├── .github ├── ISSUE_TEMPLATE │ ├── 🐞-bug-report.md │ └── 💡-feature-request.md └── workflows │ ├── publish.yml │ └── python-package.yml ├── .gitignore ├── .vscode └── settings.json ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── banner.png ├── benchmarks ├── README.md ├── _domain.py ├── create_db.bat ├── create_dbs.bat ├── create_dbs.sh ├── datasets │ └── 2wikimultihopqa.json ├── db │ └── .gitignore ├── evaluate_dbs.bat ├── evaluate_dbs.sh ├── graph_benchmark.py ├── lightrag_benchmark.py ├── nano_benchmark.py ├── questions │ ├── 2wikimultihopqa_101.json │ └── 2wikimultihopqa_51.json ├── results │ ├── graph │ │ ├── 2wikimultihopqa_101.json │ │ └── 2wikimultihopqa_51.json │ ├── lightrag │ │ ├── 2wikimultihopqa_101_local.json │ │ └── 2wikimultihopqa_51_local.json │ ├── nano │ │ ├── 2wikimultihopqa_101_local.json │ │ └── 2wikimultihopqa_51_local.json │ └── vdb │ │ ├── 2wikimultihopqa_101.json │ │ └── 2wikimultihopqa_51.json └── vdb_benchmark.py ├── demo.gif ├── examples ├── checkpointing.ipynb ├── custom_llm.py ├── gemini_example.py ├── gemini_vertexai_llm.py └── query_parameters.ipynb ├── fast_graphrag ├── __init__.py ├── _exceptions.py ├── _graphrag.py ├── _llm │ ├── __init__.py │ ├── _base.py │ ├── _default.py │ ├── _llm_genai.py │ ├── _llm_openai.py │ └── _llm_voyage.py ├── _models.py ├── _policies │ ├── __init__.py │ ├── _base.py │ ├── _graph_upsert.py │ └── _ranking.py ├── _prompt.py ├── _services │ ├── __init__.py │ ├── _base.py │ ├── _chunk_extraction.py │ ├── _information_extraction.py │ └── _state_manager.py ├── _storage │ ├── __init__.py │ ├── _base.py │ ├── _blob_pickle.py │ ├── _default.py │ ├── _gdb_igraph.py │ ├── _ikv_pickle.py │ ├── _namespace.py │ └── _vdb_hnswlib.py ├── _types.py └── _utils.py ├── mock_data.txt ├── poetry.lock ├── pyproject.toml └── tests ├── __init__.py ├── _graphrag_test.py ├── _llm ├── __init__.py ├── _base_test.py └── _llm_openai_test.py ├── _models_test.py ├── _policies ├── __init__.py ├── _graph_upsert_test.py └── _ranking_test.py ├── _services ├── __init__.py ├── _chunk_extraction_test.py └── _information_extraction_test.py ├── _storage ├── __init__.py ├── _base_test.py ├── _blob_pickle_test.py ├── _gdb_igraph_test.py ├── _ikv_pickle_test.py ├── _namespace_test.py └── _vdb_hnswlib_test.py ├── _types_test.py └── _utils_test.py /.github/ISSUE_TEMPLATE/🐞-bug-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F41E Bug report" 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Additional context** 27 | Add any other context about the problem here. 28 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/💡-feature-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F4A1 Feature request" 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish Package 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | workflow_dispatch: # TODO remove 8 | 9 | jobs: 10 | build-n-publish: 11 | name: Build and publish to PyPI 12 | runs-on: ubuntu-22.04 13 | environment: 14 | name: pypi 15 | url: https://pypi.org/p/fast-graphrag 16 | permissions: 17 | id-token: write 18 | steps: 19 | - uses: actions/checkout@master 20 | 21 | - name: Set up Python 3.11 22 | uses: actions/setup-python@v1 23 | with: 24 | python-version: 3.11 25 | 26 | - name: Install Poetry 27 | run: pipx install poetry==1.8.* 28 | 29 | - name: Cache Poetry virtual environment 30 | uses: actions/cache@v3 31 | with: 32 | path: ~/.cache/pypoetry 33 | key: ${{ runner.os }}-poetry-${{ hashFiles('**/poetry.lock') }} 34 | restore-keys: | 35 | ${{ runner.os }}-poetry- 36 | 37 | - name: Lock 38 | run: poetry lock 39 | - name: Build 40 | run: poetry build 41 | 42 | - name: pypi-publish 43 | uses: pypa/gh-action-pypi-publish@v1.10.3 -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | python-version: ["3.10", "3.11", "3.12"] 20 | 21 | steps: 22 | - uses: actions/checkout@v4 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v3 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | python -m pip install ruff 31 | pipx install poetry 32 | poetry lock 33 | poetry install 34 | - name: Lint with ruff 35 | run: | 36 | # Stop the build if there are Python syntax errors or undefined names 37 | ruff check . --select E9,F63,F7,F82 --show-files 38 | 39 | # Check with the same settings as the dev environment 40 | ruff check . --select E,W,F,I,B,C4,N,D --ignore C901,W191,D401 --show-files 41 | 42 | # Treat all errors as warnings with max line length and complexity constraints 43 | ruff check . --exit-zero --line-length 127 --select C901 44 | - name: Test with unittest 45 | run: | 46 | poetry run python -m unittest discover -s tests/ -p "*_test.py" 47 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # celery beat schedule file 88 | celerybeat-schedule 89 | 90 | # SageMath parsed files 91 | *.sage.py 92 | 93 | # Environments 94 | .env 95 | .venv 96 | env/ 97 | venv/ 98 | ENV/ 99 | env.bak/ 100 | venv.bak/ 101 | 102 | # Spyder project settings 103 | .spyderproject 104 | .spyproject 105 | 106 | # Rope project settings 107 | .ropeproject 108 | 109 | # mkdocs documentation 110 | /site 111 | 112 | # mypy 113 | .mypy_cache/ 114 | .dmypy.json 115 | dmypy.json 116 | 117 | # Pyre type checker 118 | .pyre/ 119 | 120 | # pytype static type analyzer 121 | .pytype/ 122 | 123 | # Cython debug symbols 124 | cython_debug/ 125 | book_example/ 126 | book.txt 127 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.analysis.typeCheckingMode": "strict", 3 | "python.testing.unittestArgs": [ 4 | "-v", 5 | "-s", 6 | ".", 7 | "-p", 8 | "*_test.py" 9 | ], 10 | "python.testing.pytestEnabled": false, 11 | "python.testing.unittestEnabled": true, 12 | "python.autoComplete.extraPaths": [ 13 | "${workspaceFolder}" 14 | ], 15 | "python.analysis.extraPaths": [ 16 | "${workspaceFolder}" 17 | ] 18 | } -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct - Fast GraphRAG 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behaviour that contributes to a positive environment for our 15 | community include: 16 | 17 | * Demonstrating empathy and kindness toward other people 18 | * Being respectful of differing opinions, viewpoints, and experiences 19 | * Giving and gracefully accepting constructive feedback 20 | * Accepting responsibility and apologising to those affected by our mistakes, 21 | and learning from the experience 22 | * Focusing on what is best not just for us as individuals, but for the 23 | overall community 24 | 25 | Examples of unacceptable behaviour include: 26 | 27 | * The use of sexualised language or imagery, and sexual attention or advances 28 | * Trolling, insulting or derogatory comments, and personal or political attacks 29 | * Public or private harassment 30 | * Publishing others' private information, such as a physical or email 31 | address, without their explicit permission 32 | * Other conduct which could reasonably be considered inappropriate in a 33 | professional setting 34 | 35 | ## Our Responsibilities 36 | 37 | Project maintainers are responsible for clarifying and enforcing our standards of 38 | acceptable behaviour and will take appropriate and fair corrective action in 39 | response to any instances of unacceptable behaviour. 40 | 41 | Project maintainers have the right and responsibility to remove, edit, or reject 42 | comments, commits, code, wiki edits, issues, and other contributions that are 43 | not aligned to this Code of Conduct, or to ban 44 | temporarily or permanently any contributor for other behaviours that they deem 45 | inappropriate, threatening, offensive, or harmful. 46 | 47 | ## Scope 48 | 49 | This Code of Conduct applies within all community spaces, and also applies when 50 | an individual is officially representing the community in public spaces. 51 | Examples of representing our community include using an official e-mail address, 52 | posting via an official social media account, or acting as an appointed 53 | representative at an online or offline event. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behaviour may be 58 | reported to the community leaders responsible for enforcement at . 59 | All complaints will be reviewed and investigated promptly and fairly. 60 | 61 | All community leaders are obligated to respect the privacy and security of the 62 | reporter of any incident. 63 | 64 | ## Attribution 65 | 66 | This Code of Conduct is adapted from the [Contributor Covenant](https://contributor-covenant.org/), version 67 | [1.4](https://www.contributor-covenant.org/version/1/4/code-of-conduct/code_of_conduct.md) and 68 | [2.0](https://www.contributor-covenant.org/version/2/0/code_of_conduct/code_of_conduct.md), 69 | and was generated by [contributing-gen](https://github.com/bttger/contributing-gen). 70 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | 2 | # Contributing to Fast GraphRAG 3 | 4 | First off, thanks for taking the time to contribute! ❤️ 5 | 6 | All types of contributions are encouraged and valued. See the [Table of Contents](#table-of-contents) for different ways to help and details about how this project handles them. Please make sure to read the relevant section before making your contribution. It will make it a lot easier for us maintainers and smooth out the experience for all involved. The community looks forward to your contributions. 🎉 7 | 8 | > And if you like the project, but just don't have time to contribute, that's fine. There are other easy ways to support the project and show your appreciation, which we would also be very happy about: 9 | > - Star the project 10 | > - Tweet about it 11 | > - Refer this project in your project's readme 12 | > - Mention the project at local meetups and tell your friends/colleagues 13 | 14 | 15 | ## Table of Contents 16 | 17 | - [Code of Conduct](#code-of-conduct) 18 | - [I Have a Question](#i-have-a-question) 19 | - [I Want To Contribute](#i-want-to-contribute) 20 | - [Reporting Bugs](#reporting-bugs) 21 | - [Suggesting Enhancements](#suggesting-enhancements) 22 | - [Your First Code Contribution](#your-first-code-contribution) 23 | - [Improving The Documentation](#improving-the-documentation) 24 | - [Styleguides](#styleguides) 25 | - [Commit Messages](#commit-messages) 26 | - [Join The Project Team](#join-the-project-team) 27 | 28 | 29 | ## Code of Conduct 30 | 31 | This project and everyone participating in it is governed by the 32 | [Fast GraphRAG Code of Conduct](https://github.com/circlemind-ai/fast-graphrag/blob/main/CODE_OF_CONDUCT.md). 33 | By participating, you are expected to uphold this code. Please report unacceptable behavior 34 | to . 35 | 36 | 37 | ## I Have a Question 38 | 39 | First off, make sure to join the discord community: https://discord.gg/McpuSEkR 40 | 41 | Before you ask a question, it is best to search for existing [Issues](https://github.com/circlemind-ai/fast-graphrag/issues) that might help you. In case you have found a suitable issue and still need clarification, you can write your question in this issue. It is also advisable to search the internet for answers first. 42 | 43 | If you then still feel the need to ask a question and need clarification, we recommend the following: 44 | 45 | - Open an [Issue](https://github.com/circlemind-ai/fast-graphrag/issues/new). 46 | - Provide as much context as you can about what you're running into. 47 | - Provide project and platform versions (python, os, etc), depending on what seems relevant. 48 | 49 | We will then take care of the issue as soon as possible. 50 | 51 | ## I Want To Contribute 52 | 53 | > ### Legal Notice 54 | > When contributing to this project, you must agree that you have authored 100% of the content, that you have the necessary rights to the content and that the content you contribute may be provided under the project licence. 55 | 56 | ### Reporting Bugs 57 | 58 | 59 | #### Before Submitting a Bug Report 60 | 61 | A good bug report shouldn't leave others needing to chase you up for more information. Therefore, we ask you to investigate carefully, collect information and describe the issue in detail in your report. Please complete the following steps in advance to help us fix any potential bug as fast as possible. 62 | 63 | - Make sure that you are using the latest version. 64 | - Determine if your bug is really a bug and not an error on your side e.g. using incompatible environment components/versions. If you are looking for support, you might want to check Discord first. 65 | - To see if other users have experienced (and potentially already solved) the same issue you are having, check if there is not already a bug report existing for your bug or error in the [bug tracker](https://github.com/circlemind-ai/fast-graphrag/issues?q=label%3Abug). 66 | - Also make sure to search the internet (including Stack Overflow) to see if users outside of the GitHub community have discussed the issue. 67 | - Collect all important information about the bug 68 | 69 | 70 | #### How Do I Submit a Good Bug Report? 71 | 72 | > You must never report security related issues, vulnerabilities or bugs including sensitive information to the issue tracker, or elsewhere in public. Instead sensitive bugs must be sent by email to security@circlemind.co 73 | 74 | We use GitHub issues to track bugs and errors. If you run into an issue with the project: 75 | 76 | - Open an [Issue](https://github.com/circlemind-ai/fast-graphrag/issues/new). (Since we can't be sure at this point whether it is a bug or not, we ask you not to talk about a bug yet and not to label the issue.) 77 | - Explain the behavior you would expect and the actual behavior. 78 | - Please provide as much context as possible and describe the *reproduction steps* that someone else can follow to recreate the issue on their own. This usually includes your code. For good bug reports you should isolate the problem and create a reduced test case. 79 | - Provide the information you collected in the previous section. 80 | 81 | Once it's filed: 82 | 83 | - The project team will label the issue accordingly. 84 | - A team member will try to reproduce the issue with your provided steps. If there are no reproduction steps or no obvious way to reproduce the issue, the team will ask you for those steps and mark the issue as `needs-repro`. Bugs with the `needs-repro` tag will not be addressed until they are reproduced. 85 | - If the team is able to reproduce the issue, it will be marked `needs-fix`, as well as possibly other tags (such as `critical`), and the issue will be left to be [implemented by someone](#your-first-code-contribution). 86 | 87 | 88 | 89 | 90 | ### Suggesting Enhancements 91 | 92 | This section guides you through submitting an enhancement suggestion for Fast GraphRAG, **including completely new features and minor improvements to existing functionality**. Following these guidelines will help maintainers and the community to understand your suggestion and find related suggestions. 93 | 94 | 95 | #### Before Submitting an Enhancement 96 | 97 | - Make sure that you are using the latest version. 98 | - Perform a [search](https://github.com/circlemind-ai/fast-graphrag/issues) to see if the enhancement has already been suggested. If it has, add a comment to the existing issue instead of opening a new one. 99 | - Find out whether your idea fits with the scope and aims of the project. It's up to you to make a strong case to convince the project's developers of the merits of this feature. Keep in mind that we want features that will be useful to the majority of our users and not just a small subset. If you're just targeting a minority of users, consider writing an add-on/plugin library. 100 | 101 | 102 | #### How Do I Submit a Good Enhancement Suggestion? 103 | 104 | Enhancement suggestions are tracked as [GitHub issues](https://github.com/circlemind-ai/fast-graphrag/issues). 105 | 106 | - Use a **clear and descriptive title** for the issue to identify the suggestion. 107 | - Provide a **step-by-step description of the suggested enhancement** in as many details as possible. 108 | - **Describe the current behavior** and **explain which behavior you expected to see instead** and why. At this point you can also tell which alternatives do not work for you. 109 | - **Explain why this enhancement would be useful** to most Fast GraphRAG users. You may also want to point out the other projects that solved it better and which could serve as inspiration. 110 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 circlemind-ai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | circlemind fast-graphrag 3 |

4 |

5 | 6 | fast-graphrag is released under the MIT license. 7 | 8 | 9 | PRs welcome! 10 | 11 | 12 | Circlemind Page 13 | 14 | 15 |

16 |

17 |

Streamlined and promptable Fast GraphRAG framework designed for interpretable, high-precision, agent-driven retrieval workflows.
Looking for a Managed Service? »

18 |

19 | 20 |

21 | Install | 22 | Quickstart | 23 | Community | 24 | Report Bug | 25 | Request Feature 26 |

27 | 28 | > [!NOTE] 29 | > Using *The Wizard of Oz*, `fast-graphrag` costs $0.08 vs. `graphrag` $0.48 — **a 6x costs saving** that further improves with data size and number of insertions. 30 | 31 | ## News (and Coming Soon) 32 | - [ ] Support for IDF weightening of entities 33 | - [x] Support for generic entities and concepts (initial commit) 34 | - [x] [2024.12.02] Benchmarks comparing Fast GraphRAG to LightRAG, GraphRAG and VectorDBs released [here](https://github.com/circlemind-ai/fast-graphrag/blob/main/benchmarks/README.md) 35 | 36 | ## Features 37 | 38 | - **Interpretable and Debuggable Knowledge:** Graphs offer a human-navigable view of knowledge that can be queried, visualized, and updated. 39 | - **Fast, Low-cost, and Efficient:** Designed to run at scale without heavy resource or cost requirements. 40 | - **Dynamic Data:** Automatically generate and refine graphs to best fit your domain and ontology needs. 41 | - **Incremental Updates:** Supports real-time updates as your data evolves. 42 | - **Intelligent Exploration:** Leverages PageRank-based graph exploration for enhanced accuracy and dependability. 43 | - **Asynchronous & Typed:** Fully asynchronous, with complete type support for robust and predictable workflows. 44 | 45 | Fast GraphRAG is built to fit seamlessly into your retrieval pipeline, giving you the power of advanced RAG, without the overhead of building and designing agentic workflows. 46 | 47 | ## Install 48 | 49 | **Install from source (recommended for best performance)** 50 | 51 | ```bash 52 | # clone this repo first 53 | cd fast_graphrag 54 | poetry install 55 | ``` 56 | 57 | **Install from PyPi (recommended for stability)** 58 | 59 | ```bash 60 | pip install fast-graphrag 61 | ``` 62 | 63 | ## Quickstart 64 | 65 | Set the OpenAI API key in the environment: 66 | 67 | ```bash 68 | export OPENAI_API_KEY="sk-..." 69 | ``` 70 | 71 | Download a copy of *A Christmas Carol* by Charles Dickens: 72 | 73 | ```bash 74 | curl https://raw.githubusercontent.com/circlemind-ai/fast-graphrag/refs/heads/main/mock_data.txt > ./book.txt 75 | ``` 76 | 77 | Optional: Set the limit for concurrent requests to the LLM (i.e., to control the number of tasks processed simultaneously by the LLM, this is helpful when running local models) 78 | ```bash 79 | export CONCURRENT_TASK_LIMIT=8 80 | ``` 81 | 82 | Use the Python snippet below: 83 | 84 | ```python 85 | from fast_graphrag import GraphRAG 86 | 87 | DOMAIN = "Analyze this story and identify the characters. Focus on how they interact with each other, the locations they explore, and their relationships." 88 | 89 | EXAMPLE_QUERIES = [ 90 | "What is the significance of Christmas Eve in A Christmas Carol?", 91 | "How does the setting of Victorian London contribute to the story's themes?", 92 | "Describe the chain of events that leads to Scrooge's transformation.", 93 | "How does Dickens use the different spirits (Past, Present, and Future) to guide Scrooge?", 94 | "Why does Dickens choose to divide the story into \"staves\" rather than chapters?" 95 | ] 96 | 97 | ENTITY_TYPES = ["Character", "Animal", "Place", "Object", "Activity", "Event"] 98 | 99 | grag = GraphRAG( 100 | working_dir="./book_example", 101 | domain=DOMAIN, 102 | example_queries="\n".join(EXAMPLE_QUERIES), 103 | entity_types=ENTITY_TYPES 104 | ) 105 | 106 | with open("./book.txt") as f: 107 | grag.insert(f.read()) 108 | 109 | print(grag.query("Who is Scrooge?").response) 110 | ``` 111 | 112 | The next time you initialize fast-graphrag from the same working directory, it will retain all the knowledge automatically. 113 | 114 | ## Examples 115 | Please refer to the `examples` folder for a list of tutorials on common use cases of the library: 116 | - `custom_llm.py`: a brief example on how to configure fast-graphrag to run with different OpenAI API compatible language models and embedders; 117 | - `checkpointing.ipynb`: a tutorial on how to use checkpoints to avoid irreversible data corruption; 118 | - `query_parameters.ipynb`: a tutorial on how to use the different query parameters. In particular, it shows how to include references to the used information in the provided answer (using the `with_references=True` parameter). 119 | 120 | ## Contributing 121 | 122 | Whether it's big or small, we love contributions. Contributions are what make the open-source community such an amazing place to learn, inspire, and create. Any contributions you make are greatly appreciated. Check out our [guide](https://github.com/circlemind-ai/fast-graphrag/blob/main/CONTRIBUTING.md) to see how to get started. 123 | 124 | Not sure where to get started? You can join our [Discord](https://discord.gg/DvY2B8u4sA) and ask us any questions there. 125 | 126 | ## Philosophy 127 | 128 | Our mission is to increase the number of successful GenAI applications in the world. To do that, we build memory and data tools that enable LLM apps to leverage highly specialized retrieval pipelines without the complexity of setting up and maintaining agentic workflows. 129 | 130 | Fast GraphRAG currently exploit the personalized pagerank algorithm to explore the graph and find the most relevant pieces of information to answer your query. For an overview on why this works, you can check out the HippoRAG paper [here](https://arxiv.org/abs/2405.14831). 131 | 132 | ## Open-source or Managed Service 133 | 134 | This repo is under the MIT License. See [LICENSE.txt](https://github.com/circlemind-ai/fast-graphrag/blob/main/LICENSE) for more information. 135 | 136 | The fastest and most reliable way to get started with Fast GraphRAG is using our managed service. Your first 100 requests are free every month, after which you pay based on usage. 137 | 138 |

139 | circlemind fast-graphrag demo 140 |

141 | 142 | To learn more about our managed service, [book a demo](https://circlemind.co/demo) or see our [docs](https://docs.circlemind.co/quickstart). 143 | -------------------------------------------------------------------------------- /banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/circlemind-ai/fast-graphrag/b370efe01ef836af292a3713d59b2ec23d2fe7c4/banner.png -------------------------------------------------------------------------------- /benchmarks/README.md: -------------------------------------------------------------------------------- 1 | ## Benchmarks 2 | We validate the benchmark results provided in [HippoRAG](https://arxiv.org/abs/2405.14831), as well as comparing with other methods: 3 | - NaiveRAG (vector dbs) using the OpenAI embedder `text-embedding-3-small` 4 | - [LightRAG](https://github.com/HKUDS/LightRAG) 5 | - [GraphRAG](https://github.com/gusye1234/nano-graphrag) (we use the implementation provided by `nano-graphrag`, based on the original [Microsoft GraphRAG](https://github.com/microsoft/graphrag)) 6 | 7 | ### Results 8 | **2wikimultihopQA** 9 | | # Queries | Method | All queries % | Multihop only % | 10 | |----------:|:--------:|--------------:|----------------:| 11 | | 51|||| 12 | | | VectorDB| 0.49| 0.32| 13 | | | LightRAG| 0.47| 0.32| 14 | | | GraphRAG| 0.75| 0.68| 15 | | |**Circlemind**| **0.96**| **0.95**| 16 | | 101|||| 17 | | | VectorDB| 0.42| 0.23| 18 | | | LightRAG| 0.45| 0.28| 19 | | | GraphRAG| 0.73| 0.64| 20 | | |**Circlemind**| **0.93**| **0.90**| 21 | 22 | **Circlemind is up to 4x more accurate than VectorDB RAG.** 23 | 24 | **HotpotQA** 25 | | # Queries | Method | All queries % | 26 | |----------:|:--------:|--------------:| 27 | | 101||| 28 | | | VectorDB| 0.78| 29 | | | LightRAG| 0.55| 30 | | | GraphRAG| -*| 31 | | |**Circlemind**| **0.84**| 32 | 33 | *: crashes after half an hour of processing 34 | 35 | Below, find the insertion times for the 2wikimultihopqa benchmark (~800 chunks): 36 | | Method | Time (minutes) | 37 | |:--------:|-----------------:| 38 | | VectorDB| ~0.3| 39 | | LightRAG| ~25| 40 | | GraphRAG| ~40| 41 | |**Circlemind**| ~1.5| 42 | 43 | **Circlemind is 27x faster than GraphRAG while also being over 40% more accurate in retrieval.** 44 | 45 | ### Run it yourself 46 | The scripts in this directory will generate and evaluate the 2wikimultihopqa datasets on a subsets of 51 and 101 queries with the same methodology as in the HippoRAG paper. In particular, we evaluate the retrieval capabilities of each method, mesauring the percentage of queries for which all the required evidence was retrieved. We preloaded the results so it is enough to run `evaluate_dbs.xx` to get the numbers. You can also run `create_dbs.xx` to regenerate the databases for the different methods. 47 | 48 | A couple of NOTES: 49 | - you will need to set an OPENAI_API_KEY; 50 | - LightRAG and GraphRAG could take a over an 1 hour to process and they can be expensive; 51 | - when pip installing LightRAG, not all dependencies are added; to run it we simply deleted all the imports of each missing dependency (since we use OpenAI they are not necessary). 52 | - we also benchmarked on the HotpotQA dataset (we will soon release the code for that as well). 53 | 54 | The output will look similar to the following (the exact numbers could vary based on your graph configuration) 55 | ``` 56 | Evaluation of the performance of different RAG methods on 2wikimultihopqa (51 queries) 57 | 58 | VectorDB 59 | Loading dataset... 60 | [all questions] Percentage of queries with perfect retrieval: 0.49019607843137253 61 | [multihop only] Percentage of queries with perfect retrieval: 0.32432432432432434 62 | 63 | LightRAG [local mode] 64 | Loading dataset... 65 | Percentage of queries with perfect retrieval: 0.47058823529411764 66 | [multihop only] Percentage of queries with perfect retrieval: 0.32432432432432434 67 | 68 | GraphRAG [local mode] 69 | Loading dataset... 70 | [all questions] Percentage of queries with perfect retrieval: 0.7450980392156863 71 | [multihop only] Percentage of queries with perfect retrieval: 0.6756756756756757 72 | 73 | Circlemind 74 | Loading dataset... 75 | [all questions] Percentage of queries with perfect retrieval: 0.9607843137254902 76 | [multihop only] Percentage of queries with perfect retrieval: 0.9459459459459459 77 | 78 | 79 | Evaluation of the performance of different RAG methods on 2wikimultihopqa (101 queries) 80 | 81 | VectorDB 82 | Loading dataset... 83 | [all questions] Percentage of queries with perfect retrieval: 0.4158415841584158 84 | [multihop only] Percentage of queries with perfect retrieval: 0.2318840579710145 85 | 86 | LightRAG [local mode] 87 | Loading dataset... 88 | [all questions] Percentage of queries with perfect retrieval: 0.44554455445544555 89 | [multihop only] Percentage of queries with perfect retrieval: 0.2753623188405797 90 | 91 | GraphRAG [local mode] 92 | Loading dataset... 93 | [all questions] Percentage of queries with perfect retrieval: 0.7326732673267327 94 | [multihop only] Percentage of queries with perfect retrieval: 0.6376811594202898 95 | 96 | Circlemind 97 | Loading dataset... 98 | [all questions] Percentage of queries with perfect retrieval: 0.9306930693069307 99 | [multihop only] Percentage of queries with perfect retrieval: 0.8985507246376812 100 | ``` 101 | -------------------------------------------------------------------------------- /benchmarks/_domain.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | DOMAIN: Dict[str, str] = { 4 | "2wikimultihopqa": """Analyse the following passage and identify the people, creative works, and places mentioned in it. Your goal is to create an RDF (Resource Description Framework) graph from the given text. 5 | IMPORTANT: among other entities and relationships you find, make sure to extract as separate entities (to be connected with the main one) a person's 6 | role as a family member (such as 'son', 'uncle', 'wife', ...), their profession (such as 'director'), and the location 7 | where they live or work. Pay attention to the spelling of the names.""", # noqa: E501 8 | "hotpotqa": """Analyse the following passage and identify all the entities mentioned in it and their relationships. Your goal is to create an RDF (Resource Description Framework) graph from the given text. 9 | Pay attention to the spelling of the entity names.""" 10 | } 11 | 12 | QUERIES: Dict[str, List[str]] = { 13 | "2wikimultihopqa": [ 14 | "When did Prince Arthur's mother die?", 15 | "What is the place of birth of Elizabeth II's husband?", 16 | "Which film has the director died later, Interstellar or Harry Potter I?", 17 | "Where does the singer who wrote the song Blank Space work at?", 18 | ], 19 | "hotpotqa": [ 20 | "Are Christopher Nolan and Sathish Kalathil both film directors?", 21 | "What language were books being translated into during the era of Haymo of Faversham?", 22 | "Who directed the film that was shot in or around Leland, North Carolina in 1986?", 23 | "Who wrote a song after attending a luau in the Koolauloa District on the island of Oahu in Honolulu County?" 24 | ] 25 | } 26 | 27 | ENTITY_TYPES: Dict[str, List[str]] = { 28 | "2wikimultihopqa": [ 29 | "person", 30 | "familiy_role", 31 | "location", 32 | "organization", 33 | "creative_work", 34 | "profession", 35 | ], 36 | "hotpotqa": [ 37 | "person", 38 | "familiy_role", 39 | "location", 40 | "organization", 41 | "creative_work", 42 | "profession", 43 | "event", 44 | "year" 45 | ], 46 | } 47 | -------------------------------------------------------------------------------- /benchmarks/create_db.bat: -------------------------------------------------------------------------------- 1 | :: 2wikimultihopqa benchmark 2 | :: Creating databases 3 | python graph_benchmark.py -n 51 -c 4 | python graph_benchmark.py -n 101 -c 5 | 6 | :: Evaluation (create reports) 7 | python graph_benchmark.py -n 51 -b 8 | python graph_benchmark.py -n 101 -b -------------------------------------------------------------------------------- /benchmarks/create_dbs.bat: -------------------------------------------------------------------------------- 1 | :: 2wikimultihopqa benchmark 2 | :: Creating databases 3 | python vdb_benchmark.py -n 51 -c 4 | python vdb_benchmark.py -n 101 -c 5 | python lightrag_benchmark.py -n 51 -c 6 | python lightrag_benchmark.py -n 101 -c 7 | python nano_benchmark.py -n 51 -c 8 | python nano_benchmark.py -n 101 -c 9 | python graph_benchmark.py -n 51 -c 10 | python graph_benchmark.py -n 101 -c 11 | 12 | :: Evaluation (create reports) 13 | python vdb_benchmark.py -n 51 -b 14 | python vdb_benchmark.py -n 101 -b 15 | python lightrag_benchmark.py -n 51 -b --mode=local 16 | python lightrag_benchmark.py -n 101 -b --mode=local 17 | python nano_benchmark.py -n 51 -b --mode=local 18 | python nano_benchmark.py -n 101 -b --mode=local 19 | python graph_benchmark.py -n 51 -b 20 | python graph_benchmark.py -n 101 -b -------------------------------------------------------------------------------- /benchmarks/create_dbs.sh: -------------------------------------------------------------------------------- 1 | # 2wikimultihopqa benchmark 2 | # Creating databases 3 | python vdb_benchmark.py -n 51 -c 4 | python vdb_benchmark.py -n 101 -c 5 | python lightrag_benchmark.py -n 51 -c 6 | python lightrag_benchmark.py -n 101 -c 7 | python nano_benchmark.py -n 51 -c 8 | python nano_benchmark.py -n 101 -c 9 | python graph_benchmark.py -n 51 -c 10 | python graph_benchmark.py -n 101 -c 11 | 12 | # Evaluation (create reports) 13 | python vdb_benchmark.py -n 51 -b 14 | python vdb_benchmark.py -n 101 -b 15 | python lightrag_benchmark.py -n 51 -b --mode=local 16 | python lightrag_benchmark.py -n 101 -b --mode=local 17 | python nano_benchmark.py -n 51 -b --mode=local 18 | python nano_benchmark.py -n 101 -b --mode=local 19 | python graph_benchmark.py -n 51 -b 20 | python graph_benchmark.py -n 101 -b -------------------------------------------------------------------------------- /benchmarks/db/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /benchmarks/evaluate_dbs.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | echo Evaluation of the performance of different RAG methods on 2wikimultihopqa (51 queries) 3 | echo. 4 | echo VectorDB 5 | python vdb_benchmark.py -n 51 -s 6 | echo. 7 | echo LightRAG 8 | python lightrag_benchmark.py -n 51 -s --mode=local 9 | echo. 10 | echo GraphRAG 11 | python nano_benchmark.py -n 51 -s --mode=local 12 | echo. 13 | echo Circlemind 14 | python graph_benchmark.py -n 51 -s 15 | 16 | echo. 17 | echo. 18 | echo Evaluation of the performance of different RAG methods on 2wikimultihopqa (101 queries) 19 | echo. 20 | echo VectorDB 21 | python vdb_benchmark.py -n 101 -s 22 | echo. 23 | echo LightRAG 24 | python lightrag_benchmark.py -n 101 -s --mode=local 25 | echo. 26 | echo GraphRAG 27 | python nano_benchmark.py -n 101 -s --mode=local 28 | echo. 29 | echo Circlemind 30 | python graph_benchmark.py -n 101 -s 31 | -------------------------------------------------------------------------------- /benchmarks/evaluate_dbs.sh: -------------------------------------------------------------------------------- 1 | echo "Evaluation of the performance of different RAG methods on the 2wikimultihopqa (51 queries)"; 2 | echo; 3 | echo "VectorDB"; 4 | python vdb_benchmark.py -n 51 -s 5 | echo; 6 | echo "LightRAG [local mode]"; 7 | python lightrag_benchmark.py -n 51 -s --mode=local 8 | echo; 9 | echo "GraphRAG [local mode]"; 10 | python nano_benchmark.py -n 51 -s --mode=local 11 | echo; 12 | echo "Circlemind" 13 | python graph_benchmark.py -n 51 -s 14 | 15 | echo "Evaluation of the performance of different RAG methods on the 2wikimultihopqa (101 queries)"; 16 | echo; 17 | echo "VectorDB"; 18 | python vdb_benchmark.py -n 101 -s 19 | echo; 20 | echo "LightRAG [local mode]"; 21 | python lightrag_benchmark.py -n 101 -s --mode=local 22 | echo; 23 | echo "GraphRAG [local mode]"; 24 | python nano_benchmark.py -n 101 -s --mode=local 25 | echo; 26 | echo "Circlemind"; 27 | python graph_benchmark.py -n 101 -s -------------------------------------------------------------------------------- /benchmarks/graph_benchmark.py: -------------------------------------------------------------------------------- 1 | """Benchmarking script for GraphRAG.""" 2 | 3 | import argparse 4 | import asyncio 5 | import json 6 | from dataclasses import dataclass, field 7 | from typing import Any, Dict, List, Tuple 8 | 9 | import numpy as np 10 | import xxhash 11 | from _domain import DOMAIN, ENTITY_TYPES, QUERIES 12 | from dotenv import load_dotenv 13 | from tqdm import tqdm 14 | 15 | from fast_graphrag import GraphRAG, QueryParam 16 | from fast_graphrag._utils import get_event_loop 17 | 18 | 19 | @dataclass 20 | class Query: 21 | """Dataclass for a query.""" 22 | 23 | question: str = field() 24 | answer: str = field() 25 | evidence: List[Tuple[str, int]] = field() 26 | 27 | 28 | def load_dataset(dataset_name: str, subset: int = 0) -> Any: 29 | """Load a dataset from the datasets folder.""" 30 | with open(f"./datasets/{dataset_name}.json", "r") as f: 31 | dataset = json.load(f) 32 | 33 | if subset: 34 | return dataset[:subset] 35 | else: 36 | return dataset 37 | 38 | 39 | def get_corpus(dataset: Any, dataset_name: str) -> Dict[int, Tuple[int | str, str]]: 40 | """Get the corpus from the dataset.""" 41 | if dataset_name == "2wikimultihopqa" or dataset_name == "hotpotqa": 42 | passages: Dict[int, Tuple[int | str, str]] = {} 43 | 44 | for datapoint in dataset: 45 | context = datapoint["context"] 46 | 47 | for passage in context: 48 | title, text = passage 49 | title = title.encode("utf-8").decode() 50 | text = "\n".join(text).encode("utf-8").decode() 51 | hash_t = xxhash.xxh3_64_intdigest(text) 52 | if hash_t not in passages: 53 | passages[hash_t] = (title, text) 54 | 55 | return passages 56 | else: 57 | raise NotImplementedError(f"Dataset {dataset_name} not supported.") 58 | 59 | 60 | def get_queries(dataset: Any): 61 | """Get the queries from the dataset.""" 62 | queries: List[Query] = [] 63 | 64 | for datapoint in dataset: 65 | queries.append( 66 | Query( 67 | question=datapoint["question"].encode("utf-8").decode(), 68 | answer=datapoint["answer"], 69 | evidence=list(datapoint["supporting_facts"]), 70 | ) 71 | ) 72 | 73 | return queries 74 | 75 | 76 | if __name__ == "__main__": 77 | load_dotenv() 78 | 79 | parser = argparse.ArgumentParser(description="GraphRAG CLI") 80 | parser.add_argument("-d", "--dataset", default="2wikimultihopqa", help="Dataset to use.") 81 | parser.add_argument("-n", type=int, default=0, help="Subset of corpus to use.") 82 | parser.add_argument("-c", "--create", action="store_true", help="Create the graph for the given dataset.") 83 | parser.add_argument("-b", "--benchmark", action="store_true", help="Benchmark the graph for the given dataset.") 84 | parser.add_argument("-s", "--score", action="store_true", help="Report scores after benchmarking.") 85 | args = parser.parse_args() 86 | 87 | print("Loading dataset...") 88 | dataset = load_dataset(args.dataset, subset=args.n) 89 | working_dir = f"./db/graph/{args.dataset}_{args.n}" 90 | corpus = get_corpus(dataset, args.dataset) 91 | 92 | if args.create: 93 | print("Dataset loaded. Corpus:", len(corpus)) 94 | grag = GraphRAG( 95 | working_dir=working_dir, 96 | domain=DOMAIN[args.dataset], 97 | example_queries="\n".join(QUERIES), 98 | entity_types=ENTITY_TYPES[args.dataset], 99 | ) 100 | grag.insert( 101 | [f"{title}: {corpus}" for _, (title, corpus) in tuple(corpus.items())], 102 | metadata=[{"id": title} for title in tuple(corpus.keys())], 103 | ) 104 | if args.benchmark: 105 | queries = get_queries(dataset) 106 | print("Dataset loaded. Queries:", len(queries)) 107 | grag = GraphRAG( 108 | working_dir=working_dir, 109 | domain=DOMAIN[args.dataset], 110 | example_queries="\n".join(QUERIES), 111 | entity_types=ENTITY_TYPES[args.dataset], 112 | ) 113 | 114 | async def _query_task(query: Query) -> Dict[str, Any]: 115 | answer = await grag.async_query(query.question, QueryParam(only_context=True)) 116 | return { 117 | "question": query.question, 118 | "answer": answer.response, 119 | "evidence": [ 120 | corpus[chunk.metadata["id"]][0] 121 | if isinstance(chunk.metadata["id"], int) 122 | else chunk.metadata["id"] 123 | for chunk, _ in answer.context.chunks 124 | ], 125 | "ground_truth": [e[0] for e in query.evidence], 126 | } 127 | 128 | async def _run(): 129 | await grag.state_manager.query_start() 130 | answers = [ 131 | await a 132 | for a in tqdm(asyncio.as_completed([_query_task(query) for query in queries]), total=len(queries)) 133 | ] 134 | await grag.state_manager.query_done() 135 | return answers 136 | 137 | answers = get_event_loop().run_until_complete(_run()) 138 | 139 | with open(f"./results/graph/{args.dataset}_{args.n}.json", "w") as f: 140 | json.dump(answers, f, indent=4) 141 | 142 | if args.benchmark or args.score: 143 | with open(f"./results/graph/{args.dataset}_{args.n}.json", "r") as f: 144 | answers = json.load(f) 145 | 146 | try: 147 | with open(f"./questions/{args.dataset}_{args.n}.json", "r") as f: 148 | questions_multihop = json.load(f) 149 | except FileNotFoundError: 150 | questions_multihop = [] 151 | 152 | # Compute retrieval metrics 153 | retrieval_scores: List[float] = [] 154 | retrieval_scores_multihop: List[float] = [] 155 | 156 | for answer in answers: 157 | ground_truth = answer["ground_truth"] 158 | predicted_evidence = answer["evidence"] 159 | 160 | p_retrieved: float = len(set(ground_truth).intersection(set(predicted_evidence))) / len(set(ground_truth)) 161 | retrieval_scores.append(p_retrieved) 162 | 163 | if answer["question"] in questions_multihop: 164 | retrieval_scores_multihop.append(p_retrieved) 165 | 166 | print( 167 | f"Percentage of queries with perfect retrieval: {np.mean([1 if s == 1.0 else 0 for s in retrieval_scores])}" 168 | ) 169 | if len(retrieval_scores_multihop): 170 | print( 171 | f"[multihop] Percentage of queries with perfect retrieval: { 172 | np.mean([1 if s == 1.0 else 0 for s in retrieval_scores_multihop]) 173 | }" 174 | ) 175 | -------------------------------------------------------------------------------- /benchmarks/lightrag_benchmark.py: -------------------------------------------------------------------------------- 1 | """Benchmarking script for GraphRAG.""" 2 | 3 | import argparse 4 | import asyncio 5 | import json 6 | import os 7 | import re 8 | from dataclasses import dataclass, field 9 | from typing import Any, Dict, List, Tuple 10 | 11 | import numpy as np 12 | import xxhash 13 | from dotenv import load_dotenv 14 | from lightrag import LightRAG, QueryParam 15 | from lightrag.lightrag import always_get_an_event_loop 16 | from lightrag.llm import gpt_4o_mini_complete 17 | from lightrag.utils import logging 18 | from tqdm import tqdm 19 | 20 | logging.getLogger("httpx").setLevel(logging.WARNING) 21 | logging.getLogger("nano-vectordb").setLevel(logging.WARNING) 22 | 23 | @dataclass 24 | class Query: 25 | """Dataclass for a query.""" 26 | 27 | question: str = field() 28 | answer: str = field() 29 | evidence: List[Tuple[str, int]] = field() 30 | 31 | 32 | def load_dataset(dataset_name: str, subset: int = 0) -> Any: 33 | """Load a dataset from the datasets folder.""" 34 | with open(f"./datasets/{dataset_name}.json", "r") as f: 35 | dataset = json.load(f) 36 | 37 | if subset: 38 | return dataset[:subset] 39 | else: 40 | return dataset 41 | 42 | 43 | def get_corpus(dataset: Any, dataset_name: str) -> Dict[int, Tuple[int | str, str]]: 44 | """Get the corpus from the dataset.""" 45 | if dataset_name == "2wikimultihopqa" or dataset_name == "hotpotqa": 46 | passages: Dict[int, Tuple[int | str, str]] = {} 47 | 48 | for datapoint in dataset: 49 | context = datapoint["context"] 50 | 51 | for passage in context: 52 | title, text = passage 53 | title = title.encode("utf-8").decode() 54 | text = "\n".join(text).encode("utf-8").decode() 55 | hash_t = xxhash.xxh3_64_intdigest(text) 56 | if hash_t not in passages: 57 | passages[hash_t] = (title, text) 58 | 59 | return passages 60 | else: 61 | raise NotImplementedError(f"Dataset {dataset_name} not supported.") 62 | 63 | 64 | def get_queries(dataset: Any): 65 | """Get the queries from the dataset.""" 66 | queries: List[Query] = [] 67 | 68 | for datapoint in dataset: 69 | queries.append( 70 | Query( 71 | question=datapoint["question"].encode("utf-8").decode(), 72 | answer=datapoint["answer"], 73 | evidence=list(datapoint["supporting_facts"]), 74 | ) 75 | ) 76 | 77 | return queries 78 | 79 | 80 | if __name__ == "__main__": 81 | load_dotenv() 82 | 83 | parser = argparse.ArgumentParser(description="LightRAG CLI") 84 | parser.add_argument("-d", "--dataset", default="2wikimultihopqa", help="Dataset to use.") 85 | parser.add_argument("-n", type=int, default=0, help="Subset of corpus to use.") 86 | parser.add_argument("-c", "--create", action="store_true", help="Create the graph for the given dataset.") 87 | parser.add_argument("-b", "--benchmark", action="store_true", help="Benchmark the graph for the given dataset.") 88 | parser.add_argument("-s", "--score", action="store_true", help="Report scores after benchmarking.") 89 | parser.add_argument("--mode", default="local", help="LightRAG query mode.") 90 | args = parser.parse_args() 91 | 92 | print("Loading dataset...") 93 | dataset = load_dataset(args.dataset, subset=args.n) 94 | working_dir = f"./db/lightrag/{args.dataset}_{args.n}" 95 | corpus = get_corpus(dataset, args.dataset) 96 | 97 | if not os.path.exists(working_dir): 98 | os.mkdir(working_dir) 99 | if args.create: 100 | print("Dataset loaded. Corpus:", len(corpus)) 101 | grag = LightRAG( 102 | working_dir=working_dir, 103 | llm_model_func=gpt_4o_mini_complete, 104 | log_level=logging.WARNING 105 | ) 106 | grag.insert([f"{title}: {corpus}" for _, (title, corpus) in tuple(corpus.items())]) 107 | if args.benchmark: 108 | queries = get_queries(dataset) 109 | print("Dataset loaded. Queries:", len(queries)) 110 | grag = LightRAG( 111 | working_dir=working_dir, 112 | llm_model_func=gpt_4o_mini_complete, 113 | log_level=logging.WARNING 114 | ) 115 | 116 | async def _query_task(query: Query, mode: str) -> Dict[str, Any]: 117 | answer = await grag.aquery( 118 | query.question, QueryParam(mode=mode, only_need_context=True, max_token_for_text_unit=9000) 119 | ) 120 | chunks = [ 121 | c.split(",")[1].split(":")[0].lstrip('"') 122 | for c in re.findall(r"\n-----Sources-----\n```csv\n(.*?)\n```", answer, re.DOTALL)[0].split("\r\n")[ 123 | 1:-1 124 | ] 125 | ] 126 | return { 127 | "question": query.question, 128 | "answer": "", 129 | "evidence": chunks[:8], 130 | "ground_truth": [e[0] for e in query.evidence], 131 | } 132 | 133 | async def _run(mode: str): 134 | answers = [ 135 | await a 136 | for a in tqdm( 137 | asyncio.as_completed([_query_task(query, mode=mode) for query in queries]), total=len(queries) 138 | ) 139 | ] 140 | return answers 141 | 142 | answers = always_get_an_event_loop().run_until_complete(_run(mode=args.mode)) 143 | 144 | with open(f"./results/lightrag/{args.dataset}_{args.n}_{args.mode}.json", "w") as f: 145 | json.dump(answers, f, indent=4) 146 | 147 | if args.benchmark or args.score: 148 | with open(f"./results/lightrag/{args.dataset}_{args.n}_{args.mode}.json", "r") as f: 149 | answers = json.load(f) 150 | 151 | try: 152 | with open(f"./questions/{args.dataset}_{args.n}.json", "r") as f: 153 | questions_multihop = json.load(f) 154 | except FileNotFoundError: 155 | questions_multihop = [] 156 | 157 | # Compute retrieval metrics 158 | retrieval_scores: List[float] = [] 159 | retrieval_scores_multihop: List[float] = [] 160 | 161 | for answer in answers: 162 | ground_truth = answer["ground_truth"] 163 | predicted_evidence = answer["evidence"] 164 | 165 | p_retrieved: float = len(set(ground_truth).intersection(set(predicted_evidence))) / len(set(ground_truth)) 166 | retrieval_scores.append(p_retrieved) 167 | 168 | if answer["question"] in questions_multihop: 169 | retrieval_scores_multihop.append(p_retrieved) 170 | 171 | print( 172 | f"Percentage of queries with perfect retrieval: {np.mean([1 if s == 1.0 else 0 for s in retrieval_scores])}" 173 | ) 174 | if len(retrieval_scores_multihop): 175 | print( 176 | f"[multihop] Percentage of queries with perfect retrieval: { 177 | np.mean([1 if s == 1.0 else 0 for s in retrieval_scores_multihop]) 178 | }" 179 | ) 180 | -------------------------------------------------------------------------------- /benchmarks/nano_benchmark.py: -------------------------------------------------------------------------------- 1 | """Benchmarking script for GraphRAG.""" 2 | 3 | import argparse 4 | import asyncio 5 | import json 6 | import os 7 | import re 8 | from dataclasses import dataclass, field 9 | from typing import Any, Dict, List, Tuple 10 | 11 | import numpy as np 12 | import xxhash 13 | from dotenv import load_dotenv 14 | from nano_graphrag import GraphRAG, QueryParam 15 | from nano_graphrag._llm import gpt_4o_mini_complete 16 | from nano_graphrag._utils import always_get_an_event_loop, logging 17 | from tqdm import tqdm 18 | 19 | logging.getLogger("nano-graphrag").setLevel(logging.WARNING) 20 | logging.getLogger("httpx").setLevel(logging.WARNING) 21 | logging.getLogger("nano-vectordb").setLevel(logging.WARNING) 22 | 23 | @dataclass 24 | class Query: 25 | """Dataclass for a query.""" 26 | 27 | question: str = field() 28 | answer: str = field() 29 | evidence: List[Tuple[str, int]] = field() 30 | 31 | 32 | def load_dataset(dataset_name: str, subset: int = 0) -> Any: 33 | """Load a dataset from the datasets folder.""" 34 | with open(f"./datasets/{dataset_name}.json", "r") as f: 35 | dataset = json.load(f) 36 | 37 | if subset: 38 | return dataset[:subset] 39 | else: 40 | return dataset 41 | 42 | 43 | def get_corpus(dataset: Any, dataset_name: str) -> Dict[int, Tuple[int | str, str]]: 44 | """Get the corpus from the dataset.""" 45 | if dataset_name == "2wikimultihopqa" or dataset_name == "hotpotqa": 46 | passages: Dict[int, Tuple[int | str, str]] = {} 47 | 48 | for datapoint in dataset: 49 | context = datapoint["context"] 50 | 51 | for passage in context: 52 | title, text = passage 53 | title = title.encode("utf-8").decode() 54 | text = "\n".join(text).encode("utf-8").decode() 55 | hash_t = xxhash.xxh3_64_intdigest(text) 56 | if hash_t not in passages: 57 | passages[hash_t] = (title, text) 58 | 59 | return passages 60 | else: 61 | raise NotImplementedError(f"Dataset {dataset_name} not supported.") 62 | 63 | 64 | def get_queries(dataset: Any): 65 | """Get the queries from the dataset.""" 66 | queries: List[Query] = [] 67 | 68 | for datapoint in dataset: 69 | queries.append( 70 | Query( 71 | question=datapoint["question"].encode("utf-8").decode(), 72 | answer=datapoint["answer"], 73 | evidence=list(datapoint["supporting_facts"]), 74 | ) 75 | ) 76 | 77 | return queries 78 | 79 | 80 | if __name__ == "__main__": 81 | load_dotenv() 82 | 83 | parser = argparse.ArgumentParser(description="LightRAG CLI") 84 | parser.add_argument("-d", "--dataset", default="2wikimultihopqa", help="Dataset to use.") 85 | parser.add_argument("-n", type=int, default=0, help="Subset of corpus to use.") 86 | parser.add_argument("-c", "--create", action="store_true", help="Create the graph for the given dataset.") 87 | parser.add_argument("-b", "--benchmark", action="store_true", help="Benchmark the graph for the given dataset.") 88 | parser.add_argument("-s", "--score", action="store_true", help="Report scores after benchmarking.") 89 | parser.add_argument("--mode", default="local", help="LightRAG query mode.") 90 | args = parser.parse_args() 91 | 92 | print("Loading dataset...") 93 | dataset = load_dataset(args.dataset, subset=args.n) 94 | working_dir = f"./db/nano/{args.dataset}_{args.n}" 95 | corpus = get_corpus(dataset, args.dataset) 96 | 97 | if not os.path.exists(working_dir): 98 | os.mkdir(working_dir) 99 | if args.create: 100 | print("Dataset loaded. Corpus:", len(corpus)) 101 | grag = GraphRAG( 102 | working_dir=working_dir, 103 | best_model_func=gpt_4o_mini_complete 104 | ) 105 | grag.insert([f"{title}: {corpus}" for _, (title, corpus) in tuple(corpus.items())]) 106 | if args.benchmark: 107 | queries = get_queries(dataset) 108 | print("Dataset loaded. Queries:", len(queries)) 109 | grag = GraphRAG( 110 | working_dir=working_dir, 111 | best_model_func=gpt_4o_mini_complete 112 | ) 113 | 114 | async def _query_task(query: Query, mode: str) -> Dict[str, Any]: 115 | answer = await grag.aquery( 116 | query.question, QueryParam(mode=mode, only_need_context=True, local_max_token_for_text_unit=9000) 117 | ) 118 | chunks = [] 119 | for c in re.findall(r"\n-----Sources-----\n```csv\n(.*?)\n```", answer, re.DOTALL)[0].split("\n")[ 120 | 1:-1 121 | ]: 122 | try: 123 | chunks.append(c.split(",\t")[1].split(":")[0].lstrip('"')) 124 | except IndexError: 125 | pass 126 | return { 127 | "question": query.question, 128 | "answer": "", 129 | "evidence": chunks[:8], 130 | "ground_truth": [e[0] for e in query.evidence], 131 | } 132 | 133 | async def _run(mode: str): 134 | answers = [ 135 | await a 136 | for a in tqdm( 137 | asyncio.as_completed([_query_task(query, mode=mode) for query in queries]), total=len(queries) 138 | ) 139 | ] 140 | return answers 141 | 142 | answers = always_get_an_event_loop().run_until_complete(_run(mode=args.mode)) 143 | 144 | with open(f"./results/nano/{args.dataset}_{args.n}_{args.mode}.json", "w") as f: 145 | json.dump(answers, f, indent=4) 146 | 147 | if args.benchmark or args.score: 148 | with open(f"./results/nano/{args.dataset}_{args.n}_{args.mode}.json", "r") as f: 149 | answers = json.load(f) 150 | 151 | try: 152 | with open(f"./questions/{args.dataset}_{args.n}.json", "r") as f: 153 | questions_multihop = json.load(f) 154 | except FileNotFoundError: 155 | questions_multihop = [] 156 | 157 | # Compute retrieval metrics 158 | retrieval_scores: List[float] = [] 159 | retrieval_scores_multihop: List[float] = [] 160 | 161 | for answer in answers: 162 | ground_truth = answer["ground_truth"] 163 | predicted_evidence = answer["evidence"] 164 | 165 | p_retrieved: float = len(set(ground_truth).intersection(set(predicted_evidence))) / len(set(ground_truth)) 166 | retrieval_scores.append(p_retrieved) 167 | 168 | if answer["question"] in questions_multihop: 169 | retrieval_scores_multihop.append(p_retrieved) 170 | 171 | print( 172 | f"Percentage of queries with perfect retrieval: {np.mean([1 if s == 1.0 else 0 for s in retrieval_scores])}" 173 | ) 174 | if len(retrieval_scores_multihop): 175 | print( 176 | f"[multihop] Percentage of queries with perfect retrieval: { 177 | np.mean([1 if s == 1.0 else 0 for s in retrieval_scores_multihop]) 178 | }" 179 | ) 180 | -------------------------------------------------------------------------------- /benchmarks/questions/2wikimultihopqa_101.json: -------------------------------------------------------------------------------- 1 | [ 2 | "When did Lothair Ii's mother die?", 3 | "What is the place of birth of the performer of song Changed It?", 4 | "Which film has the director who is older, God'S Gift To Women or Aldri Annet Enn Bråk?", 5 | "Which film whose director was born first, El Tonto or The Heart Of Doreon?", 6 | "Who is Raghnall Mac Ruaidhrí's paternal grandfather?", 7 | "Do both films Interview With A Hitman and The Last Coupon have the directors from the same country?", 8 | "What nationality is the director of film Blood Street?", 9 | "What is the place of birth of the director of film Gaby: A True Story?", 10 | "What nationality is the performer of song When The Stars Go Blue?", 11 | "Who is the child of the performer of song Me And Bobby Mcgee?", 12 | "Where was the place of death of Maurice, Prince Of Orange's father?", 13 | "Which country Aleksander Koniecpolski (1620–1659)'s father is from?", 14 | "What is the date of death of the director of film Madame La Presidente?", 15 | "Do both directors of films Wrong Turn 5: Bloodlines and Dark River (2017 Film) have the same nationality?", 16 | "Which film has the director who died first, The Goose Woman or You Can No Longer Remain Silent?", 17 | "Where was the director of film The Private Life Of Cinema born?", 18 | "Where did Coulson Wallop's father study?", 19 | "Why did John Middleton Murry's wife die?", 20 | "Where was the composer of film Billy Elliot born?", 21 | "What is the place of birth of Lisbeth Palme's husband?", 22 | "Who is the father-in-law of Sisowath Kossamak?", 23 | "Who is Mugain's mother-in-law?", 24 | "Where did Theodore Salisbury Woolsey's father study?", 25 | "What is the place of birth of the director of film The Return Of Swamp Thing?", 26 | "Where was the place of death of the performer of song I Can'T See Myself Leaving You?", 27 | "Which film has the director who was born later, Playing It Wild or I'Ll Be Going Now?", 28 | "What nationality is Beatrice I, Countess Of Burgundy's husband?", 29 | "Which film has the director born later, Christ Walking On The Water or 45 Fathers?", 30 | "Where was the performer of song Come Dance With Me (Song) born?", 31 | "Where does the director of film Talk About A Stranger work at?", 32 | "What is the place of birth of the composer of film Inherent Vice (Film)?", 33 | "Which film whose director is younger, Dangerously They Live or Salad By The Roots?", 34 | "Where did Sylvia Burka's husband die?", 35 | "Which film whose director is younger, Phalitamsha or Gladiators Seven?", 36 | "Which film has the director who was born later, Henry Goes Arizona or The Blue Collar Worker And The Hairdresser In A Whirl Of Sex And Politics?", 37 | "Which film has the director who was born first, Tombstone Rashomon or Waiting For The Clouds?", 38 | "Where was the performer of song B Boy (Song) detained?", 39 | "Which film has the director who was born later, Illusions (1982 Film) or It'S A Wonderful Afterlife?", 40 | "Where did the composer of film The Straw Hat die?", 41 | "Where does the director of film A Nest Of Noblemen work at?", 42 | "Who is the father-in-law of John Ernest, Duke Of Saxe-Eisenach?", 43 | "Who is the father of the director of film Palo Alto (2013 Film)?", 44 | "Who is the child of the director of film Los Pagares De Mendieta?", 45 | "Who is the spouse of the director of film My Three Merry Widows?", 46 | "Which country the composer of film Thunder On The Hill is from?", 47 | "Which film has the director born later, Arrête Ton Cinéma or Agni (2004 Film)?", 48 | "Where did Prince Ferdinand Of Bavaria's mother die?", 49 | "Who is the child of the director of film An Event?", 50 | "What nationality is the director of film Good People (Film)?", 51 | "What is the place of birth of the director of film Fortunella (Film)?", 52 | "Where was the husband of Joanna Elisabeth Of Holstein-Gottorp born?", 53 | "Where was the place of death of Abdul-Aziz Bin Muhammad's father?", 54 | "Which film has the director who died later, Love, Honor And Oh-Baby! or I Cover The Underworld?", 55 | "Where was the director of film The Circus Cyclone born?", 56 | "Where did the director of film Dancing In The Rain (Film) die?", 57 | "Where did Saw Thanda's husband die?", 58 | "Which film has the director who was born earlier, The Marriage Of Princess Demidoff or The Pocket-Knife?", 59 | "Which film has the director who is older than the other, Airheads or Return To Cabin By The Lake? ", 60 | "Which film has the director died later, Lost In The Stratosphere or Blind Man'S Eyes?", 61 | "What nationality is the performer of song Am I Wrong (Étienne De Crécy Song)?", 62 | "Who is the mother of the director of film Atomised (Film)?", 63 | "What is the date of birth of Henry I Of Ziębice's father?", 64 | "Where was the composer of song Back In The U.S.A. born?", 65 | "Which film has the director who was born later, The First Day Of Freedom or Malabimba – The Malicious Whore?", 66 | "Who is Sibyl Hathaway's child-in-law?", 67 | "Which film has the director born first, Sins Of Madeleine or Captain Apache?", 68 | "Who is Godomar Ii's stepmother?", 69 | "Where was the director of film The Outlaw Express born?", 70 | "Which film has the director who died later, 45 Calibre Echo or Bons Baisers De Hong Kong?", 71 | "What nationality is Lamprocles's father?", 72 | "Which country the performer of song I Like Control is from?", 73 | "Which film has the director born later, Romance On The Run or The Palace Of Angels?", 74 | "Which film has the director born first, Mord Em'Ly or Ek Hi Bhool (1940 Film)?", 75 | "Which film has the director born earlier, The Korean Wedding Chest or True To The Navy?", 76 | "Who is Marianus V Of Arborea's mother?", 77 | "Who is the paternal grandfather of Margaret Of Bavaria, Marchioness Of Mantua?" 78 | ] 79 | -------------------------------------------------------------------------------- /benchmarks/questions/2wikimultihopqa_51.json: -------------------------------------------------------------------------------- 1 | [ 2 | "When did Lothair Ii's mother die?", 3 | "What is the place of birth of the performer of song Changed It?", 4 | "Which film has the director who is older, God'S Gift To Women or Aldri Annet Enn Bråk?", 5 | "Which film whose director was born first, El Tonto or The Heart Of Doreon?", 6 | "Who is Raghnall Mac Ruaidhrí's paternal grandfather?", 7 | "Do both films Interview With A Hitman and The Last Coupon have the directors from the same country?", 8 | "What nationality is the director of film Blood Street?", 9 | "What is the place of birth of the director of film Gaby: A True Story?", 10 | "What nationality is the performer of song When The Stars Go Blue?", 11 | "Who is the child of the performer of song Me And Bobby Mcgee?", 12 | "Where was the place of death of Maurice, Prince Of Orange's father?", 13 | "Which country Aleksander Koniecpolski (1620–1659)'s father is from?", 14 | "What is the date of death of the director of film Madame La Presidente?", 15 | "Do both directors of films Wrong Turn 5: Bloodlines and Dark River (2017 Film) have the same nationality?", 16 | "Which film has the director who died first, The Goose Woman or You Can No Longer Remain Silent?", 17 | "Where was the director of film The Private Life Of Cinema born?", 18 | "Where did Coulson Wallop's father study?", 19 | "Why did John Middleton Murry's wife die?", 20 | "Where was the composer of film Billy Elliot born?", 21 | "What is the place of birth of Lisbeth Palme's husband?", 22 | "Who is the father-in-law of Sisowath Kossamak?", 23 | "Who is Mugain's mother-in-law?", 24 | "Where did Theodore Salisbury Woolsey's father study?", 25 | "What is the place of birth of the director of film The Return Of Swamp Thing?", 26 | "Where was the place of death of the performer of song I Can'T See Myself Leaving You?", 27 | "Which film has the director who was born later, Playing It Wild or I'Ll Be Going Now?", 28 | "What nationality is Beatrice I, Countess Of Burgundy's husband?", 29 | "Which film has the director born later, Christ Walking On The Water or 45 Fathers?", 30 | "Where was the performer of song Come Dance With Me (Song) born?", 31 | "Where does the director of film Talk About A Stranger work at?", 32 | "What is the place of birth of the composer of film Inherent Vice (Film)?", 33 | "Which film whose director is younger, Dangerously They Live or Salad By The Roots?", 34 | "Where did Sylvia Burka's husband die?", 35 | "Which film whose director is younger, Phalitamsha or Gladiators Seven?", 36 | "Which film has the director who was born later, Henry Goes Arizona or The Blue Collar Worker And The Hairdresser In A Whirl Of Sex And Politics?", 37 | "Which film has the director who was born first, Tombstone Rashomon or Waiting For The Clouds?", 38 | "Where was the performer of song B Boy (Song) detained?", 39 | "Which film has the director who was born later, Illusions (1982 Film) or It'S A Wonderful Afterlife?", 40 | "Where did the composer of film The Straw Hat die?", 41 | "Where does the director of film A Nest Of Noblemen work at?" 42 | ] -------------------------------------------------------------------------------- /demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/circlemind-ai/fast-graphrag/b370efe01ef836af292a3713d59b2ec23d2fe7c4/demo.gif -------------------------------------------------------------------------------- /examples/checkpointing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Checkpointing Tutorial\n", 8 | "\n", 9 | "To properly function, fast-graphrag mantains a state synchronised among different types of databases. It is highly unlikely, but it can happend that during any reading/writing operation any of these storages can get corrupted. So, we are introducing checkpointing to signficiantly reduce the impact of this unpleasant situation. To enable checkpointing, simply set `n_checkpoints = k`, with `k > 0`:" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "from fast_graphrag import GraphRAG\n", 19 | "\n", 20 | "DOMAIN = \"Analyze this story and identify the characters. Focus on how they interact with each other, the locations they explore, and their relationships.\"\n", 21 | "\n", 22 | "EXAMPLE_QUERIES = [\n", 23 | " \"What is the significance of Christmas Eve in A Christmas Carol?\",\n", 24 | " \"How does the setting of Victorian London contribute to the story's themes?\",\n", 25 | " \"Describe the chain of events that leads to Scrooge's transformation.\",\n", 26 | " \"How does Dickens use the different spirits (Past, Present, and Future) to guide Scrooge?\",\n", 27 | " \"Why does Dickens choose to divide the story into \\\"staves\\\" rather than chapters?\"\n", 28 | "]\n", 29 | "\n", 30 | "ENTITY_TYPES = [\"Character\", \"Animal\", \"Place\", \"Object\", \"Activity\", \"Event\"]\n", 31 | "\n", 32 | "grag = GraphRAG(\n", 33 | " working_dir=\"./book_example\",\n", 34 | " n_checkpoints=2, # Number of checkpoints to keep\n", 35 | " domain=DOMAIN,\n", 36 | " example_queries=\"\\n\".join(EXAMPLE_QUERIES),\n", 37 | " entity_types=ENTITY_TYPES\n", 38 | ")\n", 39 | "\n", 40 | "with open(\"./book.txt\") as f:\n", 41 | " grag.insert(f.read())\n", 42 | "\n", 43 | "print(grag.query(\"Who is Scrooge?\").response)" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": {}, 49 | "source": [ 50 | "DONE! Now the library will automatically keep in memory the `k` most recent checkpoints and rollback to them if necessary." 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "metadata": {}, 56 | "source": [ 57 | "NOTES:\n", 58 | "- if you want to migrate a project from no checkpoints to checkpoints, simply set the flag and run a insert operation (even an empty one should do the job). Check that the checkpoint was created succesfully by querying the graph. If eveything worked correctly, you should see a new directory in you storage working dir (in the case above, it would be something like `./book_example/1731555907`). You can now safely remove all the files in the root dir `./book_example/*.*`.\n", 59 | "- if you want to stop using checkpoints, simply copy all the files from the most recent checkpoints folder in the root dir, delete all the \"number\" folders and unset `n_checkpoints`." 60 | ] 61 | } 62 | ], 63 | "metadata": { 64 | "kernelspec": { 65 | "display_name": "cm", 66 | "language": "python", 67 | "name": "python3" 68 | }, 69 | "language_info": { 70 | "name": "python", 71 | "version": "3.12.7" 72 | } 73 | }, 74 | "nbformat": 4, 75 | "nbformat_minor": 2 76 | } 77 | -------------------------------------------------------------------------------- /examples/custom_llm.py: -------------------------------------------------------------------------------- 1 | """Example usage of GraphRAG with custom LLM and Embedding services compatible with the OpenAI API.""" 2 | 3 | from typing import List 4 | 5 | import instructor 6 | from dotenv import load_dotenv 7 | 8 | from fast_graphrag import GraphRAG 9 | from fast_graphrag._llm import OpenAIEmbeddingService, OpenAILLMService 10 | 11 | load_dotenv() 12 | 13 | DOMAIN = "" 14 | QUERIES: List[str] = [] 15 | ENTITY_TYPES: List[str] = [] 16 | 17 | working_dir = "./examples/ignore/hp" 18 | grag = GraphRAG( 19 | working_dir=working_dir, 20 | domain=DOMAIN, 21 | example_queries="\n".join(QUERIES), 22 | entity_types=ENTITY_TYPES, 23 | config=GraphRAG.Config( 24 | llm_service=OpenAILLMService( 25 | model="your-llm-model", 26 | base_url="llm.api.url.com", 27 | api_key="your-api-key", 28 | mode=instructor.Mode.JSON, 29 | api_version="your-llm-api_version", 30 | client="openai or azure" 31 | ), 32 | embedding_service=OpenAIEmbeddingService( 33 | model="your-embedding-model", 34 | base_url="emb.api.url.com", 35 | api_key="your-api-key", 36 | embedding_dim=512, # the output embedding dim of the chosen model 37 | api_version="your-llm-api_version", 38 | client="openai or azure" 39 | ), 40 | ), 41 | ) 42 | -------------------------------------------------------------------------------- /examples/gemini_example.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from fast_graphrag import GraphRAG, QueryParam 4 | import asyncio 5 | from fast_graphrag._utils import logger 6 | from fast_graphrag._llm import GeminiLLMService, GeminiEmbeddingService #, VoyageAIEmbeddingService 7 | 8 | WORKING_DIR="./book_example" 9 | if not os.path.exists(WORKING_DIR): 10 | os.mkdir(WORKING_DIR) 11 | 12 | GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") 13 | 14 | DOMAIN = "Analyze this story and identify the characters. Focus on how they interact with each other, the locations they explore, and their relationships." 15 | 16 | EXAMPLE_QUERIES = [ 17 | "What is the significance of Christmas Eve in A Christmas Carol?", 18 | "How does the setting of Victorian London contribute to the story's themes?", 19 | "Describe the chain of events that leads to Scrooge's transformation.", 20 | "How does Dickens use the different spirits (Past, Present, and Future) to guide Scrooge?", 21 | "Why does Dickens choose to divide the story into \"staves\" rather than chapters?" 22 | ] 23 | 24 | # Custom entity types for story analysis 25 | ENTITY_TYPES = ["Character", "Animal", "Place", "Object", "Activity", "Event"] 26 | 27 | # For VoyageAI Embeddings, higher rate limits 28 | # from fast_graphrag._llm import VoyageAIEmbeddingService 29 | # VOYAGE_API_KEY = os.getenv("VOYAGE_API_KEY") 30 | 31 | # For PDF Processing, uses langchain and PyMuPDF 32 | #from langchain_community.document_loaders import PyMuPDFLoader 33 | #async def process_pdf(file_path: str) -> str: 34 | # """Process PDFs with error handling""" 35 | # if not os.path.exists(file_path): 36 | # raise FileNotFoundError(f"PDF file not found: {file_path}") 37 | # 38 | # try: 39 | # loader = PyMuPDFLoader(file_path) 40 | # pages = "" 41 | # for page in loader.lazy_load(): 42 | # pages += page.page_content 43 | # return pages 44 | # 45 | # except Exception as e: 46 | # raise 47 | 48 | # Text Processing Function 49 | async def process_text(file_path: str) -> str: 50 | """Process text file with encoding handling""" 51 | if not os.path.exists(file_path): 52 | raise FileNotFoundError(f"Text file not found: {file_path}") 53 | 54 | try: 55 | # Try UTF-8 first, fallback to other encodings 56 | encodings = ['utf-8', 'ascii', 'iso-8859-1', 'cp1252'] 57 | text = None 58 | 59 | for encoding in encodings: 60 | try: 61 | with open(file_path, "r", encoding=encoding) as f: 62 | text = f.read() 63 | break 64 | except UnicodeDecodeError: 65 | continue 66 | 67 | if text is None: 68 | raise UnicodeError("Failed to decode file with any supported encoding") 69 | 70 | # Clean and normalize text 71 | text = text.encode('ascii', 'ignore').decode('ascii') 72 | return text.strip() 73 | 74 | except Exception as e: 75 | logger.exception("An error occurred:", exc_info=True) 76 | raise 77 | 78 | async def streaming_query_loop(rag: GraphRAG): 79 | """Basic query loop for repeated questions""" 80 | print("\nStreaming Query Interface (type 'exit' to quit)") 81 | print("="*50) 82 | 83 | while True: 84 | try: 85 | query = input("\nYou: ").strip() 86 | if query.lower() == 'exit': 87 | print("\nExiting chat...") 88 | break 89 | 90 | print("Assistant: ", end='', flush=True) 91 | 92 | try: 93 | # Higher token limits for Gemini context windows 94 | response = await rag.async_query( 95 | query, 96 | params=QueryParam(with_references=False, only_context=False, entities_max_tokens=250000, relations_max_tokens=200000, chunks_max_tokens=500000) 97 | ) 98 | print(response.response) 99 | except Exception as e: 100 | import traceback 101 | traceback.print_exc() 102 | 103 | except KeyboardInterrupt: 104 | print("\nInterrupted by user") 105 | break 106 | except Exception as e: 107 | logger.exception("An error occurred:", exc_info=True) 108 | continue 109 | 110 | ### fast-graphrag example for Gemini 111 | async def main(): 112 | try: 113 | grag = GraphRAG( 114 | working_dir=WORKING_DIR, 115 | domain=DOMAIN, 116 | example_queries="\n".join(EXAMPLE_QUERIES), 117 | entity_types=ENTITY_TYPES, 118 | config=GraphRAG.Config( 119 | # Ensure necessary APIs have been enabled in Google Cloud, namely the Generative API 120 | # Supports Vertex usage via passing project_id and locatio, or using an 'Express' API key if available. More aggressive rate limiting will be required 121 | # Recommendation optionis using API keys from AI Studio, enabling Billing on the studio's project for much higher rate limits (2000 RPM for 2.0 Flash as of Feb 2025) 122 | llm_service = GeminiLLMService( 123 | model="gemini-2.0-flash", 124 | api_key=GEMINI_API_KEY, 125 | temperature=0.7, 126 | rate_limit_per_minute=True, 127 | rate_limit_per_second=True, 128 | max_requests_per_minute=1950, 129 | max_requests_per_second=500 130 | ), 131 | embedding_service=GeminiEmbeddingService( 132 | api_key=GEMINI_API_KEY, 133 | max_requests_concurrent=100, 134 | ), 135 | # for Voyage AI embeddings with higher rate limits than Vertex 136 | #embedding_service=VoyageAIEmbeddingService( 137 | # model="voyage-3-large", 138 | # api_key=VOYAGE_API_KEY, 139 | # embedding_dim=1024, # the output embedding dim of the chosen model 140 | #), 141 | ), 142 | ) 143 | 144 | file = await process_text("./book.txt") 145 | await grag.async_insert(file) 146 | 147 | await streaming_query_loop(grag) 148 | except Exception as e: 149 | logger.exception("An error occurred:", exc_info=True) 150 | 151 | 152 | if __name__ == "__main__": 153 | asyncio.run(main()) 154 | -------------------------------------------------------------------------------- /fast_graphrag/__init__.py: -------------------------------------------------------------------------------- 1 | """Top-level package for GraphRAG.""" 2 | 3 | __all__ = ["GraphRAG", "QueryParam"] 4 | 5 | from dataclasses import dataclass, field 6 | from typing import Type 7 | 8 | from fast_graphrag._llm import DefaultEmbeddingService, DefaultLLMService 9 | from fast_graphrag._llm._base import BaseEmbeddingService 10 | from fast_graphrag._llm._llm_openai import BaseLLMService 11 | from fast_graphrag._policies._base import BaseGraphUpsertPolicy 12 | from fast_graphrag._policies._graph_upsert import ( 13 | DefaultGraphUpsertPolicy, 14 | EdgeUpsertPolicy_UpsertIfValidNodes, 15 | EdgeUpsertPolicy_UpsertValidAndMergeSimilarByLLM, 16 | NodeUpsertPolicy_SummarizeDescription, 17 | ) 18 | from fast_graphrag._policies._ranking import RankingPolicy_TopK, RankingPolicy_WithThreshold 19 | from fast_graphrag._services import ( 20 | BaseChunkingService, 21 | BaseInformationExtractionService, 22 | BaseStateManagerService, 23 | DefaultChunkingService, 24 | DefaultInformationExtractionService, 25 | DefaultStateManagerService, 26 | ) 27 | from fast_graphrag._storage import ( 28 | DefaultGraphStorage, 29 | DefaultGraphStorageConfig, 30 | DefaultIndexedKeyValueStorage, 31 | DefaultVectorStorage, 32 | DefaultVectorStorageConfig, 33 | ) 34 | from fast_graphrag._storage._base import BaseGraphStorage 35 | from fast_graphrag._storage._namespace import Workspace 36 | from fast_graphrag._types import TChunk, TEmbedding, TEntity, THash, TId, TIndex, TRelation 37 | 38 | from ._graphrag import BaseGraphRAG, QueryParam 39 | 40 | 41 | @dataclass 42 | class GraphRAG(BaseGraphRAG[TEmbedding, THash, TChunk, TEntity, TRelation, TId]): 43 | """A class representing a Graph-based Retrieval-Augmented Generation system.""" 44 | 45 | @dataclass 46 | class Config: 47 | """Configuration for the GraphRAG class.""" 48 | 49 | chunking_service_cls: Type[BaseChunkingService[TChunk]] = field(default=DefaultChunkingService) 50 | information_extraction_service_cls: Type[BaseInformationExtractionService[TChunk, TEntity, TRelation, TId]] = ( 51 | field(default=DefaultInformationExtractionService) 52 | ) 53 | information_extraction_upsert_policy: BaseGraphUpsertPolicy[TEntity, TRelation, TId] = field( 54 | default_factory=lambda: DefaultGraphUpsertPolicy( 55 | config=NodeUpsertPolicy_SummarizeDescription.Config(), 56 | nodes_upsert_cls=NodeUpsertPolicy_SummarizeDescription, 57 | edges_upsert_cls=EdgeUpsertPolicy_UpsertIfValidNodes, 58 | ) 59 | ) 60 | state_manager_cls: Type[BaseStateManagerService[TEntity, TRelation, THash, TChunk, TId, TEmbedding]] = field( 61 | default=DefaultStateManagerService 62 | ) 63 | 64 | llm_service: BaseLLMService = field(default_factory=lambda: DefaultLLMService()) 65 | embedding_service: BaseEmbeddingService = field(default_factory=lambda: DefaultEmbeddingService()) 66 | 67 | graph_storage: BaseGraphStorage[TEntity, TRelation, TId] = field( 68 | default_factory=lambda: DefaultGraphStorage(DefaultGraphStorageConfig(node_cls=TEntity, edge_cls=TRelation)) 69 | ) 70 | entity_storage: DefaultVectorStorage[TIndex, TEmbedding] = field( 71 | default_factory=lambda: DefaultVectorStorage( 72 | DefaultVectorStorageConfig() 73 | ) 74 | ) 75 | chunk_storage: DefaultIndexedKeyValueStorage[THash, TChunk] = field( 76 | default_factory=lambda: DefaultIndexedKeyValueStorage(None) 77 | ) 78 | 79 | entity_ranking_policy: RankingPolicy_WithThreshold = field( 80 | default_factory=lambda: RankingPolicy_WithThreshold(RankingPolicy_WithThreshold.Config(threshold=0.005)) 81 | ) 82 | relation_ranking_policy: RankingPolicy_TopK = field( 83 | default_factory=lambda: RankingPolicy_TopK(RankingPolicy_TopK.Config(top_k=64)) 84 | ) 85 | chunk_ranking_policy: RankingPolicy_TopK = field( 86 | default_factory=lambda: RankingPolicy_TopK(RankingPolicy_TopK.Config(top_k=8)) 87 | ) 88 | node_upsert_policy: NodeUpsertPolicy_SummarizeDescription = field( 89 | default_factory=lambda: NodeUpsertPolicy_SummarizeDescription() 90 | ) 91 | edge_upsert_policy: EdgeUpsertPolicy_UpsertValidAndMergeSimilarByLLM = field( 92 | default_factory=lambda: EdgeUpsertPolicy_UpsertValidAndMergeSimilarByLLM() 93 | ) 94 | 95 | def __post_init__(self): 96 | """Initialize the GraphRAG Config class.""" 97 | self.entity_storage.embedding_dim = self.embedding_service.embedding_dim 98 | 99 | config: Config = field(default_factory=Config) 100 | 101 | def __post_init__(self): 102 | """Initialize the GraphRAG class.""" 103 | self.llm_service = self.config.llm_service 104 | self.embedding_service = self.config.embedding_service 105 | self.chunking_service = self.config.chunking_service_cls() 106 | self.information_extraction_service = self.config.information_extraction_service_cls( 107 | graph_upsert=self.config.information_extraction_upsert_policy 108 | ) 109 | self.state_manager = self.config.state_manager_cls( 110 | workspace=Workspace.new(self.working_dir, keep_n=self.n_checkpoints), 111 | embedding_service=self.embedding_service, 112 | graph_storage=self.config.graph_storage, 113 | entity_storage=self.config.entity_storage, 114 | chunk_storage=self.config.chunk_storage, 115 | entity_ranking_policy=self.config.entity_ranking_policy, 116 | relation_ranking_policy=self.config.relation_ranking_policy, 117 | chunk_ranking_policy=self.config.chunk_ranking_policy, 118 | node_upsert_policy=self.config.node_upsert_policy, 119 | edge_upsert_policy=self.config.edge_upsert_policy, 120 | ) 121 | -------------------------------------------------------------------------------- /fast_graphrag/_exceptions.py: -------------------------------------------------------------------------------- 1 | class InvalidStorageError(Exception): 2 | """Exception raised for errors in the storage operations.""" 3 | 4 | def __init__(self, message: str = "Invalid storage operation"): 5 | self.message = message 6 | super().__init__(self.message) 7 | 8 | 9 | class InvalidStorageUsageError(Exception): 10 | """Exception raised for errors in the usage of the storage.""" 11 | 12 | def __init__(self, message: str = "Invalid usage of the storage"): 13 | self.message = message 14 | super().__init__(self.message) 15 | 16 | 17 | class LLMServiceNoResponseError(Exception): 18 | """Exception raised when the LLM service does not provide a response.""" 19 | 20 | def __init__(self, message: str = "LLM service did not provide a response"): 21 | self.message = message 22 | super().__init__(self.message) 23 | -------------------------------------------------------------------------------- /fast_graphrag/_llm/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "BaseLLMService", 3 | "BaseEmbeddingService", 4 | "DefaultEmbeddingService", 5 | "DefaultLLMService", 6 | "format_and_send_prompt", 7 | "OpenAIEmbeddingService", 8 | "OpenAILLMService", 9 | "GeminiLLMService", 10 | "GeminiEmbeddingService", 11 | "VoyageAIEmbeddingService" 12 | ] 13 | 14 | from ._base import BaseEmbeddingService, BaseLLMService, format_and_send_prompt 15 | from ._default import DefaultEmbeddingService, DefaultLLMService 16 | from ._llm_genai import GeminiEmbeddingService, GeminiLLMService 17 | from ._llm_openai import OpenAIEmbeddingService, OpenAILLMService 18 | from ._llm_voyage import VoyageAIEmbeddingService 19 | -------------------------------------------------------------------------------- /fast_graphrag/_llm/_base.py: -------------------------------------------------------------------------------- 1 | """LLM Services module.""" 2 | 3 | import os 4 | import re 5 | from dataclasses import dataclass, field 6 | from typing import Any, Optional, Tuple, Type, TypeVar, Union 7 | 8 | import numpy as np 9 | from pydantic import BaseModel 10 | 11 | from fast_graphrag._models import BaseModelAlias 12 | from fast_graphrag._prompt import PROMPTS 13 | 14 | T_model = TypeVar("T_model", bound=Union[BaseModel, BaseModelAlias]) 15 | TOKEN_PATTERN = re.compile(r"\w+|[^\w\s]", re.UNICODE) 16 | 17 | 18 | async def format_and_send_prompt( 19 | prompt_key: str, 20 | llm: "BaseLLMService", 21 | format_kwargs: dict[str, Any], 22 | response_model: Type[T_model], 23 | **args: Any, 24 | ) -> Tuple[T_model, list[dict[str, str]]]: 25 | """Get a prompt, format it with the supplied args, and send it to the LLM. 26 | 27 | If a system prompt is provided (i.e. PROMPTS contains a key named 28 | '{prompt_key}_system'), it will use both the system and prompt entries: 29 | - System prompt: PROMPTS[prompt_key + '_system'] 30 | - Message prompt: PROMPTS[prompt_key + '_prompt'] 31 | 32 | Otherwise, it will default to using the single prompt defined by: 33 | - PROMPTS[prompt_key] 34 | 35 | Args: 36 | prompt_key (str): The key for the prompt in the PROMPTS dictionary. 37 | llm (BaseLLMService): The LLM service to use for sending the message. 38 | response_model (Type[T_model]): The expected response model. 39 | format_kwargs (dict[str, Any]): Dictionary of arguments to format the prompt. 40 | model (str | None): The model to use for the LLM. Defaults to None. 41 | max_tokens (int | None): The maximum number of tokens for the response. Defaults to None. 42 | **args (Any): Additional keyword arguments to pass to the LLM. 43 | 44 | Returns: 45 | Tuple[T_model, list[dict[str, str]]]: The response from the LLM. 46 | """ 47 | system_key = prompt_key + "_system" 48 | 49 | if system_key in PROMPTS: 50 | # Use separate system and prompt entries 51 | system = PROMPTS[system_key] 52 | prompt = PROMPTS[prompt_key + "_prompt"] 53 | formatted_system = system.format(**format_kwargs) 54 | formatted_prompt = prompt.format(**format_kwargs) 55 | return await llm.send_message( 56 | system_prompt=formatted_system, prompt=formatted_prompt, response_model=response_model, **args 57 | ) 58 | else: 59 | # Default: use the single prompt entry 60 | prompt = PROMPTS[prompt_key] 61 | formatted_prompt = prompt.format(**format_kwargs) 62 | return await llm.send_message(prompt=formatted_prompt, response_model=response_model, **args) 63 | 64 | 65 | @dataclass 66 | class BaseLLMService: 67 | """Base class for Language Model implementations.""" 68 | 69 | model: str = field() 70 | base_url: Optional[str] = field(default=None) 71 | api_key: Optional[str] = field(default=None) 72 | llm_async_client: Any = field(init=False, default=None) 73 | max_requests_concurrent: int = field(default=int(os.getenv("CONCURRENT_TASK_LIMIT", 1024))) 74 | max_requests_per_minute: int = field(default=500) 75 | max_requests_per_second: int = field(default=60) 76 | rate_limit_concurrency: bool = field(default=True) 77 | rate_limit_per_minute: bool = field(default=False) 78 | rate_limit_per_second: bool = field(default=False) 79 | 80 | def count_tokens(self, text: str) -> int: 81 | """Returns the number of tokens for a given text using the encoding appropriate for the model.""" 82 | return len(TOKEN_PATTERN.findall(text)) 83 | 84 | def is_within_token_limit(self, text: str, token_limit: int): 85 | """Lightweight check to determine if `text` fits within `token_limit` tokens. 86 | 87 | Returns the token count (an int) if it is less than or equal to the limit, 88 | otherwise returns False. 89 | """ 90 | token_count = self.count_tokens(text) 91 | return token_count if token_count <= token_limit else False 92 | 93 | async def send_message( 94 | self, 95 | prompt: str, 96 | system_prompt: str | None = None, 97 | history_messages: list[dict[str, str]] | None = None, 98 | response_model: Type[T_model] | None = None, 99 | **kwargs: Any, 100 | ) -> Tuple[T_model, list[dict[str, str]]]: 101 | """Send a message to the language model and receive a response. 102 | 103 | Args: 104 | prompt (str): The input message to send to the language model. 105 | model (str): The name of the model to use. 106 | system_prompt (str, optional): The system prompt to set the context for the conversation. Defaults to None. 107 | history_messages (list, optional): A list of previous messages in the conversation. Defaults to empty. 108 | response_model (Type[T], optional): The Pydantic model to parse the response. Defaults to None. 109 | **kwargs: Additional keyword arguments that may be required by specific LLM implementations. 110 | 111 | Returns: 112 | str: The response from the language model. 113 | """ 114 | raise NotImplementedError 115 | 116 | 117 | @dataclass 118 | class BaseEmbeddingService: 119 | """Base class for Language Model implementations.""" 120 | 121 | embedding_dim: int = field(default=1536) 122 | model: Optional[str] = field(default="text-embedding-3-small") 123 | base_url: Optional[str] = field(default=None) 124 | api_key: Optional[str] = field(default=None) 125 | max_requests_concurrent: int = field(default=int(os.getenv("CONCURRENT_TASK_LIMIT", 1024))) 126 | max_requests_per_minute: int = field(default=500) # Tier 1 OpenAI RPM 127 | max_requests_per_second: int = field(default=100) 128 | rate_limit_concurrency: bool = field(default=True) 129 | rate_limit_per_minute: bool = field(default=True) 130 | rate_limit_per_second: bool = field(default=False) 131 | 132 | embedding_async_client: Any = field(init=False, default=None) 133 | 134 | async def encode(self, texts: list[str], model: Optional[str] = None) -> np.ndarray[Any, np.dtype[np.float32]]: 135 | """Get the embedding representation of the input text. 136 | 137 | Args: 138 | texts (str): The input text to embed. 139 | model (str): The name of the model to use. 140 | 141 | Returns: 142 | list[float]: The embedding vector as a list of floats. 143 | """ 144 | raise NotImplementedError 145 | 146 | 147 | class NoopAsyncContextManager: 148 | async def __aenter__(self): 149 | return self 150 | 151 | async def __aexit__(self, exc_type: Any, exc: Any, tb: Any): 152 | pass 153 | -------------------------------------------------------------------------------- /fast_graphrag/_llm/_default.py: -------------------------------------------------------------------------------- 1 | __all__ = ['DefaultLLMService', 'DefaultEmbeddingService'] 2 | 3 | from ._llm_openai import OpenAIEmbeddingService, OpenAILLMService 4 | 5 | 6 | class DefaultLLMService(OpenAILLMService): 7 | pass 8 | class DefaultEmbeddingService(OpenAIEmbeddingService): 9 | pass 10 | -------------------------------------------------------------------------------- /fast_graphrag/_llm/_llm_voyage.py: -------------------------------------------------------------------------------- 1 | """LLM Services module.""" 2 | 3 | import asyncio 4 | import os 5 | from dataclasses import dataclass, field 6 | from itertools import chain 7 | from typing import Any, List, Optional 8 | 9 | import numpy as np 10 | from aiolimiter import AsyncLimiter 11 | from voyageai import client_async 12 | from voyageai.object.embeddings import EmbeddingsObject 13 | 14 | from fast_graphrag._utils import logger 15 | 16 | from ._base import BaseEmbeddingService, NoopAsyncContextManager 17 | 18 | 19 | @dataclass 20 | class VoyageAIEmbeddingService(BaseEmbeddingService): 21 | """Base class for VoyageAI embeddings implementations.""" 22 | 23 | embedding_dim: int = field(default=1024) 24 | max_elements_per_request: int = field(default=128) # Max 128 elements per batch for Voyage API 25 | model: Optional[str] = field(default="voyage-3") 26 | api_version: Optional[str] = field(default=None) 27 | max_requests_concurrent: int = field(default=int(os.getenv("CONCURRENT_TASK_LIMIT", 1024))) 28 | max_requests_per_minute: int = field(default=1800) 29 | max_requests_per_second: int = field(default=100) 30 | rate_limit_per_second: bool = field(default=False) 31 | 32 | def __post_init__(self): 33 | self.embedding_max_requests_concurrent = ( 34 | asyncio.Semaphore(self.max_requests_concurrent) if self.rate_limit_concurrency else NoopAsyncContextManager() 35 | ) 36 | self.embedding_per_minute_limiter = ( 37 | AsyncLimiter(self.max_requests_per_minute, 60) if self.rate_limit_per_minute else NoopAsyncContextManager() 38 | ) 39 | self.embedding_per_second_limiter = ( 40 | AsyncLimiter(self.max_requests_per_second, 1) if self.rate_limit_per_second else NoopAsyncContextManager() 41 | ) 42 | self.embedding_async_client: client_async.AsyncClient = client_async.AsyncClient( 43 | api_key=self.api_key, max_retries=4 44 | ) 45 | logger.debug("Initialized VoyageAIEmbeddingService.") 46 | 47 | async def encode(self, texts: list[str], model: Optional[str] = None) -> np.ndarray[Any, np.dtype[np.float32]]: 48 | try: 49 | """Get the embedding representation of the input text. 50 | 51 | Args: 52 | texts (str): The input text to embed. 53 | model (str, optional): The name of the model to use. Defaults to the model provided in the config. 54 | 55 | Returns: 56 | list[float]: The embedding vector as a list of floats. 57 | """ 58 | logger.debug(f"Getting embedding for texts: {texts}") 59 | model = model or self.model 60 | if model is None: 61 | raise ValueError("Model name must be provided.") 62 | 63 | batched_texts = [ 64 | texts[i * self.max_elements_per_request : (i + 1) * self.max_elements_per_request] 65 | for i in range((len(texts) + self.max_elements_per_request - 1) // self.max_elements_per_request) 66 | ] 67 | response = await asyncio.gather(*[self._embedding_request(b, model) for b in batched_texts]) 68 | 69 | data = chain(*[r.embeddings for r in response]) 70 | embeddings = np.array(list(data)) 71 | logger.debug(f"Received embedding response: {len(embeddings)} embeddings") 72 | 73 | return embeddings 74 | except Exception: 75 | logger.exception("An error occurred:", exc_info=True) 76 | raise 77 | 78 | async def _embedding_request(self, input: List[str], model: str) -> EmbeddingsObject: 79 | async with self.embedding_max_requests_concurrent: 80 | async with self.embedding_per_minute_limiter: 81 | async with self.embedding_per_second_limiter: 82 | return await self.embedding_async_client.embed(model=model, texts=input, output_dimension=self.embedding_dim) 83 | -------------------------------------------------------------------------------- /fast_graphrag/_models.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | from typing import Any, Dict, Iterable, List, Optional 3 | 4 | from pydantic import BaseModel, Field, field_validator 5 | from pydantic._internal import _model_construction 6 | 7 | #################################################################################################### 8 | # LLM Models 9 | #################################################################################################### 10 | 11 | 12 | def _json_schema_slim(schema: dict[str, Any]) -> None: 13 | schema.pop("required") 14 | for prop in schema.get("properties", {}).values(): 15 | prop.pop("title", None) 16 | 17 | 18 | class _BaseModelAliasMeta(_model_construction.ModelMetaclass): 19 | def __new__( 20 | cls, name: str, bases: tuple[type[Any], ...], dct: Dict[str, Any], alias: Optional[str] = None, **kwargs: Any 21 | ) -> type: 22 | if alias: 23 | dct["__qualname__"] = alias 24 | name = alias 25 | return super().__new__(cls, name, bases, dct, json_schema_extra=_json_schema_slim, **kwargs) 26 | 27 | 28 | class BaseModelAlias: 29 | class Model(BaseModel, metaclass=_BaseModelAliasMeta): 30 | @staticmethod 31 | def to_dataclass(pydantic: Any) -> Any: 32 | raise NotImplementedError 33 | 34 | def to_str(self) -> str: 35 | raise NotImplementedError 36 | 37 | 38 | #################################################################################################### 39 | # LLM Dumping to strings 40 | #################################################################################################### 41 | 42 | 43 | def dump_to_csv( 44 | data: Iterable[object], 45 | fields: List[str], 46 | separator: str = "\t", 47 | with_header: bool = False, 48 | **values: Dict[str, List[Any]], 49 | ) -> List[str]: 50 | rows = list( 51 | chain( 52 | (separator.join(chain(fields, values.keys())),) if with_header else (), 53 | chain( 54 | separator.join( 55 | chain( 56 | (str(getattr(d, field)).replace("\n", " ").replace("\t", " ") for field in fields), 57 | (str(v).replace("\n", " ").replace("\t", " ") for v in vs), 58 | ) 59 | ) 60 | for d, *vs in zip(data, *values.values()) 61 | ), 62 | ) 63 | ) 64 | return rows 65 | 66 | 67 | def dump_to_reference_list(data: Iterable[object], separator: str = "\n=====\n\n"): 68 | return [f"[{i + 1}] {d}{separator}" for i, d in enumerate(data)] 69 | 70 | 71 | #################################################################################################### 72 | # Response Models 73 | #################################################################################################### 74 | 75 | 76 | class TAnswer(BaseModel): 77 | answer: str 78 | 79 | 80 | class TEditRelation(BaseModel): 81 | ids: List[int] = Field(..., description="Ids of the facts that you are combining into one") 82 | description: str = Field( 83 | ..., description="Summarized description of the combined facts, in detail and comprehensive" 84 | ) 85 | 86 | 87 | class TEditRelationList(BaseModel): 88 | groups: List[TEditRelation] = Field( 89 | ..., 90 | description="List of new fact groups; include only groups of more than one fact", 91 | alias="grouped_facts", 92 | ) 93 | 94 | 95 | class TEntityDescription(BaseModel): 96 | description: str 97 | 98 | 99 | class TQueryEntities(BaseModel): 100 | named: List[str] = Field( 101 | ..., 102 | description=("List of named entities extracted from the query"), 103 | ) 104 | generic: List[str] = Field( 105 | ..., 106 | description=("List of generic entities extracted from the query"), 107 | ) 108 | 109 | @field_validator("named", mode="before") 110 | @classmethod 111 | def uppercase_named(cls, value: List[str]): 112 | return [e.upper() for e in value] if value else value 113 | 114 | # @field_validator("generic", mode="before") 115 | # @classmethod 116 | # def uppercase_generic(cls, value: List[str]): 117 | # return [e.upper() for e in value] if value else value 118 | -------------------------------------------------------------------------------- /fast_graphrag/_policies/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/circlemind-ai/fast-graphrag/b370efe01ef836af292a3713d59b2ec23d2fe7c4/fast_graphrag/_policies/__init__.py -------------------------------------------------------------------------------- /fast_graphrag/_policies/_base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Any, Generic, Iterable, Tuple, Type 3 | 4 | from scipy.sparse import csr_matrix 5 | 6 | from fast_graphrag._llm._llm_openai import BaseLLMService 7 | from fast_graphrag._storage._base import BaseGraphStorage 8 | from fast_graphrag._types import GTEdge, GTId, GTNode, TIndex 9 | 10 | 11 | @dataclass 12 | class BasePolicy: 13 | config: Any = field() 14 | 15 | 16 | #################################################################################################### 17 | # GRAPH UPSERT POLICIES 18 | #################################################################################################### 19 | 20 | 21 | @dataclass 22 | class BaseNodeUpsertPolicy(BasePolicy, Generic[GTNode, GTId]): 23 | async def __call__( 24 | self, llm: BaseLLMService, target: BaseGraphStorage[GTNode, GTEdge, GTId], source_nodes: Iterable[GTNode] 25 | ) -> Tuple[BaseGraphStorage[GTNode, GTEdge, GTId], Iterable[Tuple[TIndex, GTNode]]]: 26 | raise NotImplementedError 27 | 28 | 29 | @dataclass 30 | class BaseEdgeUpsertPolicy(BasePolicy, Generic[GTEdge, GTId]): 31 | async def __call__( 32 | self, llm: BaseLLMService, target: BaseGraphStorage[GTNode, GTEdge, GTId], source_edges: Iterable[GTEdge] 33 | ) -> Tuple[BaseGraphStorage[GTNode, GTEdge, GTId], Iterable[Tuple[TIndex, GTEdge]]]: 34 | raise NotImplementedError 35 | 36 | 37 | @dataclass 38 | class BaseGraphUpsertPolicy(BasePolicy, Generic[GTNode, GTEdge, GTId]): 39 | nodes_upsert_cls: Type[BaseNodeUpsertPolicy[GTNode, GTId]] = field() 40 | edges_upsert_cls: Type[BaseEdgeUpsertPolicy[GTEdge, GTId]] = field() 41 | _nodes_upsert: BaseNodeUpsertPolicy[GTNode, GTId] = field(init=False) 42 | _edges_upsert: BaseEdgeUpsertPolicy[GTEdge, GTId] = field(init=False) 43 | 44 | def __post_init__(self): 45 | self._nodes_upsert = self.nodes_upsert_cls(self.config) 46 | self._edges_upsert = self.edges_upsert_cls(self.config) 47 | 48 | async def __call__( 49 | self, 50 | llm: BaseLLMService, 51 | target: BaseGraphStorage[GTNode, GTEdge, GTId], 52 | source_nodes: Iterable[GTNode], 53 | source_edges: Iterable[GTEdge], 54 | ) -> Tuple[ 55 | BaseGraphStorage[GTNode, GTEdge, GTId], 56 | Iterable[Tuple[TIndex, GTNode]], 57 | Iterable[Tuple[TIndex, GTEdge]], 58 | ]: 59 | raise NotImplementedError 60 | 61 | 62 | #################################################################################################### 63 | # RANKING POLICIES 64 | #################################################################################################### 65 | 66 | 67 | class BaseRankingPolicy(BasePolicy): 68 | def __call__(self, scores: csr_matrix) -> csr_matrix: 69 | assert scores.shape[0] == 1, "Ranking policies only supports batch size of 1" 70 | return scores 71 | -------------------------------------------------------------------------------- /fast_graphrag/_policies/_ranking.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | import numpy as np 4 | from scipy.sparse import csr_matrix 5 | 6 | from ._base import BaseRankingPolicy 7 | 8 | 9 | class RankingPolicy_WithThreshold(BaseRankingPolicy): # noqa: N801 10 | @dataclass 11 | class Config: 12 | threshold: float = field(default=0.05) 13 | max_entities: int = field(default=128) 14 | 15 | config: Config = field() 16 | 17 | def __call__(self, scores: csr_matrix) -> csr_matrix: 18 | # Remove scores below threshold 19 | scores.data[scores.data < self.config.threshold] = 0 20 | if scores.nnz >= self.config.max_entities: 21 | smallest_indices = np.argpartition(scores.data, -self.config.max_entities)[:-self.config.max_entities] 22 | scores.data[smallest_indices] = 0 23 | scores.eliminate_zeros() 24 | 25 | return scores 26 | 27 | 28 | class RankingPolicy_TopK(BaseRankingPolicy): # noqa: N801 29 | @dataclass 30 | class Config: 31 | top_k: int = field(default=10) 32 | 33 | top_k: Config = field() 34 | 35 | def __call__(self, scores: csr_matrix) -> csr_matrix: 36 | assert scores.shape[0] == 1, "TopK policy only supports batch size of 1" 37 | if scores.nnz <= self.config.top_k: 38 | return scores 39 | 40 | smallest_indices = np.argpartition(scores.data, -self.config.top_k)[:-self.config.top_k] 41 | scores.data[smallest_indices] = 0 42 | scores.eliminate_zeros() 43 | 44 | return scores 45 | 46 | 47 | class RankingPolicy_Elbow(BaseRankingPolicy): # noqa: N801 48 | def __call__(self, scores: csr_matrix) -> csr_matrix: 49 | assert scores.shape[0] == 1, "Elbow policy only supports batch size of 1" 50 | if scores.nnz <= 1: 51 | return scores 52 | 53 | sorted_scores = np.sort(scores.data) 54 | 55 | # Compute elbow 56 | diff = np.diff(sorted_scores) 57 | elbow = np.argmax(diff) + 1 58 | 59 | smallest_indices = np.argpartition(scores.data, elbow)[:elbow] 60 | scores.data[smallest_indices] = 0 61 | scores.eliminate_zeros() 62 | 63 | return scores 64 | 65 | 66 | class RankingPolicy_WithConfidence(BaseRankingPolicy): # noqa: N801 67 | def __call__(self, scores: csr_matrix) -> csr_matrix: 68 | raise NotImplementedError("Confidence policy is not supported yet.") 69 | -------------------------------------------------------------------------------- /fast_graphrag/_services/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'BaseChunkingService', 3 | 'BaseInformationExtractionService', 4 | 'BaseStateManagerService', 5 | 'DefaultChunkingService', 6 | 'DefaultInformationExtractionService', 7 | 'DefaultStateManagerService' 8 | ] 9 | 10 | from ._base import BaseChunkingService, BaseInformationExtractionService, BaseStateManagerService 11 | from ._chunk_extraction import DefaultChunkingService 12 | from ._information_extraction import DefaultInformationExtractionService 13 | from ._state_manager import DefaultStateManagerService 14 | -------------------------------------------------------------------------------- /fast_graphrag/_services/_base.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from dataclasses import dataclass, field 3 | from typing import Dict, Generic, Iterable, List, Optional, Type 4 | 5 | from scipy.sparse import csr_matrix 6 | 7 | from fast_graphrag._llm import BaseEmbeddingService, BaseLLMService 8 | from fast_graphrag._policies._base import ( 9 | BaseEdgeUpsertPolicy, 10 | BaseGraphUpsertPolicy, 11 | BaseNodeUpsertPolicy, 12 | BaseRankingPolicy, 13 | ) 14 | from fast_graphrag._storage import BaseBlobStorage, BaseGraphStorage, BaseIndexedKeyValueStorage, BaseVectorStorage 15 | from fast_graphrag._storage._namespace import Workspace 16 | from fast_graphrag._types import ( 17 | GTChunk, 18 | GTEdge, 19 | GTEmbedding, 20 | GTHash, 21 | GTId, 22 | GTNode, 23 | TContext, 24 | TDocument, 25 | TIndex, 26 | ) 27 | 28 | 29 | @dataclass 30 | class BaseChunkingService(Generic[GTChunk]): 31 | """Base class for chunk extractor.""" 32 | 33 | def __post__init__(self): 34 | pass 35 | 36 | async def extract(self, data: Iterable[TDocument]) -> Iterable[Iterable[GTChunk]]: 37 | """Extract unique chunks from the given data.""" 38 | raise NotImplementedError 39 | 40 | 41 | @dataclass 42 | class BaseInformationExtractionService(Generic[GTChunk, GTNode, GTEdge, GTId]): 43 | """Base class for entity and relationship extractors.""" 44 | 45 | graph_upsert: BaseGraphUpsertPolicy[GTNode, GTEdge, GTId] 46 | max_gleaning_steps: int = 0 47 | 48 | def extract( 49 | self, 50 | llm: BaseLLMService, 51 | documents: Iterable[Iterable[GTChunk]], 52 | prompt_kwargs: Dict[str, str], 53 | entity_types: List[str], 54 | ) -> List[asyncio.Future[Optional[BaseGraphStorage[GTNode, GTEdge, GTId]]]]: 55 | """Extract both entities and relationships from the given data.""" 56 | raise NotImplementedError 57 | 58 | async def extract_entities_from_query( 59 | self, llm: BaseLLMService, query: str, prompt_kwargs: Dict[str, str] 60 | ) -> Dict[str, List[str]]: 61 | """Extract entities from the given query.""" 62 | raise NotImplementedError 63 | 64 | 65 | @dataclass 66 | class BaseStateManagerService(Generic[GTNode, GTEdge, GTHash, GTChunk, GTId, GTEmbedding]): 67 | """A class for managing state operations.""" 68 | 69 | workspace: Optional[Workspace] = field() 70 | 71 | graph_storage: BaseGraphStorage[GTNode, GTEdge, GTId] = field() 72 | entity_storage: BaseVectorStorage[TIndex, GTEmbedding] = field() 73 | chunk_storage: BaseIndexedKeyValueStorage[GTHash, GTChunk] = field() 74 | 75 | embedding_service: BaseEmbeddingService = field() 76 | 77 | node_upsert_policy: BaseNodeUpsertPolicy[GTNode, GTId] = field() 78 | edge_upsert_policy: BaseEdgeUpsertPolicy[GTEdge, GTId] = field() 79 | 80 | entity_ranking_policy: BaseRankingPolicy = field(default_factory=lambda: BaseRankingPolicy(None)) 81 | relation_ranking_policy: BaseRankingPolicy = field(default_factory=lambda: BaseRankingPolicy(None)) 82 | chunk_ranking_policy: BaseRankingPolicy = field(default_factory=lambda: BaseRankingPolicy(None)) 83 | 84 | node_specificity: bool = field(default=False) 85 | 86 | blob_storage_cls: Type[BaseBlobStorage[csr_matrix]] = field(default=BaseBlobStorage) 87 | 88 | async def insert_start(self) -> None: 89 | """Prepare the storage for indexing before adding new data.""" 90 | raise NotImplementedError 91 | 92 | async def insert_done(self) -> None: 93 | """Commit the storage operations after indexing.""" 94 | raise NotImplementedError 95 | 96 | async def query_start(self) -> None: 97 | """Prepare the storage for indexing before adding new data.""" 98 | raise NotImplementedError 99 | 100 | async def query_done(self) -> None: 101 | """Commit the storage operations after indexing.""" 102 | raise NotImplementedError 103 | 104 | async def filter_new_chunks(self, chunks_per_data: Iterable[Iterable[GTChunk]]) -> List[List[GTChunk]]: 105 | """Filter the chunks to check for duplicates. 106 | 107 | This method takes a sequence of chunks and returns a sequence of new chunks 108 | that are not already present in the storage. It uses a hashing mechanism to 109 | efficiently identify duplicates. 110 | 111 | Args: 112 | chunks_per_data (Iterable[Iterable[TChunk]]): A sequence of chunks to be filtered. 113 | 114 | Returns: 115 | Iterable[Iterable[TChunk]]: A sequence of chunks that are not in the storage. 116 | """ 117 | raise NotImplementedError 118 | 119 | async def upsert( 120 | self, 121 | llm: BaseLLMService, 122 | subgraphs: List[asyncio.Future[Optional[BaseGraphStorage[GTNode, GTEdge, GTId]]]], 123 | documents: Iterable[Iterable[GTChunk]], 124 | show_progress: bool = True 125 | ) -> None: 126 | """Clean and upsert entities, relationships, and chunks into the storage.""" 127 | raise NotImplementedError 128 | 129 | async def get_context( 130 | self, query: str, entities: Dict[str, List[str]] 131 | ) -> Optional[TContext[GTNode, GTEdge, GTHash, GTChunk]]: 132 | """Retrieve relevant state from the storage.""" 133 | raise NotImplementedError 134 | 135 | async def get_num_entities(self) -> int: 136 | """Get the number of entities in the storage.""" 137 | raise NotImplementedError 138 | 139 | async def get_num_relations(self) -> int: 140 | """Get the number of relations in the storage.""" 141 | raise NotImplementedError 142 | 143 | async def get_num_chunks(self) -> int: 144 | """Get the number of chunks in the storage.""" 145 | raise NotImplementedError 146 | 147 | async def save_graphml(self, output_path: str) -> None: 148 | """Save the graph in GraphML format.""" 149 | raise NotImplementedError 150 | -------------------------------------------------------------------------------- /fast_graphrag/_services/_chunk_extraction.py: -------------------------------------------------------------------------------- 1 | import re 2 | from dataclasses import dataclass, field 3 | from itertools import chain 4 | from typing import Iterable, List, Set, Tuple 5 | 6 | import xxhash 7 | 8 | from fast_graphrag._types import TChunk, TDocument, THash 9 | from fast_graphrag._utils import TOKEN_TO_CHAR_RATIO 10 | 11 | from ._base import BaseChunkingService 12 | 13 | DEFAULT_SEPARATORS = [ 14 | # Paragraph and page separators 15 | "\n\n\n", 16 | "\n\n", 17 | "\r\n\r\n", 18 | # Sentence ending punctuation 19 | "。", # Chinese period 20 | ".", # Full-width dot 21 | ".", # English period 22 | "!", # Chinese exclamation mark 23 | "!", # English exclamation mark 24 | "?", # Chinese question mark 25 | "?", # English question mark 26 | ] 27 | 28 | 29 | @dataclass 30 | class DefaultChunkingServiceConfig: 31 | separators: List[str] = field(default_factory=lambda: DEFAULT_SEPARATORS) 32 | chunk_token_size: int = field(default=800) 33 | chunk_token_overlap: int = field(default=100) 34 | 35 | 36 | @dataclass 37 | class DefaultChunkingService(BaseChunkingService[TChunk]): 38 | """Default class for chunk extractor.""" 39 | 40 | config: DefaultChunkingServiceConfig = field(default_factory=DefaultChunkingServiceConfig) 41 | 42 | def __post_init__(self): 43 | self._split_re = re.compile(f"({'|'.join(re.escape(s) for s in self.config.separators or [])})") 44 | self._chunk_size = self.config.chunk_token_size * TOKEN_TO_CHAR_RATIO 45 | self._chunk_overlap = self.config.chunk_token_overlap * TOKEN_TO_CHAR_RATIO 46 | 47 | async def extract(self, data: Iterable[TDocument]) -> Iterable[Iterable[TChunk]]: 48 | """Extract unique chunks from the given data.""" 49 | chunks_per_data: List[List[TChunk]] = [] 50 | 51 | for d in data: 52 | unique_chunk_ids: Set[THash] = set() 53 | extracted_chunks = await self._extract_chunks(d) 54 | chunks: List[TChunk] = [] 55 | for chunk in extracted_chunks: 56 | if chunk.id not in unique_chunk_ids: 57 | unique_chunk_ids.add(chunk.id) 58 | chunks.append(chunk) 59 | chunks_per_data.append(chunks) 60 | 61 | return chunks_per_data 62 | 63 | async def _extract_chunks(self, data: TDocument) -> List[TChunk]: 64 | # Sanitise input data: 65 | try: 66 | data.data = data.data.encode(errors="replace").decode() 67 | except UnicodeDecodeError: 68 | # Default to replacing all unrecognised characters with a space 69 | data.data = re.sub(r"[\x00-\x09\x11-\x12\x14-\x1f]", " ", data.data) 70 | 71 | if len(data.data) <= self._chunk_size: 72 | chunks = [data.data] 73 | else: 74 | chunks = self._split_text(data.data) 75 | 76 | return [ 77 | TChunk( 78 | id=THash(xxhash.xxh3_64_intdigest(chunk) // 2), 79 | content=chunk, 80 | metadata=data.metadata, 81 | ) 82 | for chunk in chunks 83 | ] 84 | 85 | def _split_text(self, text: str) -> List[str]: 86 | return self._merge_splits(self._split_re.split(text)) 87 | 88 | def _merge_splits(self, splits: List[str]) -> List[str]: 89 | if not splits: 90 | return [] 91 | 92 | # Add empty string to the end to have a separator at the end of the last chunk 93 | splits.append("") 94 | 95 | merged_splits: List[List[Tuple[str, int]]] = [] 96 | current_chunk: List[Tuple[str, int]] = [] 97 | current_chunk_length: int = 0 98 | 99 | for i, split in enumerate(splits): 100 | split_length: int = len(split) 101 | # Ignore splitting if it's a separator 102 | if (i % 2 == 1) or ( 103 | current_chunk_length + split_length <= self._chunk_size - (self._chunk_overlap if i > 0 else 0) 104 | ): 105 | current_chunk.append((split, split_length)) 106 | current_chunk_length += split_length 107 | else: 108 | merged_splits.append(current_chunk) 109 | current_chunk = [(split, split_length)] 110 | current_chunk_length = split_length 111 | 112 | merged_splits.append(current_chunk) 113 | 114 | if self._chunk_overlap > 0: 115 | return self._enforce_overlap(merged_splits) 116 | else: 117 | r = ["".join((c[0] for c in chunk)) for chunk in merged_splits] 118 | 119 | return r 120 | 121 | def _enforce_overlap(self, chunks: List[List[Tuple[str, int]]]) -> List[str]: 122 | result: List[str] = [] 123 | for i, chunk in enumerate(chunks): 124 | if i == 0: 125 | result.append("".join((c[0] for c in chunk))) 126 | else: 127 | # Compute overlap 128 | overlap_length: int = 0 129 | overlap: List[str] = [] 130 | for text, length in reversed(chunks[i - 1]): 131 | if overlap_length + length > self._chunk_overlap: 132 | break 133 | overlap_length += length 134 | overlap.append(text) 135 | result.append("".join(chain(reversed(overlap), (c[0] for c in chunk)))) 136 | return result 137 | -------------------------------------------------------------------------------- /fast_graphrag/_services/_information_extraction.py: -------------------------------------------------------------------------------- 1 | """Entity-Relationship extraction module.""" 2 | import asyncio 3 | import re 4 | from dataclasses import dataclass 5 | from typing import Dict, Iterable, List, Literal, Optional 6 | 7 | from pydantic import BaseModel, Field 8 | 9 | from fast_graphrag._llm import BaseLLMService, format_and_send_prompt 10 | from fast_graphrag._models import TQueryEntities 11 | from fast_graphrag._storage._base import BaseGraphStorage 12 | from fast_graphrag._storage._gdb_igraph import IGraphStorage, IGraphStorageConfig 13 | from fast_graphrag._types import GTId, TChunk, TEntity, TGraph, TRelation 14 | from fast_graphrag._utils import logger 15 | 16 | from ._base import BaseInformationExtractionService 17 | 18 | 19 | class TGleaningStatus(BaseModel): 20 | status: Literal["done", "continue"] = Field( 21 | description="done if all entities and relationship have been extracted, continue otherwise" 22 | ) 23 | 24 | 25 | @dataclass 26 | class DefaultInformationExtractionService(BaseInformationExtractionService[TChunk, TEntity, TRelation, GTId]): 27 | """Default entity and relationship extractor.""" 28 | 29 | def extract( 30 | self, 31 | llm: BaseLLMService, 32 | documents: Iterable[Iterable[TChunk]], 33 | prompt_kwargs: Dict[str, str], 34 | entity_types: List[str], 35 | ) -> List[asyncio.Future[Optional[BaseGraphStorage[TEntity, TRelation, GTId]]]]: 36 | """Extract both entities and relationships from the given data.""" 37 | return [ 38 | asyncio.create_task(self._extract(llm, document, prompt_kwargs, entity_types)) for document in documents 39 | ] 40 | 41 | async def extract_entities_from_query( 42 | self, llm: BaseLLMService, query: str, prompt_kwargs: Dict[str, str] 43 | ) -> Dict[str, List[str]]: 44 | """Extract entities from the given query.""" 45 | prompt_kwargs["query"] = query 46 | entities, _ = await format_and_send_prompt( 47 | prompt_key="entity_extraction_query", 48 | llm=llm, 49 | format_kwargs=prompt_kwargs, 50 | response_model=TQueryEntities, 51 | ) 52 | 53 | return { 54 | "named": entities.named, 55 | "generic": entities.generic 56 | } 57 | 58 | async def _extract( 59 | self, llm: BaseLLMService, chunks: Iterable[TChunk], prompt_kwargs: Dict[str, str], entity_types: List[str] 60 | ) -> Optional[BaseGraphStorage[TEntity, TRelation, GTId]]: 61 | """Extract both entities and relationships from the given chunks.""" 62 | # Extract entities and relatioships from each chunk 63 | try: 64 | chunk_graphs = await asyncio.gather( 65 | *[self._extract_from_chunk(llm, chunk, prompt_kwargs, entity_types) for chunk in chunks] 66 | ) 67 | if len(chunk_graphs) == 0: 68 | return None 69 | 70 | # Combine chunk graphs in document graph 71 | return await self._merge(llm, chunk_graphs) 72 | except Exception as e: 73 | logger.error(f"Error during information extraction from document: {e}") 74 | return None 75 | 76 | async def _gleaning( 77 | self, llm: BaseLLMService, initial_graph: TGraph, history: list[dict[str, str]] 78 | ) -> Optional[TGraph]: 79 | """Do gleaning steps until the llm says we are done or we reach the max gleaning steps.""" 80 | # Prompts 81 | current_graph = initial_graph 82 | 83 | try: 84 | for gleaning_count in range(self.max_gleaning_steps): 85 | # Do gleaning step 86 | gleaning_result, history = await format_and_send_prompt( 87 | prompt_key="entity_relationship_continue_extraction", 88 | llm=llm, 89 | format_kwargs={}, 90 | response_model=TGraph, 91 | history_messages=history, 92 | ) 93 | 94 | # Combine new entities, relationships with previously obtained ones 95 | current_graph.entities.extend(gleaning_result.entities) 96 | current_graph.relationships.extend(gleaning_result.relationships) 97 | 98 | # Stop gleaning if we don't need to keep going 99 | if gleaning_count == self.max_gleaning_steps - 1: 100 | break 101 | 102 | # Ask llm if we are done extracting entities and relationships 103 | gleaning_status, _ = await format_and_send_prompt( 104 | prompt_key="entity_relationship_gleaning_done_extraction", 105 | llm=llm, 106 | format_kwargs={}, 107 | response_model=TGleaningStatus, 108 | history_messages=history, 109 | ) 110 | 111 | # If we are done parsing, stop gleaning 112 | if gleaning_status.status == Literal["done"]: 113 | break 114 | except Exception as e: 115 | logger.error(f"Error during gleaning: {e}") 116 | 117 | return None 118 | 119 | return current_graph 120 | 121 | async def _extract_from_chunk( 122 | self, llm: BaseLLMService, chunk: TChunk, prompt_kwargs: Dict[str, str], entity_types: List[str] 123 | ) -> TGraph: 124 | """Extract entities and relationships from the given chunk.""" 125 | prompt_kwargs["input_text"] = chunk.content 126 | 127 | chunk_graph, history = await format_and_send_prompt( 128 | prompt_key="entity_relationship_extraction", 129 | llm=llm, 130 | format_kwargs=prompt_kwargs, 131 | response_model=TGraph, 132 | ) 133 | 134 | # Do gleaning 135 | chunk_graph_with_gleaning = await self._gleaning(llm, chunk_graph, history) 136 | if chunk_graph_with_gleaning: 137 | chunk_graph = chunk_graph_with_gleaning 138 | 139 | _clean_entity_types = [re.sub("[ _]", "", entity_type).upper() for entity_type in entity_types] 140 | for entity in chunk_graph.entities: 141 | if re.sub("[ _]", "", entity.type).upper() not in _clean_entity_types: 142 | entity.type = "UNKNOWN" 143 | 144 | # Assign chunk ids to relationships 145 | for relationship in chunk_graph.relationships: 146 | relationship.chunks = [chunk.id] 147 | 148 | return chunk_graph 149 | 150 | async def _merge(self, llm: BaseLLMService, graphs: List[TGraph]) -> BaseGraphStorage[TEntity, TRelation, GTId]: 151 | """Merge the given graphs into a single graph storage.""" 152 | graph_storage = IGraphStorage[TEntity, TRelation, GTId](config=IGraphStorageConfig(TEntity, TRelation)) 153 | 154 | await graph_storage.insert_start() 155 | 156 | try: 157 | # This is synchronous since each sub graph is inserted into the graph storage and conflicts are resolved 158 | for graph in graphs: 159 | await self.graph_upsert(llm, graph_storage, graph.entities, graph.relationships) 160 | finally: 161 | await graph_storage.insert_done() 162 | 163 | return graph_storage 164 | -------------------------------------------------------------------------------- /fast_graphrag/_storage/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'Namespace', 3 | 'BaseBlobStorage', 4 | 'BaseIndexedKeyValueStorage', 5 | 'BaseVectorStorage', 6 | 'BaseGraphStorage', 7 | 'DefaultBlobStorage', 8 | 'DefaultIndexedKeyValueStorage', 9 | 'DefaultVectorStorage', 10 | 'DefaultGraphStorage', 11 | 'DefaultGraphStorageConfig', 12 | 'DefaultVectorStorageConfig', 13 | ] 14 | 15 | from ._base import BaseBlobStorage, BaseGraphStorage, BaseIndexedKeyValueStorage, BaseVectorStorage, Namespace 16 | from ._default import ( 17 | DefaultBlobStorage, 18 | DefaultGraphStorage, 19 | DefaultGraphStorageConfig, 20 | DefaultIndexedKeyValueStorage, 21 | DefaultVectorStorage, 22 | DefaultVectorStorageConfig, 23 | ) 24 | -------------------------------------------------------------------------------- /fast_graphrag/_storage/_blob_pickle.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from dataclasses import dataclass, field 3 | from typing import Optional 4 | 5 | from fast_graphrag._exceptions import InvalidStorageError 6 | from fast_graphrag._types import GTBlob 7 | from fast_graphrag._utils import logger 8 | 9 | from ._base import BaseBlobStorage 10 | 11 | 12 | @dataclass 13 | class PickleBlobStorage(BaseBlobStorage[GTBlob]): 14 | RESOURCE_NAME = "blob_data.pkl" 15 | _data: Optional[GTBlob] = field(init=False, default=None) 16 | 17 | async def get(self) -> Optional[GTBlob]: 18 | return self._data 19 | 20 | async def set(self, blob: GTBlob) -> None: 21 | self._data = blob 22 | 23 | async def _insert_start(self): 24 | if self.namespace: 25 | data_file_name = self.namespace.get_load_path(self.RESOURCE_NAME) 26 | if data_file_name: 27 | try: 28 | with open(data_file_name, "rb") as f: 29 | self._data = pickle.load(f) 30 | except Exception as e: 31 | t = f"Error loading data file for blob storage {data_file_name}: {e}" 32 | logger.error(t) 33 | raise InvalidStorageError(t) from e 34 | else: 35 | logger.info(f"No data file found for blob storage {data_file_name}. Loading empty storage.") 36 | self._data = None 37 | else: 38 | self._data = None 39 | logger.debug("Creating new volatile blob storage.") 40 | 41 | async def _insert_done(self): 42 | if self.namespace: 43 | data_file_name = self.namespace.get_save_path(self.RESOURCE_NAME) 44 | try: 45 | with open(data_file_name, "wb") as f: 46 | pickle.dump(self._data, f) 47 | logger.debug( 48 | f"Saving blob storage '{data_file_name}'." 49 | ) 50 | except Exception as e: 51 | logger.error(f"Error saving data file for blob storage {data_file_name}: {e}") 52 | 53 | async def _query_start(self): 54 | assert self.namespace, "Loading a blob storage requires a namespace." 55 | 56 | data_file_name = self.namespace.get_load_path(self.RESOURCE_NAME) 57 | if data_file_name: 58 | try: 59 | with open(data_file_name, "rb") as f: 60 | self._data = pickle.load(f) 61 | except Exception as e: 62 | t = f"Error loading data file for blob storage {data_file_name}: {e}" 63 | logger.error(t) 64 | raise InvalidStorageError(t) from e 65 | else: 66 | logger.warning(f"No data file found for blob storage {data_file_name}. Loading empty blob.") 67 | self._data = None 68 | 69 | async def _query_done(self): 70 | pass 71 | -------------------------------------------------------------------------------- /fast_graphrag/_storage/_default.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "DefaultVectorStorage", 3 | "DefaultVectorStorageConfig", 4 | "DefaultBlobStorage", 5 | "DefaultIndexedKeyValueStorage", 6 | "DefaultGraphStorage", 7 | "DefaultGraphStorageConfig", 8 | ] 9 | 10 | from fast_graphrag._storage._blob_pickle import PickleBlobStorage 11 | from fast_graphrag._storage._gdb_igraph import IGraphStorage, IGraphStorageConfig 12 | from fast_graphrag._storage._ikv_pickle import PickleIndexedKeyValueStorage 13 | from fast_graphrag._storage._vdb_hnswlib import HNSWVectorStorage, HNSWVectorStorageConfig 14 | from fast_graphrag._types import GTBlob, GTEdge, GTEmbedding, GTId, GTKey, GTNode, GTValue 15 | 16 | 17 | # Storage 18 | class DefaultVectorStorage(HNSWVectorStorage[GTId, GTEmbedding]): 19 | pass 20 | class DefaultVectorStorageConfig(HNSWVectorStorageConfig): 21 | pass 22 | class DefaultBlobStorage(PickleBlobStorage[GTBlob]): 23 | pass 24 | class DefaultIndexedKeyValueStorage(PickleIndexedKeyValueStorage[GTKey, GTValue]): 25 | pass 26 | class DefaultGraphStorage(IGraphStorage[GTNode, GTEdge, GTId]): 27 | pass 28 | class DefaultGraphStorageConfig(IGraphStorageConfig[GTNode, GTEdge]): 29 | pass 30 | -------------------------------------------------------------------------------- /fast_graphrag/_storage/_ikv_pickle.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from dataclasses import dataclass, field 3 | from typing import Dict, Iterable, List, Optional, Union 4 | 5 | import numpy as np 6 | import numpy.typing as npt 7 | 8 | from fast_graphrag._exceptions import InvalidStorageError 9 | from fast_graphrag._types import GTKey, GTValue, TIndex 10 | from fast_graphrag._utils import logger 11 | 12 | from ._base import BaseIndexedKeyValueStorage 13 | 14 | 15 | @dataclass 16 | class PickleIndexedKeyValueStorage(BaseIndexedKeyValueStorage[GTKey, GTValue]): 17 | RESOURCE_NAME = "kv_data.pkl" 18 | _data: Dict[Union[None, TIndex], GTValue] = field(init=False, default_factory=dict) 19 | _key_to_index: Dict[GTKey, TIndex] = field(init=False, default_factory=dict) 20 | _free_indices: List[TIndex] = field(init=False, default_factory=list) 21 | _np_keys: Optional[npt.NDArray[np.object_]] = field(init=False, default=None) 22 | 23 | async def size(self) -> int: 24 | return len(self._data) 25 | 26 | async def get(self, keys: Iterable[GTKey]) -> Iterable[Optional[GTValue]]: 27 | return (self._data.get(self._key_to_index.get(key, None), None) for key in keys) 28 | 29 | async def get_by_index(self, indices: Iterable[TIndex]) -> Iterable[Optional[GTValue]]: 30 | return (self._data.get(index, None) for index in indices) 31 | 32 | async def get_index(self, keys: Iterable[GTKey]) -> Iterable[Optional[TIndex]]: 33 | return (self._key_to_index.get(key, None) for key in keys) 34 | 35 | async def upsert(self, keys: Iterable[GTKey], values: Iterable[GTValue]) -> None: 36 | for key, value in zip(keys, values): 37 | index = self._key_to_index.get(key, None) 38 | if index is None: 39 | if len(self._free_indices) > 0: 40 | index = self._free_indices.pop() 41 | else: 42 | index = TIndex(len(self._data)) 43 | self._key_to_index[key] = index 44 | 45 | # Invalidate cache 46 | self._np_keys = None 47 | self._data[index] = value 48 | 49 | async def delete(self, keys: Iterable[GTKey]) -> None: 50 | for key in keys: 51 | index = self._key_to_index.pop(key, None) 52 | if index is not None: 53 | self._free_indices.append(index) 54 | self._data.pop(index, None) 55 | 56 | # Invalidate cache 57 | self._np_keys = None 58 | else: 59 | logger.warning(f"Key '{key}' not found in indexed key-value storage.") 60 | 61 | async def mask_new(self, keys: Iterable[GTKey]) -> Iterable[bool]: 62 | keys = list(keys) 63 | 64 | if len(keys) == 0: 65 | return np.array([], dtype=bool) 66 | 67 | if self._np_keys is None: 68 | self._np_keys = np.fromiter( 69 | self._key_to_index.keys(), 70 | count=len(self._key_to_index), 71 | dtype=type(keys[0]), 72 | ) 73 | keys_array = np.array(keys, dtype=type(keys[0])) 74 | 75 | return ~np.isin(keys_array, self._np_keys) 76 | 77 | async def _insert_start(self): 78 | if self.namespace: 79 | data_file_name = self.namespace.get_load_path(self.RESOURCE_NAME) 80 | 81 | if data_file_name: 82 | try: 83 | with open(data_file_name, "rb") as f: 84 | self._data, self._free_indices, self._key_to_index = pickle.load(f) 85 | logger.debug( 86 | f"Loaded {len(self._data)} elements from indexed key-value storage '{data_file_name}'." 87 | ) 88 | except Exception as e: 89 | t = f"Error loading data file for key-vector storage '{data_file_name}': {e}" 90 | logger.error(t) 91 | raise InvalidStorageError(t) from e 92 | else: 93 | logger.info(f"No data file found for key-vector storage '{data_file_name}'. Loading empty storage.") 94 | self._data = {} 95 | self._free_indices = [] 96 | self._key_to_index = {} 97 | else: 98 | self._data = {} 99 | self._free_indices = [] 100 | self._key_to_index = {} 101 | logger.debug("Creating new volatile indexed key-value storage.") 102 | self._np_keys = None 103 | 104 | async def _insert_done(self): 105 | if self.namespace: 106 | data_file_name = self.namespace.get_save_path(self.RESOURCE_NAME) 107 | try: 108 | with open(data_file_name, "wb") as f: 109 | pickle.dump((self._data, self._free_indices, self._key_to_index), f) 110 | logger.debug(f"Saving {len(self._data)} elements to indexed key-value storage '{data_file_name}'.") 111 | except Exception as e: 112 | t = f"Error saving data file for key-vector storage '{data_file_name}': {e}" 113 | logger.error(t) 114 | raise InvalidStorageError(t) from e 115 | 116 | async def _query_start(self): 117 | assert self.namespace, "Loading a kv storage requires a namespace." 118 | data_file_name = self.namespace.get_load_path(self.RESOURCE_NAME) 119 | if data_file_name: 120 | try: 121 | with open(data_file_name, "rb") as f: 122 | self._data, self._free_indices, self._key_to_index = pickle.load(f) 123 | logger.debug( 124 | f"Loaded {len(self._data)} elements from indexed key-value storage '{data_file_name}'." 125 | ) 126 | except Exception as e: 127 | t = f"Error loading data file for key-vector storage {data_file_name}: {e}" 128 | logger.error(t) 129 | raise InvalidStorageError(t) from e 130 | else: 131 | logger.warning(f"No data file found for key-vector storage '{data_file_name}'. Loading empty storage.") 132 | self._data = {} 133 | self._free_indices = [] 134 | self._key_to_index = {} 135 | self._np_keys = None 136 | 137 | async def _query_done(self): 138 | pass 139 | -------------------------------------------------------------------------------- /fast_graphrag/_storage/_namespace.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import time 4 | from typing import Any, Callable, List, Optional 5 | 6 | from fast_graphrag._exceptions import InvalidStorageError 7 | from fast_graphrag._utils import logger 8 | 9 | 10 | class Workspace: 11 | @staticmethod 12 | def new(working_dir: str, checkpoint: int = 0, keep_n: int = 0) -> "Workspace": 13 | return Workspace(working_dir, checkpoint, keep_n) 14 | 15 | @staticmethod 16 | def get_path(working_dir: str, checkpoint: Optional[int] = None) -> Optional[str]: 17 | if checkpoint is None: 18 | return None 19 | elif checkpoint == 0: 20 | return working_dir 21 | return os.path.join(working_dir, str(checkpoint)) 22 | 23 | def __init__(self, working_dir: str, checkpoint: int = 0, keep_n: int = 0): 24 | self.working_dir: str = working_dir 25 | self.keep_n: int = keep_n 26 | if not os.path.exists(working_dir): 27 | os.makedirs(working_dir) 28 | 29 | self.checkpoints = sorted( 30 | (int(x.name) for x in os.scandir(self.working_dir) if x.is_dir() and not x.name.startswith("0__err_")), 31 | reverse=True, 32 | ) 33 | if self.checkpoints: 34 | self.current_load_checkpoint = checkpoint if checkpoint else self.checkpoints[0] 35 | else: 36 | self.current_load_checkpoint = checkpoint 37 | self.save_checkpoint: Optional[int] = None 38 | self.failed_checkpoints: List[str] = [] 39 | 40 | def __del__(self): 41 | for checkpoint in self.failed_checkpoints: 42 | old_path = os.path.join(self.working_dir, checkpoint) 43 | new_path = os.path.join(self.working_dir, f"0__err_{checkpoint}") 44 | os.rename(old_path, new_path) 45 | 46 | if self.keep_n > 0: 47 | checkpoints = sorted((x.name for x in os.scandir(self.working_dir) if x.is_dir()), reverse=True) 48 | for checkpoint in checkpoints[self.keep_n + 1 :]: 49 | shutil.rmtree(os.path.join(self.working_dir, str(checkpoint))) 50 | 51 | def make_for(self, namespace: str) -> "Namespace": 52 | return Namespace(self, namespace) 53 | 54 | def get_load_path(self) -> Optional[str]: 55 | load_path = self.get_path(self.working_dir, self.current_load_checkpoint) 56 | if load_path == self.working_dir and len([x for x in os.scandir(load_path) if x.is_file()]) == 0: 57 | return None 58 | return load_path 59 | 60 | 61 | def get_save_path(self) -> str: 62 | if self.save_checkpoint is None: 63 | if self.keep_n > 0: 64 | self.save_checkpoint = int(time.time()) 65 | else: 66 | self.save_checkpoint = 0 67 | save_path = self.get_path(self.working_dir, self.save_checkpoint) 68 | 69 | assert save_path is not None, "Save path cannot be None." 70 | 71 | if not os.path.exists(save_path): 72 | os.makedirs(save_path) 73 | return os.path.join(save_path) 74 | 75 | def _rollback(self) -> bool: 76 | if self.current_load_checkpoint is None: 77 | return False 78 | # List all directories in the working directory and select the one 79 | # with the smallest number greater then the current load checkpoint. 80 | try: 81 | self.current_load_checkpoint = next(x for x in self.checkpoints if x < self.current_load_checkpoint) 82 | logger.warning("Rolling back to checkpoint: %s", self.current_load_checkpoint) 83 | except (StopIteration, ValueError): 84 | self.current_load_checkpoint = None 85 | logger.warning("No checkpoints to rollback to. Last checkpoint tried: %s", self.current_load_checkpoint) 86 | 87 | return True 88 | 89 | async def with_checkpoints(self, fn: Callable[[], Any]) -> Any: 90 | while True: 91 | try: 92 | return await fn() 93 | except Exception as e: 94 | logger.warning("Error occurred loading checkpoint: %s", e) 95 | if self.current_load_checkpoint is not None: 96 | self.failed_checkpoints.append(str(self.current_load_checkpoint)) 97 | if self._rollback() is False: 98 | break 99 | raise InvalidStorageError("No valid checkpoints to load or default storages cannot be created.") 100 | 101 | 102 | class Namespace: 103 | def __init__(self, workspace: Workspace, namespace: Optional[str] = None): 104 | self.namespace = namespace 105 | self.workspace = workspace 106 | 107 | def get_load_path(self, resource_name: str) -> Optional[str]: 108 | assert self.namespace is not None, "Namespace must be set to get resource load path." 109 | load_path = self.workspace.get_load_path() 110 | if load_path is None: 111 | return None 112 | return os.path.join(load_path, f"{self.namespace}_{resource_name}") 113 | 114 | def get_save_path(self, resource_name: str) -> str: 115 | assert self.namespace is not None, "Namespace must be set to get resource save path." 116 | return os.path.join(self.workspace.get_save_path(), f"{self.namespace}_{resource_name}") 117 | -------------------------------------------------------------------------------- /fast_graphrag/_storage/_vdb_hnswlib.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from dataclasses import dataclass, field 3 | from typing import Any, Dict, Iterable, List, Optional, Tuple, Union 4 | 5 | import hnswlib 6 | import numpy as np 7 | import numpy.typing as npt 8 | from scipy.sparse import csr_matrix 9 | 10 | from fast_graphrag._exceptions import InvalidStorageError 11 | from fast_graphrag._types import GTEmbedding, GTId, TScore 12 | from fast_graphrag._utils import logger 13 | 14 | from ._base import BaseVectorStorage 15 | 16 | 17 | @dataclass 18 | class HNSWVectorStorageConfig: 19 | ef_construction: int = field(default=128) 20 | M: int = field(default=64) 21 | ef_search: int = field(default=96) 22 | num_threads: int = field(default=-1) 23 | 24 | 25 | @dataclass 26 | class HNSWVectorStorage(BaseVectorStorage[GTId, GTEmbedding]): 27 | RESOURCE_NAME = "hnsw_index_{}.bin" 28 | RESOURCE_METADATA_NAME = "hnsw_metadata.pkl" 29 | INITIAL_MAX_ELEMENTS = 128000 30 | config: HNSWVectorStorageConfig = field() # type: ignore 31 | _index: Any = field(init=False, default=None) # type: ignore 32 | _metadata: Dict[GTId, Dict[str, Any]] = field(default_factory=dict) 33 | 34 | @property 35 | def size(self) -> int: 36 | return self._index.get_current_count() 37 | 38 | @property 39 | def max_size(self) -> int: 40 | return self._index.get_max_elements() or self.INITIAL_MAX_ELEMENTS 41 | 42 | async def upsert( 43 | self, 44 | ids: Iterable[GTId], 45 | embeddings: Iterable[GTEmbedding], 46 | metadata: Union[Iterable[Dict[str, Any]], None] = None, 47 | ) -> None: 48 | ids = list(ids) 49 | embeddings = np.array(list(embeddings), dtype=np.float32) 50 | metadata = list(metadata) if metadata else None 51 | 52 | assert (len(ids) == len(embeddings)) and ( 53 | metadata is None or (len(metadata) == len(ids)) 54 | ), "ids, embeddings, and metadata (if provided) must have the same length" 55 | 56 | if self.size + len(embeddings) >= self.max_size: 57 | new_size = self.max_size * 2 58 | while self.size + len(embeddings) >= new_size: 59 | new_size *= 2 60 | self._index.resize_index(new_size) 61 | logger.info("Resizing HNSW index.") 62 | 63 | if metadata: 64 | self._metadata.update(dict(zip(ids, metadata))) 65 | self._index.add_items(data=embeddings, ids=ids, num_threads=self.config.num_threads) 66 | 67 | async def get_knn( 68 | self, embeddings: Iterable[GTEmbedding], top_k: int 69 | ) -> Tuple[Iterable[Iterable[GTId]], npt.NDArray[TScore]]: 70 | if self.size == 0: 71 | empty_list: List[List[GTId]] = [] 72 | logger.info("Querying knns in empty index.") 73 | return empty_list, np.array([], dtype=TScore) 74 | 75 | top_k = min(top_k, self.size) 76 | 77 | if top_k > self.config.ef_search: 78 | self._index.set_ef(top_k) 79 | 80 | # distances is [0, 2] (best, worst) 81 | ids, distances = self._index.knn_query(data=embeddings, k=top_k, num_threads=self.config.num_threads) 82 | 83 | return ids, 1.0 - np.array(distances, dtype=TScore) * 0.5 84 | 85 | async def score_all( 86 | self, embeddings: Iterable[GTEmbedding], top_k: int = 1, threshold: Optional[float] = None 87 | ) -> csr_matrix: 88 | if not isinstance(embeddings, np.ndarray): 89 | embeddings = np.array(list(embeddings), dtype=np.float32) 90 | 91 | if embeddings.size == 0 or self.size == 0: 92 | logger.warning(f"No provided embeddings ({embeddings.size}) or empty index ({self.size}).") 93 | return csr_matrix((0, self.size)) 94 | 95 | top_k = min(top_k, self.size) 96 | if top_k > self.config.ef_search: 97 | self._index.set_ef(top_k) 98 | 99 | # distances is [0, 2] (best, worst) 100 | ids, distances = self._index.knn_query(data=embeddings, k=top_k, num_threads=self.config.num_threads) 101 | 102 | ids = np.array(ids) 103 | scores = 1.0 - np.array(distances, dtype=TScore) * 0.5 104 | 105 | if threshold is not None: 106 | scores[scores < threshold] = 0 107 | 108 | # Create sparse distance matrix with shape (#embeddings, #all_embeddings) 109 | flattened_ids = ids.ravel() 110 | flattened_scores = scores.ravel() 111 | 112 | scores = csr_matrix( 113 | (flattened_scores, (np.repeat(np.arange(len(ids)), top_k), flattened_ids)), 114 | shape=(len(ids), self.size), 115 | ) 116 | 117 | return scores 118 | 119 | async def _insert_start(self): 120 | self._index = hnswlib.Index(space="cosine", dim=self.embedding_dim) # type: ignore 121 | 122 | if self.namespace: 123 | index_file_name = self.namespace.get_load_path(self.RESOURCE_NAME.format(self.embedding_dim)) 124 | metadata_file_name = self.namespace.get_load_path(self.RESOURCE_METADATA_NAME) 125 | 126 | if index_file_name and metadata_file_name: 127 | try: 128 | self._index.load_index(index_file_name, allow_replace_deleted=True) 129 | with open(metadata_file_name, "rb") as f: 130 | self._metadata = pickle.load(f) 131 | logger.debug( 132 | f"Loaded {self.size} elements from vectordb storage '{index_file_name}'." 133 | ) 134 | return # All good 135 | except Exception as e: 136 | t = f"Error loading metadata file for vectordb storage '{metadata_file_name}': {e}" 137 | logger.error(t) 138 | raise InvalidStorageError(t) from e 139 | else: 140 | logger.info(f"No data file found for vectordb storage '{index_file_name}'. Loading empty vectordb.") 141 | else: 142 | logger.debug("Creating new volatile vectordb storage.") 143 | self._index.init_index( 144 | max_elements=self.INITIAL_MAX_ELEMENTS, 145 | ef_construction=self.config.ef_construction, 146 | M=self.config.M, 147 | allow_replace_deleted=True 148 | ) 149 | self._index.set_ef(self.config.ef_search) 150 | self._metadata = {} 151 | 152 | async def _insert_done(self): 153 | if self.namespace: 154 | index_file_name = self.namespace.get_save_path(self.RESOURCE_NAME.format(self.embedding_dim)) 155 | metadata_file_name = self.namespace.get_save_path(self.RESOURCE_METADATA_NAME) 156 | 157 | try: 158 | self._index.save_index(index_file_name) 159 | with open(metadata_file_name, "wb") as f: 160 | pickle.dump(self._metadata, f) 161 | logger.debug(f"Saving {self.size} elements from vectordb storage '{index_file_name}'.") 162 | except Exception as e: 163 | t = f"Error saving vectordb storage from {index_file_name}: {e}" 164 | logger.error(t) 165 | raise InvalidStorageError(t) from e 166 | 167 | async def _query_start(self): 168 | assert self.namespace, "Loading a vectordb requires a namespace." 169 | self._index = hnswlib.Index(space="cosine", dim=self.embedding_dim) # type: ignore 170 | 171 | index_file_name = self.namespace.get_load_path(self.RESOURCE_NAME.format(self.embedding_dim)) 172 | metadata_file_name = self.namespace.get_load_path(self.RESOURCE_METADATA_NAME) 173 | if index_file_name and metadata_file_name: 174 | try: 175 | self._index.load_index(index_file_name, allow_replace_deleted=True) 176 | with open(metadata_file_name, "rb") as f: 177 | self._metadata = pickle.load(f) 178 | logger.debug(f"Loaded {self.size} elements from vectordb storage '{index_file_name}'.") 179 | 180 | return # All good 181 | except Exception as e: 182 | t = f"Error loading vectordb storage from {index_file_name}: {e}" 183 | logger.error(t) 184 | raise InvalidStorageError(t) from e 185 | else: 186 | logger.warning(f"No data file found for vectordb storage '{index_file_name}'. Loading empty vectordb.") 187 | self._index.init_index( 188 | max_elements=self.INITIAL_MAX_ELEMENTS, 189 | ef_construction=self.config.ef_construction, 190 | M=self.config.M, 191 | allow_replace_deleted=True 192 | ) 193 | self._index.set_ef(self.config.ef_search) 194 | self._metadata = {} 195 | 196 | async def _query_done(self): 197 | pass 198 | -------------------------------------------------------------------------------- /fast_graphrag/_utils.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import time 4 | from functools import wraps 5 | from typing import Any, Callable, List, Optional, Tuple, TypeVar, Union 6 | 7 | import numpy as np 8 | import numpy.typing as npt 9 | from scipy.sparse import csr_matrix 10 | 11 | from fast_graphrag._types import TIndex 12 | 13 | logger = logging.getLogger("graphrag") 14 | TOKEN_TO_CHAR_RATIO = 4 15 | 16 | 17 | def timeit(func: Callable[..., Any]): 18 | @wraps(func) 19 | async def wrapper(*args: Any, **kwargs: Any) -> Any: 20 | start = time.time() 21 | result = await func(*args, **kwargs) 22 | duration = time.time() - start 23 | wrapper.execution_times.append(duration) # type: ignore 24 | return result 25 | 26 | wrapper.execution_times = [] # type: ignore 27 | return wrapper 28 | 29 | 30 | def throttle_async_func_call( 31 | max_concurrent: int = 2048, 32 | stagger_time: Optional[float] = None, 33 | waiting_time: float = 0.001, 34 | ): 35 | _wrappedFn = TypeVar("_wrappedFn", bound=Callable[..., Any]) 36 | 37 | def decorator(func: _wrappedFn) -> _wrappedFn: 38 | semaphore = asyncio.Semaphore(max_concurrent) 39 | 40 | @wraps(func) 41 | async def wait_func(*args: Any, **kwargs: Any) -> Any: 42 | async with semaphore: 43 | try: 44 | if stagger_time: 45 | await asyncio.sleep(stagger_time) 46 | return await func(*args, **kwargs) 47 | except Exception as e: 48 | logger.error(f"Error in throttled function {func.__name__}: {e}") 49 | raise e 50 | 51 | return wait_func # type: ignore 52 | 53 | return decorator 54 | 55 | 56 | def get_event_loop() -> asyncio.AbstractEventLoop: 57 | try: 58 | # If there is already an event loop, use it. 59 | loop = asyncio.get_event_loop() 60 | except RuntimeError: 61 | # If in a sub-thread, create a new event loop. 62 | loop = asyncio.new_event_loop() 63 | asyncio.set_event_loop(loop) 64 | return loop 65 | 66 | 67 | def extract_sorted_scores( 68 | row_vector: csr_matrix, 69 | ) -> Tuple[npt.NDArray[np.int64], npt.NDArray[np.float32]]: 70 | """Take a sparse row vector and return a list of non-zero (index, score) pairs sorted by score.""" 71 | assert row_vector.shape[0] <= 1, "The input matrix must be a row vector." 72 | if row_vector.shape[0] == 0: 73 | return np.array([], dtype=np.int64), np.array([], dtype=np.float32) 74 | 75 | # Step 1: Get the indices of non-zero elements 76 | non_zero_indices = row_vector.nonzero()[1] 77 | 78 | # Step 2: Extract the probabilities of these indices 79 | probabilities = row_vector.data 80 | 81 | # Step 3: Use NumPy to create arrays for indices and probabilities 82 | indices_array = np.array(non_zero_indices) 83 | probabilities_array = np.array(probabilities) 84 | 85 | # Step 4: Sort the probabilities and get the sorted indices 86 | sorted_indices = np.argsort(probabilities_array)[::-1] 87 | 88 | # Step 5: Create sorted arrays for indices and probabilities 89 | sorted_indices_array = indices_array[sorted_indices] 90 | sorted_probabilities_array = probabilities_array[sorted_indices] 91 | 92 | return sorted_indices_array, sorted_probabilities_array 93 | 94 | 95 | def csr_from_indices_list( 96 | data: List[List[Union[int, TIndex]]], shape: Tuple[int, int] 97 | ) -> csr_matrix: 98 | """Create a CSR matrix from a list of lists.""" 99 | num_rows = len(data) 100 | 101 | # Flatten the list of lists and create corresponding row indices 102 | row_indices = np.repeat(np.arange(num_rows), [len(row) for row in data]) 103 | col_indices = np.concatenate(data) if num_rows > 0 else np.array([], dtype=np.int64) 104 | 105 | # Data values (all ones in this case) 106 | values = np.broadcast_to(1, len(row_indices)) 107 | 108 | # Create the CSR matrix 109 | return csr_matrix((values, (row_indices, col_indices)), shape=shape) 110 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "fast-graphrag" 3 | version = "0.0.5" 4 | description = "" 5 | authors = ["Luca Pinchetti ", "Antonio Vespoli ", "Yuhang Song "] 6 | packages = [{include = "fast_graphrag" }] 7 | readme = "README.md" 8 | 9 | [tool.poetry.dependencies] 10 | python = ">=3.10,<3.13" 11 | igraph = "^0.11.6" 12 | xxhash = "^3.5.0" 13 | pydantic = "^2.9.2" 14 | scipy = "^1.14.1" 15 | scikit-learn = "^1.5.2" 16 | tenacity = "^9.0.0" 17 | openai = "^1.52.1" 18 | scipy-stubs = "^1.14.1.5" 19 | hnswlib = "^0.8.0" 20 | instructor = "^1.6.3" 21 | requests = "^2.32.3" 22 | python-dotenv = "^1.0.1" 23 | tiktoken = "^0.8.0" 24 | aiolimiter = "^1.1.0" 25 | google-genai = "^1.3.0" 26 | vertexai = "^1.71.1" 27 | sentencepiece = "^0.2.0" 28 | json-repair = "^0.39.1" 29 | voyageai = "^0.3.2" 30 | 31 | 32 | [tool.poetry.group.dev.dependencies] 33 | ruff = "^0.7.0" 34 | 35 | [build-system] 36 | requires = ["poetry-core"] 37 | build-backend = "poetry.core.masonry.api" 38 | 39 | [tool.ruff] 40 | line-length = 120 41 | indent-width = 2 42 | 43 | [tool.ruff.lint] 44 | select = [ 45 | "E", # pycodestyle errors 46 | "W", # pycodestyle warnings 47 | "F", # pyflakes 48 | "I", # isort 49 | "B", # flake8-bugbear 50 | "C4", # flake8-comprehensions 51 | "N", # PEP8 naming convetions 52 | "D" # pydocstyle 53 | ] 54 | ignore = [ 55 | "C901", # too complex 56 | "W191", # indentation contains tabs 57 | "D401" # imperative mood 58 | ] 59 | 60 | [tool.ruff.lint.pydocstyle] 61 | convention = "google" 62 | 63 | [tool.ruff.lint.per-file-ignores] 64 | "_prompt.py" = ["E501"] 65 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Testing.""" 2 | -------------------------------------------------------------------------------- /tests/_graphrag_test.py: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | import unittest 3 | from dataclasses import dataclass 4 | from unittest.mock import AsyncMock, MagicMock, patch 5 | 6 | from fast_graphrag._graphrag import BaseGraphRAG 7 | from fast_graphrag._models import TAnswer 8 | from fast_graphrag._types import TContext, TQueryResponse 9 | 10 | 11 | class TestBaseGraphRAG(unittest.IsolatedAsyncioTestCase): 12 | def setUp(self): 13 | self.llm_service = AsyncMock() 14 | self.chunking_service = AsyncMock() 15 | self.information_extraction_service = MagicMock() 16 | self.information_extraction_service.extract_entities_from_query = AsyncMock() 17 | self.state_manager = AsyncMock() 18 | self.state_manager.embedding_service.embedding_dim = self.state_manager.entity_storage.embedding_dim = 1 19 | 20 | @dataclass 21 | class BaseGraphRAGNoEmbeddingValidation(BaseGraphRAG): 22 | def __post_init__(self): 23 | pass 24 | 25 | self.graph_rag = BaseGraphRAGNoEmbeddingValidation( 26 | working_dir="test_dir", 27 | domain="test_domain", 28 | example_queries="test_query", 29 | entity_types=["type1", "type2"], 30 | ) 31 | self.graph_rag.llm_service = self.llm_service 32 | self.graph_rag.chunking_service = self.chunking_service 33 | self.graph_rag.information_extraction_service = self.information_extraction_service 34 | self.graph_rag.state_manager = self.state_manager 35 | 36 | async def test_async_insert(self): 37 | self.chunking_service.extract = AsyncMock(return_value=["chunked_data"]) 38 | self.state_manager.filter_new_chunks = AsyncMock(return_value=["new_chunks"]) 39 | self.information_extraction_service.extract = MagicMock(return_value=["subgraph"]) 40 | self.state_manager.upsert = AsyncMock() 41 | 42 | await self.graph_rag.async_insert("test_content", {"meta": "data"}) 43 | 44 | self.chunking_service.extract.assert_called_once() 45 | self.state_manager.filter_new_chunks.assert_called_once() 46 | self.information_extraction_service.extract.assert_called_once() 47 | self.state_manager.upsert.assert_called_once() 48 | 49 | @patch("fast_graphrag._graphrag.format_and_send_prompt", new_callable=AsyncMock) 50 | async def test_async_query(self, format_and_send_prompt): 51 | self.information_extraction_service.extract_entities_from_query = AsyncMock(return_value=["entities"]) 52 | self.state_manager.get_context = AsyncMock(return_value=TContext([], [], [])) 53 | format_and_send_prompt.return_value=(TAnswer(answer="response"), None) 54 | 55 | response = await self.graph_rag.async_query("test_query") 56 | 57 | self.information_extraction_service.extract_entities_from_query.assert_called_once() 58 | self.state_manager.get_context.assert_called_once() 59 | format_and_send_prompt.assert_called_once() 60 | self.assertIsInstance(response, TQueryResponse) 61 | 62 | 63 | if __name__ == "__main__": 64 | unittest.main() 65 | -------------------------------------------------------------------------------- /tests/_llm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/circlemind-ai/fast-graphrag/b370efe01ef836af292a3713d59b2ec23d2fe7c4/tests/_llm/__init__.py -------------------------------------------------------------------------------- /tests/_llm/_base_test.py: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | import unittest 3 | from unittest.mock import AsyncMock, patch 4 | 5 | from pydantic import BaseModel 6 | 7 | from fast_graphrag._llm._base import BaseLLMService, format_and_send_prompt 8 | 9 | # Assuming these are defined somewhere in your codebase 10 | PROMPTS = { 11 | "example_prompt": "Hello, {name}!" 12 | } 13 | 14 | class TestModel(BaseModel): 15 | answer: str 16 | 17 | class TestFormatAndSendPrompt(unittest.IsolatedAsyncioTestCase): 18 | 19 | @patch("fast_graphrag._llm._base.PROMPTS", PROMPTS) 20 | async def test_format_and_send_prompt(self): 21 | mock_llm = AsyncMock(spec=BaseLLMService(model="")) 22 | answer = TestModel(answer="TEST") 23 | mock_response = (answer, [{"key": "value"}]) 24 | mock_llm.send_message = AsyncMock(return_value=mock_response) 25 | 26 | result = await format_and_send_prompt( 27 | prompt_key="example_prompt", 28 | llm=mock_llm, 29 | format_kwargs={"name": "World"}, 30 | response_model=TestModel 31 | ) 32 | 33 | mock_llm.send_message.assert_called_once_with( 34 | prompt="Hello, World!", 35 | response_model=TestModel 36 | ) 37 | self.assertEqual(result, mock_response) 38 | 39 | @patch("fast_graphrag._llm._base.PROMPTS", PROMPTS) 40 | async def test_format_and_send_prompt_with_additional_args(self): 41 | mock_llm = AsyncMock(spec=BaseLLMService(model="")) 42 | answer = TestModel(answer="TEST") 43 | mock_response = (answer, [{"key": "value"}]) 44 | mock_llm.send_message = AsyncMock(return_value=mock_response) 45 | 46 | result = await format_and_send_prompt( 47 | prompt_key="example_prompt", 48 | llm=mock_llm, 49 | format_kwargs={"name": "World"}, 50 | response_model=TestModel, 51 | model="test_model", 52 | max_tokens=100 53 | ) 54 | 55 | mock_llm.send_message.assert_called_once_with( 56 | prompt="Hello, World!", 57 | response_model=TestModel, 58 | model="test_model", 59 | max_tokens=100 60 | ) 61 | self.assertEqual(result, mock_response) 62 | 63 | if __name__ == "__main__": 64 | unittest.main() 65 | -------------------------------------------------------------------------------- /tests/_llm/_llm_openai_test.py: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | import os 3 | import unittest 4 | from unittest.mock import AsyncMock, MagicMock 5 | 6 | import instructor 7 | from openai import APIConnectionError, AsyncOpenAI, RateLimitError 8 | from tenacity import RetryError 9 | 10 | from fast_graphrag._exceptions import LLMServiceNoResponseError 11 | from fast_graphrag._llm._llm_openai import OpenAIEmbeddingService, OpenAILLMService 12 | 13 | os.environ["OPENAI_API_KEY"] = "" 14 | 15 | 16 | RateLimitError429 = RateLimitError(message="Rate limit exceeded", response=MagicMock(), body=None) 17 | 18 | 19 | class TestOpenAILLMService(unittest.IsolatedAsyncioTestCase): 20 | async def test_send_message_success(self): 21 | service = OpenAILLMService(api_key="test") 22 | mock_response = str("Hi!") 23 | service.llm_async_client = AsyncMock() 24 | service.llm_async_client.chat.completions.create = AsyncMock(return_value=mock_response) 25 | 26 | response, messages = await service.send_message(prompt="Hello") 27 | 28 | self.assertEqual(response, mock_response) 29 | self.assertEqual(messages[-1]["role"], "assistant") 30 | 31 | async def test_send_message_no_response(self): 32 | service = OpenAILLMService(api_key="test") 33 | service.llm_async_client = AsyncMock() 34 | service.llm_async_client.chat.completions.create.return_value = None 35 | 36 | with self.assertRaises(LLMServiceNoResponseError): 37 | await service.send_message(prompt="Hello") 38 | 39 | async def test_send_message_rate_limit_error(self): 40 | service = OpenAILLMService() 41 | mock_response = str("Hi!") 42 | async_open_ai = AsyncOpenAI(api_key="test") 43 | async_open_ai.chat.completions.create = AsyncMock( 44 | side_effect=(RateLimitError429, mock_response) 45 | ) 46 | service.llm_async_client: instructor.AsyncInstructor = instructor.from_openai( 47 | async_open_ai 48 | ) 49 | 50 | response, messages = await service.send_message(prompt="Hello", response_model=None) 51 | 52 | self.assertEqual(response, mock_response) 53 | self.assertEqual(messages[-1]["role"], "assistant") 54 | 55 | async def test_send_message_api_connection_error(self): 56 | service = OpenAILLMService() 57 | mock_response = str("Hi!") 58 | async_open_ai = AsyncOpenAI(api_key="test") 59 | async_open_ai.chat.completions.create = AsyncMock( 60 | side_effect=(APIConnectionError(request=MagicMock()), mock_response) 61 | ) 62 | service.llm_async_client: instructor.AsyncInstructor = instructor.from_openai( 63 | async_open_ai 64 | ) 65 | 66 | response, messages = await service.send_message(prompt="Hello") 67 | 68 | self.assertEqual(response, mock_response) 69 | self.assertEqual(messages[-1]["role"], "assistant") 70 | 71 | async def test_send_message_with_system_prompt(self): 72 | service = OpenAILLMService(api_key="test") 73 | mock_response = str("Hi!") 74 | service.llm_async_client = AsyncMock() 75 | service.llm_async_client.chat.completions.create = AsyncMock(return_value=mock_response) 76 | 77 | response, messages = await service.send_message( 78 | prompt="Hello", system_prompt="System prompt" 79 | ) 80 | 81 | self.assertEqual(response, mock_response) 82 | self.assertEqual(messages[0]["role"], "system") 83 | self.assertEqual(messages[0]["content"], "System prompt") 84 | 85 | async def test_send_message_with_history(self): 86 | service = OpenAILLMService(api_key="test") 87 | mock_response = str("Hi!") 88 | service.llm_async_client = AsyncMock() 89 | service.llm_async_client.chat.completions.create = AsyncMock(return_value=mock_response) 90 | 91 | history = [{"role": "user", "content": "Previous message"}] 92 | response, messages = await service.send_message(prompt="Hello", history_messages=history) 93 | 94 | self.assertEqual(response, mock_response) 95 | self.assertEqual(messages[0]["role"], "user") 96 | self.assertEqual(messages[0]["content"], "Previous message") 97 | 98 | 99 | class TestOpenAIEmbeddingService(unittest.IsolatedAsyncioTestCase): 100 | async def test_get_embedding_success(self): 101 | service = OpenAIEmbeddingService(api_key="test") 102 | mock_response = AsyncMock() 103 | mock_response.data = [AsyncMock(embedding=[0.1, 0.2, 0.3])] 104 | service.embedding_async_client.embeddings.create = AsyncMock(return_value=mock_response) 105 | 106 | embeddings = await service.encode(texts=["test"], model="text-embedding-3-small") 107 | 108 | self.assertEqual(embeddings.shape, (1, 3)) 109 | self.assertEqual(embeddings[0][0], 0.1) 110 | 111 | async def test_get_embedding_rate_limit_error(self): 112 | service = OpenAIEmbeddingService(api_key="test") 113 | mock_response = AsyncMock() 114 | mock_response.data = [AsyncMock(embedding=[0.1, 0.2, 0.3])] 115 | service.embedding_async_client.embeddings.create = AsyncMock(side_effect=(RateLimitError429, mock_response)) 116 | 117 | embeddings = await service.encode(texts=["test"], model="text-embedding-3-small") 118 | 119 | self.assertEqual(embeddings.shape, (1, 3)) 120 | self.assertEqual(embeddings[0][0], 0.1) 121 | 122 | async def test_get_embedding_api_connection_error(self): 123 | service = OpenAIEmbeddingService(api_key="test") 124 | mock_response = AsyncMock() 125 | mock_response.data = [AsyncMock(embedding=[0.1, 0.2, 0.3])] 126 | service.embedding_async_client.embeddings.create = AsyncMock( 127 | side_effect=(APIConnectionError(request=MagicMock()), mock_response) 128 | ) 129 | embeddings = await service.encode(texts=["test"], model="text-embedding-3-small") 130 | 131 | self.assertEqual(embeddings.shape, (1, 3)) 132 | self.assertEqual(embeddings[0][0], 0.1) 133 | 134 | async def test_get_embedding_retry_failure(self): 135 | service = OpenAIEmbeddingService(api_key="test") 136 | service.embedding_async_client.embeddings.create = AsyncMock( 137 | side_effect=RateLimitError429 138 | ) 139 | 140 | with self.assertRaises(RetryError): 141 | await service.encode(texts=["test"], model="text-embedding-3-small") 142 | 143 | async def test_get_embedding_with_different_model(self): 144 | service = OpenAIEmbeddingService(api_key="test") 145 | mock_response = AsyncMock() 146 | mock_response.data = [AsyncMock(embedding=[0.4, 0.5, 0.6])] 147 | service.embedding_async_client.embeddings.create = AsyncMock(return_value=mock_response) 148 | 149 | embeddings = await service.encode(texts=["test"], model="text-embedding-3-large") 150 | 151 | self.assertEqual(embeddings.shape, (1, 3)) 152 | self.assertEqual(embeddings[0][0], 0.4) 153 | 154 | 155 | if __name__ == "__main__": 156 | unittest.main() 157 | -------------------------------------------------------------------------------- /tests/_models_test.py: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | import unittest 3 | 4 | from pydantic import ValidationError 5 | 6 | from fast_graphrag._models import ( 7 | TEditRelation, 8 | TEditRelationList, 9 | TQueryEntities, 10 | dump_to_csv, 11 | dump_to_reference_list, 12 | ) 13 | from fast_graphrag._types import TEntity 14 | 15 | 16 | class TestModels(unittest.TestCase): 17 | def test_tqueryentities(self): 18 | query_entities = TQueryEntities(named=["Entity1", "Entity2"], generic=["Generic1", "Generic2"]) 19 | self.assertEqual(query_entities.named, ["ENTITY1", "ENTITY2"]) 20 | self.assertEqual(query_entities.generic, ["Generic1", "Generic2"]) 21 | 22 | with self.assertRaises(ValidationError): 23 | TQueryEntities(entities=["Entity1", "Entity2"], n="two") 24 | 25 | def test_teditrelationship(self): 26 | edit_relationship = TEditRelation(ids=[1, 2], description="Combined relationship description") 27 | self.assertEqual(edit_relationship.ids, [1, 2]) 28 | self.assertEqual(edit_relationship.description, "Combined relationship description") 29 | 30 | def test_teditrelationshiplist(self): 31 | edit_relationship = TEditRelation(ids=[1, 2], description="Combined relationship description") 32 | edit_relationship_list = TEditRelationList(grouped_facts=[edit_relationship]) 33 | self.assertEqual(edit_relationship_list.groups, [edit_relationship]) 34 | 35 | def test_dump_to_csv(self): 36 | data = [TEntity(name="Sample name", type="SAMPLE TYPE", description="Sample description")] 37 | fields = ["name", "type"] 38 | values = {"score": [0.9]} 39 | csv_output = dump_to_csv(data, fields, with_header=True, **values) 40 | expected_output = ["name\ttype\tscore", "Sample name\tSAMPLE TYPE\t0.9"] 41 | self.assertEqual(csv_output, expected_output) 42 | 43 | 44 | class TestDumpToReferenceList(unittest.TestCase): 45 | def test_empty_list(self): 46 | self.assertEqual(dump_to_reference_list([]), []) 47 | 48 | def test_single_element(self): 49 | self.assertEqual(dump_to_reference_list(["item"]), ["[1] item\n=====\n\n"]) 50 | 51 | def test_multiple_elements(self): 52 | data = ["item1", "item2", "item3"] 53 | expected = [ 54 | "[1] item1\n=====\n\n", 55 | "[2] item2\n=====\n\n", 56 | "[3] item3\n=====\n\n" 57 | ] 58 | self.assertEqual(dump_to_reference_list(data), expected) 59 | 60 | def test_custom_separator(self): 61 | data = ["item1", "item2"] 62 | separator = " | " 63 | expected = [ 64 | "[1] item1 | ", 65 | "[2] item2 | " 66 | ] 67 | self.assertEqual(dump_to_reference_list(data, separator), expected) 68 | 69 | 70 | class TestDumpToCsv(unittest.TestCase): 71 | def test_empty_data(self): 72 | self.assertEqual(dump_to_csv([], ["field1", "field2"]), []) 73 | 74 | def test_single_element(self): 75 | class Data: 76 | def __init__(self, field1, field2): 77 | self.field1 = field1 78 | self.field2 = field2 79 | 80 | data = [Data("value1", "value2")] 81 | expected = ["value1\tvalue2"] 82 | self.assertEqual(dump_to_csv(data, ["field1", "field2"]), expected) 83 | 84 | def test_multiple_elements(self): 85 | class Data: 86 | def __init__(self, field1, field2): 87 | self.field1 = field1 88 | self.field2 = field2 89 | 90 | data = [Data("value1", "value2"), Data("value3", "value4")] 91 | expected = ["value1\tvalue2", "value3\tvalue4"] 92 | self.assertEqual(dump_to_csv(data, ["field1", "field2"]), expected) 93 | 94 | def test_with_header(self): 95 | class Data: 96 | def __init__(self, field1, field2): 97 | self.field1 = field1 98 | self.field2 = field2 99 | 100 | data = [Data("value1", "value2")] 101 | expected = ["field1\tfield2", "value1\tvalue2"] 102 | self.assertEqual(dump_to_csv(data, ["field1", "field2"], with_header=True), expected) 103 | 104 | def test_custom_separator(self): 105 | class Data: 106 | def __init__(self, field1, field2): 107 | self.field1 = field1 108 | self.field2 = field2 109 | 110 | data = [Data("value1", "value2")] 111 | expected = ["value1 | value2"] 112 | self.assertEqual(dump_to_csv(data, ["field1", "field2"], separator=" | "), expected) 113 | 114 | def test_additional_values(self): 115 | class Data: 116 | def __init__(self, field1, field2): 117 | self.field1 = field1 118 | self.field2 = field2 119 | 120 | data = [Data("value1", "value2")] 121 | expected = ["value1\tvalue2\tvalue3"] 122 | self.assertEqual(dump_to_csv(data, ["field1", "field2"], value3=["value3"]), expected) 123 | 124 | 125 | if __name__ == "__main__": 126 | unittest.main() 127 | -------------------------------------------------------------------------------- /tests/_policies/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/circlemind-ai/fast-graphrag/b370efe01ef836af292a3713d59b2ec23d2fe7c4/tests/_policies/__init__.py -------------------------------------------------------------------------------- /tests/_policies/_ranking_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | from scipy.sparse import csr_matrix 5 | 6 | from fast_graphrag._policies._ranking import ( 7 | RankingPolicy_Elbow, 8 | RankingPolicy_TopK, 9 | # RankingPolicy_WithConfidence, 10 | RankingPolicy_WithThreshold, 11 | ) 12 | 13 | 14 | class TestRankingPolicyWithThreshold(unittest.TestCase): 15 | def test_threshold(self): 16 | policy = RankingPolicy_WithThreshold(RankingPolicy_WithThreshold.Config(0.1)) 17 | scores = csr_matrix([0.05, 0.2, 0.15, 0.05]) 18 | result = policy(scores) 19 | expected = csr_matrix([0, 0.2, 0.15, 0]) 20 | np.testing.assert_array_equal(result.toarray(), expected.toarray()) 21 | 22 | def test_all_below_threshold(self): 23 | policy = RankingPolicy_WithThreshold(RankingPolicy_WithThreshold.Config(0.1)) 24 | scores = csr_matrix([0.05, 0.05, 0.05, 0.05]) 25 | result = policy(scores) 26 | expected = csr_matrix([], shape=(1, 4)) 27 | np.testing.assert_array_equal(result.toarray(), expected.toarray()) 28 | 29 | def test_explicit_batch_size_1(self): 30 | policy = RankingPolicy_WithThreshold(RankingPolicy_WithThreshold.Config(0.1)) 31 | scores = csr_matrix([[0.05, 0.2, 0.15, 0.05]]) 32 | result = policy(scores) 33 | expected = csr_matrix([[0, 0.2, 0.15, 0]]) 34 | np.testing.assert_array_equal(result.toarray(), expected.toarray()) 35 | 36 | def test_all_above_threshold(self): 37 | policy = RankingPolicy_WithThreshold(RankingPolicy_WithThreshold.Config(0.1)) 38 | scores = csr_matrix([0.15, 0.2, 0.25, 0.35]) 39 | result = policy(scores) 40 | expected = csr_matrix([0.15, 0.2, 0.25, 0.35]) 41 | np.testing.assert_array_equal(result.toarray(), expected.toarray()) 42 | 43 | 44 | class TestRankingPolicyTopK(unittest.TestCase): 45 | def test_top_k(self): 46 | policy = RankingPolicy_TopK(RankingPolicy_TopK.Config(2)) 47 | scores = csr_matrix([0.05, 0.05, 0.2, 0.15, 0.25]) 48 | result = policy(scores) 49 | expected = csr_matrix([0, 0, 0.2, 0.0, 0.25]) 50 | np.testing.assert_array_equal(result.toarray(), expected.toarray()) 51 | 52 | policy = RankingPolicy_TopK(RankingPolicy_TopK.Config(1)) 53 | result = policy(scores) 54 | expected = csr_matrix([0, 0, 0.0, 0.0, 0.25]) 55 | np.testing.assert_array_equal(result.toarray(), expected.toarray()) 56 | 57 | def test_top_k_less_than_k(self): 58 | policy = RankingPolicy_TopK(RankingPolicy_TopK.Config(5)) 59 | scores = csr_matrix([0.05, 0.2, 0.0, 0.15]) 60 | result = policy(scores) 61 | expected = csr_matrix([0.05, 0.2, 0.0, 0.15]) 62 | np.testing.assert_array_equal(result.toarray(), expected.toarray()) 63 | 64 | def test_top_k_is_zero(self): 65 | policy = RankingPolicy_TopK(RankingPolicy_TopK.Config(0)) 66 | scores = csr_matrix([0.05, 0.2, 0.15, 0.25]) 67 | result = policy(scores) 68 | expected = csr_matrix([0.05, 0.2, 0.15, 0.25]) 69 | np.testing.assert_array_equal(result.toarray(), expected.toarray()) 70 | 71 | def test_top_k_all_zero(self): 72 | policy = RankingPolicy_TopK(RankingPolicy_TopK.Config(2)) 73 | scores = csr_matrix([0, 0, 0, 0, 0]) 74 | result = policy(scores) 75 | expected = csr_matrix([0, 0, 0, 0, 0]) 76 | np.testing.assert_array_equal(result.toarray(), expected.toarray()) 77 | 78 | 79 | class TestRankingPolicyElbow(unittest.TestCase): 80 | def test_elbow(self): 81 | policy = RankingPolicy_Elbow(config=None) 82 | scores = csr_matrix([0.05, 0.2, 0.1, 0.25, 0.1]) 83 | result = policy(scores) 84 | expected = csr_matrix([0, 0.2, 0.0, 0.25, 0]) 85 | np.testing.assert_array_equal(result.toarray(), expected.toarray()) 86 | 87 | def test_elbow_all_zero(self): 88 | policy = RankingPolicy_Elbow(config=None) 89 | scores = csr_matrix([0, 0, 0, 0, 0]) 90 | result = policy(scores) 91 | expected = csr_matrix([0, 0, 0, 0, 0]) 92 | np.testing.assert_array_equal(result.toarray(), expected.toarray()) 93 | 94 | def test_elbow_all_same(self): 95 | policy = RankingPolicy_Elbow(config=None) 96 | scores = csr_matrix([0.05, 0.05, 0.05, 0.05, 0.05]) 97 | result = policy(scores) 98 | expected = csr_matrix([0, 0.05, 0.05, 0.05, 0.05]) 99 | np.testing.assert_array_equal(result.toarray(), expected.toarray()) 100 | 101 | # class TestRankingPolicyWithConfidence(unittest.TestCase): 102 | # def test_not_implemented(self): 103 | # policy = RankingPolicy_WithConfidence() 104 | # scores = csr_matrix([0.05, 0.2, 0.15, 0.25, 0.1]) 105 | # with self.assertRaises(NotImplementedError): 106 | # policy(scores) 107 | 108 | if __name__ == '__main__': 109 | unittest.main() 110 | -------------------------------------------------------------------------------- /tests/_services/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/circlemind-ai/fast-graphrag/b370efe01ef836af292a3713d59b2ec23d2fe7c4/tests/_services/__init__.py -------------------------------------------------------------------------------- /tests/_services/_chunk_extraction_test.py: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | import unittest 3 | from dataclasses import dataclass 4 | from typing import Any, Dict 5 | from unittest.mock import patch 6 | 7 | import xxhash 8 | 9 | from fast_graphrag._services._chunk_extraction import DefaultChunkingService 10 | from fast_graphrag._types import THash 11 | 12 | 13 | @dataclass 14 | class MockDocument: 15 | data: str 16 | metadata: Dict[str, Any] 17 | 18 | 19 | @dataclass 20 | class MockChunk: 21 | id: THash 22 | content: str 23 | metadata: Dict[str, Any] 24 | 25 | 26 | class TestDefaultChunkingService(unittest.IsolatedAsyncioTestCase): 27 | def setUp(self): 28 | self.chunking_service = DefaultChunkingService() 29 | 30 | async def test_extract(self): 31 | doc1 = MockDocument(data="test data 1", metadata={"meta": "data1"}) 32 | doc2 = MockDocument(data="test data 2", metadata={"meta": "data2"}) 33 | documents = [doc1, doc2] 34 | 35 | with patch.object( 36 | self.chunking_service, 37 | "_extract_chunks", 38 | return_value=[ 39 | MockChunk(id=THash(xxhash.xxh3_64_intdigest(doc1.data) // 2), content=doc1.data, metadata=doc1.metadata) 40 | ], 41 | ) as mock_extract_chunks: 42 | chunks = await self.chunking_service.extract(documents) 43 | 44 | self.assertEqual(len(chunks), 2) 45 | self.assertEqual(len(chunks[0]), 1) 46 | self.assertEqual(chunks[0][0].content, "test data 1") 47 | self.assertEqual(chunks[0][0].metadata, {"meta": "data1"}) 48 | mock_extract_chunks.assert_called() 49 | 50 | async def test_extract_with_duplicates(self): 51 | doc1 = MockDocument(data="test data 1", metadata={"meta": "data1"}) 52 | doc2 = MockDocument(data="test data 1", metadata={"meta": "data1"}) 53 | documents = [doc1, doc2] 54 | 55 | with patch.object( 56 | self.chunking_service, 57 | "_extract_chunks", 58 | return_value=[ 59 | MockChunk(id=THash(xxhash.xxh3_64_intdigest(doc1.data) // 2), content=doc1.data, metadata=doc1.metadata) 60 | ], 61 | ) as mock_extract_chunks: 62 | chunks = await self.chunking_service.extract(documents) 63 | 64 | self.assertEqual(len(chunks), 2) 65 | self.assertEqual(len(chunks[0]), 1) 66 | self.assertEqual(len(chunks[1]), 1) 67 | self.assertEqual(chunks[0][0].content, "test data 1") 68 | self.assertEqual(chunks[0][0].metadata, {"meta": "data1"}) 69 | self.assertEqual(chunks[1][0].content, "test data 1") 70 | self.assertEqual(chunks[1][0].metadata, {"meta": "data1"}) 71 | mock_extract_chunks.assert_called() 72 | 73 | async def test_extract_chunks(self): 74 | doc = MockDocument(data="test data", metadata={"meta": "data"}) 75 | chunk = MockChunk(id=THash(xxhash.xxh3_64_intdigest(doc.data) // 2), content=doc.data, metadata=doc.metadata) 76 | 77 | chunks = await self.chunking_service._extract_chunks(doc) 78 | self.assertEqual(len(chunks), 1) 79 | self.assertEqual(chunks[0].id, chunk.id) 80 | self.assertEqual(chunks[0].content, chunk.content) 81 | self.assertEqual(chunks[0].metadata, chunk.metadata) 82 | 83 | 84 | if __name__ == "__main__": 85 | unittest.main() 86 | -------------------------------------------------------------------------------- /tests/_services/_information_extraction_test.py: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | import asyncio 3 | import unittest 4 | from unittest.mock import AsyncMock, MagicMock, patch 5 | 6 | from fast_graphrag._llm._base import BaseLLMService 7 | from fast_graphrag._models import TQueryEntities 8 | from fast_graphrag._policies._graph_upsert import BaseGraphUpsertPolicy 9 | from fast_graphrag._services import DefaultInformationExtractionService 10 | from fast_graphrag._storage._base import BaseGraphStorage 11 | from fast_graphrag._types import TGraph 12 | 13 | 14 | class TestDefaultInformationExtractionService(unittest.IsolatedAsyncioTestCase): 15 | def setUp(self): 16 | self.llm_service = MagicMock(spec=BaseLLMService) 17 | self.llm_service.send_message = AsyncMock() 18 | self.chunk = MagicMock() 19 | self.chunk.content = "test content" 20 | self.chunk.id = "chunk_id" 21 | self.document = [self.chunk] 22 | self.entity_types = ["entity_type"] 23 | self.prompt_kwargs = {"domain": "test_domain"} 24 | self.service = DefaultInformationExtractionService(graph_upsert=None) 25 | self.service.graph_upsert = AsyncMock(spec=BaseGraphUpsertPolicy) 26 | 27 | @patch("fast_graphrag._services._information_extraction.format_and_send_prompt", new_callable=AsyncMock) 28 | async def test_extract_entities_from_query(self, mock_format_and_send_prompt): 29 | mock_format_and_send_prompt.return_value = ( 30 | TQueryEntities(named=["entity1", "entity2"], generic=["generic1"]), 31 | None, 32 | ) 33 | r = await self.service.extract_entities_from_query(self.llm_service, "test query", self.prompt_kwargs) 34 | named, generic = r["named"], r["generic"] 35 | self.assertEqual(len(named), 2) 36 | self.assertEqual(named[0], "ENTITY1") 37 | self.assertEqual(named[1], "ENTITY2") 38 | self.assertEqual(len(generic), 1) 39 | 40 | @patch("fast_graphrag._services._information_extraction.format_and_send_prompt", new_callable=AsyncMock) 41 | async def test_extract(self, mock_format_and_send_prompt): 42 | mock_format_and_send_prompt.return_value = (TGraph(entities=[], relationships=[]), []) 43 | tasks = self.service.extract(self.llm_service, [self.document], self.prompt_kwargs, self.entity_types) 44 | results = await asyncio.gather(*tasks) 45 | self.assertEqual(len(results), 1) 46 | self.assertIsInstance(results[0], BaseGraphStorage) 47 | 48 | 49 | if __name__ == "__main__": 50 | unittest.main() 51 | -------------------------------------------------------------------------------- /tests/_storage/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/circlemind-ai/fast-graphrag/b370efe01ef836af292a3713d59b2ec23d2fe7c4/tests/_storage/__init__.py -------------------------------------------------------------------------------- /tests/_storage/_base_test.py: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | import unittest 3 | from unittest.mock import AsyncMock, patch 4 | 5 | from fast_graphrag._storage._base import BaseStorage 6 | 7 | 8 | class TestBaseStorage(unittest.IsolatedAsyncioTestCase): 9 | 10 | def setUp(self): 11 | self.storage = BaseStorage(config=None) 12 | 13 | @patch.object(BaseStorage, '_insert_start', new_callable=AsyncMock) 14 | @patch.object(BaseStorage, '_query_done', new_callable=AsyncMock) 15 | @patch("fast_graphrag._storage._base.logger") 16 | async def test_insert_start_from_query_mode(self, mock_logger, mock_query_done, mock_insert_start): 17 | self.storage._mode = "query" 18 | self.storage._in_progress = True 19 | 20 | await self.storage.insert_start() 21 | 22 | mock_query_done.assert_called_once() 23 | mock_insert_start.assert_called_once() 24 | mock_logger.error.assert_called_once() 25 | self.assertEqual(self.storage._mode, "insert") 26 | self.assertFalse(self.storage._in_progress) 27 | 28 | @patch.object(BaseStorage, '_insert_start', new_callable=AsyncMock) 29 | async def test_insert_start_from_none_mode(self, mock_insert_start): 30 | self.storage._mode = None 31 | self.storage._in_progress = False 32 | 33 | await self.storage.insert_start() 34 | 35 | mock_insert_start.assert_called_once() 36 | self.assertEqual(self.storage._mode, "insert") 37 | self.assertFalse(self.storage._in_progress) 38 | 39 | @patch.object(BaseStorage, '_insert_done', new_callable=AsyncMock) 40 | async def test_insert_done_in_insert_mode(self, mock_insert_done): 41 | self.storage._mode = "insert" 42 | self.storage._in_progress = True 43 | 44 | await self.storage.insert_done() 45 | 46 | mock_insert_done.assert_called_once() 47 | 48 | @patch("fast_graphrag._storage._base.logger") 49 | async def test_insert_done_in_query_mode(self, mock_logger): 50 | self.storage._mode = "query" 51 | self.storage._in_progress = True 52 | 53 | await self.storage.insert_done() 54 | mock_logger.error.assert_called_once() 55 | 56 | @patch.object(BaseStorage, '_query_start', new_callable=AsyncMock) 57 | @patch.object(BaseStorage, '_insert_done', new_callable=AsyncMock) 58 | @patch("fast_graphrag._storage._base.logger") 59 | async def test_query_start_from_insert_mode(self, mock_logger, mock_insert_done, mock_query_start): 60 | self.storage._mode = "insert" 61 | self.storage._in_progress = True 62 | 63 | await self.storage.query_start() 64 | 65 | mock_insert_done.assert_called_once() 66 | mock_query_start.assert_called_once() 67 | mock_logger.error.assert_called_once() 68 | self.assertEqual(self.storage._mode, "query") 69 | self.assertFalse(self.storage._in_progress) 70 | 71 | @patch.object(BaseStorage, '_query_start', new_callable=AsyncMock) 72 | async def test_query_start_from_none_mode(self, mock_query_start): 73 | self.storage._mode = None 74 | self.storage._in_progress = False 75 | 76 | await self.storage.query_start() 77 | 78 | mock_query_start.assert_called_once() 79 | self.assertEqual(self.storage._mode, "query") 80 | 81 | @patch.object(BaseStorage, '_query_done', new_callable=AsyncMock) 82 | async def test_query_done_in_query_mode(self, mock_query_done): 83 | self.storage._mode = "query" 84 | self.storage._in_progress = True 85 | 86 | await self.storage.query_done() 87 | 88 | mock_query_done.assert_called_once() 89 | 90 | @patch("fast_graphrag._storage._base.logger") 91 | async def test_query_done_in_insert_mode(self, mock_logger): 92 | self.storage._mode = "insert" 93 | self.storage._in_progress = True 94 | 95 | await self.storage.query_done() 96 | mock_logger.error.assert_called_once() 97 | 98 | if __name__ == "__main__": 99 | unittest.main() 100 | -------------------------------------------------------------------------------- /tests/_storage/_blob_pickle_test.py: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | import pickle 3 | import unittest 4 | from unittest.mock import MagicMock, mock_open, patch 5 | 6 | from fast_graphrag._exceptions import InvalidStorageError 7 | from fast_graphrag._storage._blob_pickle import PickleBlobStorage 8 | 9 | 10 | class TestPickleBlobStorage(unittest.IsolatedAsyncioTestCase): 11 | async def asyncSetUp(self): 12 | self.namespace = MagicMock() 13 | self.namespace.get_load_path.return_value = "blob_data.pkl" 14 | self.namespace.get_save_path.return_value = "blob_data.pkl" 15 | self.storage = PickleBlobStorage(namespace=self.namespace, config=None) 16 | 17 | async def test_get(self): 18 | self.storage._data = {"key": "value"} 19 | result = await self.storage.get() 20 | self.assertEqual(result, {"key": "value"}) 21 | 22 | async def test_set(self): 23 | blob = {"key": "value"} 24 | await self.storage.set(blob) 25 | self.assertEqual(self.storage._data, blob) 26 | 27 | @patch("builtins.open", new_callable=mock_open, read_data=pickle.dumps({"key": "value"})) 28 | async def test_insert_start_with_existing_file(self, mock_open): 29 | await self.storage._insert_start() 30 | self.assertEqual(self.storage._data, {"key": "value"}) 31 | mock_open.assert_called_once_with("blob_data.pkl", "rb") 32 | 33 | @patch("os.path.exists", return_value=False) 34 | async def test_insert_start_without_existing_file(self, mock_exists): 35 | self.namespace.get_load_path.return_value = None 36 | await self.storage._insert_start() 37 | self.assertIsNone(self.storage._data) 38 | 39 | @patch("builtins.open", new_callable=mock_open) 40 | async def test_insert_done(self, mock_open): 41 | self.storage._data = {"key": "value"} 42 | await self.storage._insert_done() 43 | mock_open.assert_called_once_with("blob_data.pkl", "wb") 44 | mock_open().write.assert_called_once() 45 | 46 | @patch("builtins.open", new_callable=mock_open, read_data=pickle.dumps({"key": "value"})) 47 | async def test_query_start_with_existing_file(self, mock_open): 48 | await self.storage._query_start() 49 | self.assertEqual(self.storage._data, {"key": "value"}) 50 | mock_open.assert_called_once_with("blob_data.pkl", "rb") 51 | 52 | @patch("fast_graphrag._storage._blob_pickle.logger") 53 | async def test_query_start_without_existing_file(self, mock_logger): 54 | self.namespace.get_load_path.return_value = None 55 | await self.storage._query_start() 56 | self.assertIsNone(self.storage._data) 57 | mock_logger.warning.assert_called_once() 58 | 59 | @patch("fast_graphrag._storage._blob_pickle.logger") 60 | async def test_insert_start_with_invalid_file(self, mock_logger): 61 | with self.assertRaises(InvalidStorageError): 62 | await self.storage._insert_start() 63 | mock_logger.error.assert_called_once() 64 | 65 | @patch("fast_graphrag._storage._blob_pickle.logger") 66 | async def test_query_start_with_invalid_file(self, mock_logger): 67 | with self.assertRaises(InvalidStorageError): 68 | await self.storage._query_start() 69 | mock_logger.error.assert_called_once() 70 | 71 | async def test_query_done(self): 72 | await self.storage._query_done() # Should not raise any exceptions 73 | 74 | 75 | if __name__ == "__main__": 76 | unittest.main() 77 | -------------------------------------------------------------------------------- /tests/_storage/_gdb_igraph_test.py: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | 3 | import unittest 4 | from unittest.mock import AsyncMock, MagicMock, patch 5 | 6 | import numpy as np 7 | 8 | from fast_graphrag._storage._gdb_igraph import IGraphStorage, IGraphStorageConfig 9 | from fast_graphrag._types import TEntity, TRelation 10 | 11 | 12 | class TestIGraphStorage(unittest.IsolatedAsyncioTestCase): 13 | def setUp(self): 14 | self.config = IGraphStorageConfig(node_cls=TEntity, edge_cls=TRelation) 15 | self.storage = IGraphStorage(config=self.config) 16 | self.storage._graph = MagicMock() 17 | 18 | async def test_node_count(self): 19 | self.storage._graph.vcount.return_value = 10 20 | count = await self.storage.node_count() 21 | self.assertEqual(count, 10) 22 | 23 | async def test_edge_count(self): 24 | self.storage._graph.ecount.return_value = 20 25 | count = await self.storage.edge_count() 26 | self.assertEqual(count, 20) 27 | 28 | async def test_get_node(self): 29 | node = MagicMock() 30 | node.name = "node1" 31 | node.attributes.return_value = {"name": "foo", "description": "value", "type": ""} 32 | self.storage._graph.vs.find.return_value = node 33 | 34 | result = await self.storage.get_node("node1") 35 | self.assertEqual(result, (TEntity(**node.attributes()), node.index)) 36 | 37 | async def test_get_node_not_found(self): 38 | self.storage._graph.vs.find.side_effect = ValueError 39 | result = await self.storage.get_node("node1") 40 | self.assertEqual(result, (None, None)) 41 | 42 | async def test_get_edges(self): 43 | self.storage._get_edge_indices = AsyncMock(return_value=[0, 1]) 44 | self.storage.get_edge_by_index = AsyncMock( 45 | side_effect=[TRelation(source="node1", target="node2", description="txt"), None] 46 | ) 47 | 48 | edges = await self.storage.get_edges("node1", "node2") 49 | self.assertEqual(edges, [(TRelation(source="node1", target="node2", description="txt"), 0)]) 50 | 51 | async def test_get_edge_indices(self): 52 | self.storage._graph.vs.find.side_effect = lambda name: MagicMock(index=name) 53 | self.storage._graph.es.select.return_value = [MagicMock(index=0), MagicMock(index=1)] 54 | 55 | indices = await self.storage._get_edge_indices("node1", "node2") 56 | self.assertEqual(list(indices), [0, 1]) 57 | 58 | async def test_get_node_by_index(self): 59 | node = MagicMock() 60 | node.attributes.return_value = {"name": "foo", "description": "value", "type": "type"} 61 | self.storage._graph.vs.__getitem__.return_value = node 62 | self.storage._graph.vcount.return_value = 1 63 | 64 | result = await self.storage.get_node_by_index(0) 65 | self.assertEqual(result, TEntity(**node.attributes())) 66 | 67 | async def test_get_edge_by_index(self): 68 | edge = MagicMock() 69 | edge.source = "node0" 70 | edge.target = "node1" 71 | edge.attributes.return_value = {"description": "value"} 72 | self.storage._graph.es.__getitem__.return_value = edge 73 | self.storage._graph.vs.__getitem__.side_effect = lambda idx: {"name": idx} 74 | self.storage._graph.ecount.return_value = 1 75 | 76 | result = await self.storage.get_edge_by_index(0) 77 | self.assertEqual(result, TRelation(source="node0", target="node1", **edge.attributes())) 78 | 79 | async def test_upsert_node(self): 80 | node = TEntity(name="node1", description="value", type="type") 81 | self.storage._graph.vcount.return_value = 1 82 | self.storage._graph.vs.__getitem__.return_value = MagicMock(index=0) 83 | 84 | index = await self.storage.upsert_node(node, 0) 85 | self.assertEqual(index, 0) 86 | 87 | async def test_upsert_edge(self): 88 | edge = TRelation(source="node1", target="node2", description="desc", chunks=[]) 89 | self.storage._graph.ecount.return_value = 1 90 | self.storage._graph.es.__getitem__.return_value = MagicMock(index=0) 91 | 92 | index = await self.storage.upsert_edge(edge, 0) 93 | self.assertEqual(index, 0) 94 | 95 | async def test_delete_edges_by_index(self): 96 | self.storage._graph.delete_edges = MagicMock() 97 | indices = [0, 1] 98 | await self.storage.delete_edges_by_index(indices) 99 | self.storage._graph.delete_edges.assert_called_with(indices) 100 | 101 | @patch("fast_graphrag._storage._gdb_igraph.logger") 102 | async def test_score_nodes_empty_graph(self, mock_logger): 103 | self.storage._graph.vcount.return_value = 0 104 | scores = await self.storage.score_nodes(None) 105 | self.assertEqual(scores.shape, (1, 0)) 106 | mock_logger.info.assert_called_with("Trying to score nodes in an empty graph.") 107 | 108 | async def test_score_nodes(self): 109 | self.storage._graph.vcount.return_value = 3 110 | self.storage._graph.personalized_pagerank.return_value = [0.1, 0.2, 0.7] 111 | 112 | scores = await self.storage.score_nodes(None) 113 | self.assertTrue(np.array_equal(scores.toarray(), np.array([[0.1, 0.2, 0.7]], dtype=np.float32))) 114 | 115 | async def test_get_entities_to_relationships_map_empty_graph(self): 116 | self.storage._graph.vs = [] 117 | result = await self.storage.get_entities_to_relationships_map() 118 | self.assertEqual(result.shape, (0, 0)) 119 | 120 | @patch("fast_graphrag._storage._gdb_igraph.csr_from_indices_list") 121 | async def test_get_entities_to_relationships_map(self, mock_csr_from_indices_list): 122 | self.storage._graph.vs = [MagicMock(incident=lambda: [MagicMock(index=0), MagicMock(index=1)])] 123 | self.storage.node_count = AsyncMock(return_value=1) 124 | self.storage.edge_count = AsyncMock(return_value=2) 125 | 126 | await self.storage.get_entities_to_relationships_map() 127 | mock_csr_from_indices_list.assert_called_with([[0, 1]], shape=(1, 2)) 128 | 129 | async def test_get_relationships_attrs_empty_graph(self): 130 | self.storage._graph.es = [] 131 | result = await self.storage.get_relationships_attrs("key") 132 | self.assertEqual(result, []) 133 | 134 | async def test_get_relationships_attrs(self): 135 | self.storage._graph.es.__getitem__.return_value = [[1, 2], [3, 4]] 136 | self.storage._graph.es.__len__.return_value = 2 137 | result = await self.storage.get_relationships_attrs("key") 138 | self.assertEqual(result, [[1, 2], [3, 4]]) 139 | 140 | @patch("igraph.Graph.Read_Picklez") 141 | @patch("fast_graphrag._storage._gdb_igraph.logger") 142 | async def test_insert_start_with_existing_file(self, mock_logger, mock_read_picklez): 143 | self.storage.namespace = MagicMock() 144 | self.storage.namespace.get_load_path.return_value = "dummy_path" 145 | 146 | await self.storage._insert_start() 147 | 148 | mock_read_picklez.assert_called_with("dummy_path") 149 | mock_logger.debug.assert_called_with("Loaded graph storage 'dummy_path'.") 150 | 151 | @patch("igraph.Graph") 152 | @patch("fast_graphrag._storage._gdb_igraph.logger") 153 | async def test_insert_start_with_no_file(self, mock_logger, mock_graph): 154 | self.storage.namespace = MagicMock() 155 | self.storage.namespace.get_load_path.return_value = None 156 | 157 | await self.storage._insert_start() 158 | 159 | mock_graph.assert_called_with(directed=False) 160 | mock_logger.info.assert_called_with("No data file found for graph storage 'None'. Loading empty graph.") 161 | 162 | @patch("igraph.Graph") 163 | @patch("fast_graphrag._storage._gdb_igraph.logger") 164 | async def test_insert_start_with_no_namespace(self, mock_logger, mock_graph): 165 | self.storage.namespace = None 166 | 167 | await self.storage._insert_start() 168 | 169 | mock_graph.assert_called_with(directed=False) 170 | mock_logger.debug.assert_called_with("Creating new volatile graphdb storage.") 171 | 172 | @patch("igraph.Graph.write_picklez") 173 | @patch("fast_graphrag._storage._gdb_igraph.logger") 174 | async def test_insert_done(self, mock_logger, mock_write_picklez): 175 | self.storage.namespace = MagicMock() 176 | self.storage.namespace.get_save_path.return_value = "dummy_path" 177 | 178 | await self.storage._insert_done() 179 | 180 | mock_write_picklez.assert_called_with(self.storage._graph, "dummy_path") 181 | 182 | @patch("igraph.Graph.Read_Picklez") 183 | @patch("fast_graphrag._storage._gdb_igraph.logger") 184 | async def test_query_start_with_existing_file(self, mock_logger, mock_read_picklez): 185 | self.storage.namespace = MagicMock() 186 | self.storage.namespace.get_load_path.return_value = "dummy_path" 187 | 188 | await self.storage._query_start() 189 | 190 | mock_read_picklez.assert_called_with("dummy_path") 191 | mock_logger.debug.assert_called_with("Loaded graph storage 'dummy_path'.") 192 | 193 | @patch("igraph.Graph") 194 | @patch("fast_graphrag._storage._gdb_igraph.logger") 195 | async def test_query_start_with_no_file(self, mock_logger, mock_graph): 196 | self.storage.namespace = MagicMock() 197 | self.storage.namespace.get_load_path.return_value = None 198 | 199 | await self.storage._query_start() 200 | 201 | mock_graph.assert_called_with(directed=False) 202 | mock_logger.warning.assert_called_with( 203 | "No data file found for graph storage 'None'. Loading empty graph." 204 | ) 205 | 206 | 207 | if __name__ == "__main__": 208 | unittest.main() 209 | -------------------------------------------------------------------------------- /tests/_storage/_ikv_pickle_test.py: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | import pickle 3 | import unittest 4 | from unittest.mock import MagicMock, mock_open, patch 5 | 6 | import numpy as np 7 | 8 | from fast_graphrag._exceptions import InvalidStorageError 9 | from fast_graphrag._storage._ikv_pickle import PickleIndexedKeyValueStorage 10 | 11 | 12 | class TestPickleIndexedKeyValueStorage(unittest.IsolatedAsyncioTestCase): 13 | async def asyncSetUp(self): 14 | self.storage = PickleIndexedKeyValueStorage(namespace=None, config=None) 15 | await self.storage._insert_start() 16 | 17 | async def test_size(self): 18 | self.storage._data = {1: "value1", 2: "value2"} 19 | size = await self.storage.size() 20 | self.assertEqual(size, 2) 21 | 22 | async def test_get(self): 23 | self.storage._data = {1: "value1", 2: "value2"} 24 | self.storage._key_to_index = {"key1": 1, "key2": 2} 25 | result = await self.storage.get(["key1", "key2", "key3"]) 26 | self.assertEqual(list(result), ["value1", "value2", None]) 27 | 28 | async def test_get_by_index(self): 29 | self.storage._data = {1: "value1", 2: "value2"} 30 | result = await self.storage.get_by_index([1, 2, 3]) 31 | self.assertEqual(list(result), ["value1", "value2", None]) 32 | 33 | async def test_get_index(self): 34 | self.storage._key_to_index = {"key1": 1, "key2": 2} 35 | result = await self.storage.get_index(["key1", "key2", "key3"]) 36 | self.assertEqual(list(result), [1, 2, None]) 37 | 38 | async def test_upsert(self): 39 | await self.storage.upsert(["key1", "key2"], ["value1", "value2"]) 40 | self.assertEqual(self.storage._data, {0: "value1", 1: "value2"}) 41 | self.assertEqual(self.storage._key_to_index, {"key1": 0, "key2": 1}) 42 | 43 | async def test_delete(self): 44 | self.storage._data = {0: "value1", 1: "value2"} 45 | self.storage._key_to_index = {"key1": 0, "key2": 1} 46 | await self.storage.delete(["key1"]) 47 | self.assertEqual(self.storage._data, {1: "value2"}) 48 | self.assertEqual(self.storage._key_to_index, {"key2": 1}) 49 | self.assertEqual(self.storage._free_indices, [0]) 50 | 51 | async def test_mask_new(self): 52 | self.storage._key_to_index = {"key1": 0, "key2": 1} 53 | result = await self.storage.mask_new([["key1", "key3"]]) 54 | self.assertTrue(np.array_equal(result, [[False, True]])) 55 | 56 | @patch("builtins.open", new_callable=mock_open, read_data=pickle.dumps(({0: "value"}, [1, 2, 3], {"key": 0}))) 57 | @patch("os.path.exists", return_value=True) 58 | @patch("fast_graphrag._storage._ikv_pickle.logger") 59 | async def test_insert_start_with_existing_file(self, mock_logger, mock_exists, mock_open): 60 | self.storage.namespace = MagicMock() 61 | self.storage.namespace.get_load_path.return_value = "dummy_path" 62 | 63 | # Call the function 64 | await self.storage._insert_start() 65 | 66 | # Check if data was loaded correctly 67 | self.assertEqual(self.storage._data, {0: "value"}) 68 | self.assertEqual(self.storage._free_indices, [1, 2, 3]) 69 | self.assertEqual(self.storage._key_to_index, {"key": 0}) 70 | mock_logger.debug.assert_called_once() 71 | 72 | @patch("fast_graphrag._storage._ikv_pickle.logger") 73 | async def test_insert_start_with_no_file(self, mock_logger): 74 | self.storage.namespace = MagicMock() 75 | self.storage.namespace.get_load_path.return_value = "dummy_path" 76 | 77 | # Call the function 78 | with self.assertRaises(InvalidStorageError): 79 | await self.storage._insert_start() 80 | 81 | @patch("fast_graphrag._storage._ikv_pickle.logger") 82 | async def test_insert_start_with_no_namespace(self, mock_logger): 83 | self.storage.namespace = None 84 | 85 | # Call the function 86 | await self.storage._insert_start() 87 | 88 | # Check if data was initialized correctly 89 | self.assertEqual(self.storage._data, {}) 90 | self.assertEqual(self.storage._free_indices, []) 91 | mock_logger.debug.assert_called_with("Creating new volatile indexed key-value storage.") 92 | 93 | @patch("builtins.open", new_callable=mock_open) 94 | @patch("fast_graphrag._storage._ikv_pickle.logger") 95 | async def test_insert_done(self, mock_logger, mock_open): 96 | self.storage.namespace = MagicMock() 97 | self.storage.namespace.get_save_path.return_value = "dummy_path" 98 | self.storage._data = {0: "value"} 99 | self.storage._free_indices = [1, 2, 3] 100 | self.storage._key_to_index = {"key": 0} 101 | 102 | # Call the function 103 | await self.storage._insert_done() 104 | 105 | # Check if data was saved correctly 106 | mock_open.assert_called_with("dummy_path", "wb") 107 | mock_logger.debug.assert_called_with("Saving 1 elements to indexed key-value storage 'dummy_path'.") 108 | 109 | @patch("builtins.open", new_callable=mock_open, read_data=pickle.dumps(({0: "value"}, [1, 2, 3], {"key": 0}))) 110 | @patch("os.path.exists", return_value=True) 111 | @patch("fast_graphrag._storage._ikv_pickle.logger") 112 | async def test_query_start_with_existing_file(self, mock_logger, mock_exists, mock_open): 113 | self.storage.namespace = MagicMock() 114 | self.storage.namespace.get_load_path.return_value = "dummy_path" 115 | 116 | # Call the function 117 | await self.storage._query_start() 118 | 119 | # Check if data was loaded correctly 120 | self.assertEqual(self.storage._data, {0: "value"}) 121 | self.assertEqual(self.storage._free_indices, [1, 2, 3]) 122 | self.assertEqual(self.storage._key_to_index, {"key": 0}) 123 | mock_logger.debug.assert_called_with("Loaded 1 elements from indexed key-value storage 'dummy_path'.") 124 | 125 | @patch("os.path.exists", return_value=False) 126 | @patch("fast_graphrag._storage._ikv_pickle.logger") 127 | async def test_query_start_with_no_file(self, mock_logger, mock_exists): 128 | self.storage.namespace = MagicMock() 129 | self.storage.namespace.get_load_path.return_value = None 130 | 131 | # Call the function 132 | await self.storage._query_start() 133 | 134 | # Check if data was initialized correctly 135 | self.assertEqual(self.storage._data, {}) 136 | self.assertEqual(self.storage._free_indices, []) 137 | mock_logger.warning.assert_called_with( 138 | "No data file found for key-vector storage 'None'. Loading empty storage." 139 | ) 140 | 141 | 142 | if __name__ == "__main__": 143 | unittest.main() 144 | -------------------------------------------------------------------------------- /tests/_storage/_namespace_test.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | import shutil 4 | import unittest 5 | from typing import cast 6 | 7 | from fast_graphrag._exceptions import InvalidStorageError 8 | from fast_graphrag._storage._namespace import Namespace, Workspace 9 | 10 | 11 | class TestWorkspace(unittest.IsolatedAsyncioTestCase): 12 | def setUp(self): 13 | def _(self: Workspace) -> None: 14 | pass 15 | 16 | Workspace.__del__ = _ 17 | self.test_dir = "test_workspace" 18 | self.workspace = Workspace(self.test_dir) 19 | 20 | def tearDown(self): 21 | if os.path.exists(self.test_dir): 22 | shutil.rmtree(self.test_dir) 23 | 24 | def test_new_workspace(self): 25 | ws = Workspace.new(self.test_dir) 26 | self.assertIsInstance(ws, Workspace) 27 | self.assertEqual(ws.working_dir, self.test_dir) 28 | 29 | def test_get_load_path_no_checkpoint(self): 30 | self.assertEqual(self.workspace.get_load_path(), None) 31 | 32 | def test_get_save_path_creates_directory(self): 33 | save_path = self.workspace.get_save_path() 34 | self.assertTrue(os.path.exists(save_path)) 35 | 36 | async def test_with_checkpoint_failures(self): 37 | for checkpoint in [1, 2, 3]: 38 | os.makedirs(os.path.join(self.test_dir, str(checkpoint))) 39 | self.workspace = Workspace(self.test_dir) 40 | 41 | async def sample_fn(): 42 | if "1" not in cast(str, self.workspace.get_load_path()): 43 | raise Exception("Checkpoint not loaded") 44 | return "success" 45 | 46 | result = await self.workspace.with_checkpoints(sample_fn) 47 | self.assertEqual(result, "success") 48 | self.assertEqual(self.workspace.current_load_checkpoint, 1) 49 | self.assertEqual(self.workspace.failed_checkpoints, ["3", "2"]) 50 | 51 | async def test_with_checkpoint_no_failure(self): 52 | for checkpoint in [1, 2, 3]: 53 | os.makedirs(os.path.join(self.test_dir, str(checkpoint))) 54 | self.workspace = Workspace(self.test_dir) 55 | 56 | async def sample_fn(): 57 | return "success" 58 | 59 | result = await self.workspace.with_checkpoints(sample_fn) 60 | self.assertEqual(result, "success") 61 | self.assertEqual(self.workspace.current_load_checkpoint, 3) 62 | self.assertEqual(self.workspace.failed_checkpoints, []) 63 | 64 | async def test_with_checkpoint_all_failures(self): 65 | for checkpoint in [1, 2, 3]: 66 | os.makedirs(os.path.join(self.test_dir, str(checkpoint))) 67 | self.workspace = Workspace(self.test_dir) 68 | 69 | async def sample_fn(): 70 | raise Exception("Checkpoint not loaded") 71 | 72 | with self.assertRaises(InvalidStorageError): 73 | await self.workspace.with_checkpoints(sample_fn) 74 | self.assertEqual(self.workspace.current_load_checkpoint, None) 75 | self.assertEqual(self.workspace.failed_checkpoints, ["3", "2", "1"]) 76 | 77 | async def test_with_checkpoint_all_failures_accept_none(self): 78 | for checkpoint in [1, 2, 3]: 79 | os.makedirs(os.path.join(self.test_dir, str(checkpoint))) 80 | self.workspace = Workspace(self.test_dir) 81 | 82 | async def sample_fn(): 83 | if self.workspace.get_load_path() is not None: 84 | raise Exception("Checkpoint not loaded") 85 | 86 | result = await self.workspace.with_checkpoints(sample_fn) 87 | self.assertEqual(result, None) 88 | self.assertEqual(self.workspace.current_load_checkpoint, None) 89 | self.assertEqual(self.workspace.failed_checkpoints, ["3", "2", "1"]) 90 | 91 | 92 | class TestNamespace(unittest.TestCase): 93 | def setUp(self): 94 | def _(self: Workspace) -> None: 95 | pass 96 | 97 | Workspace.__del__ = _ 98 | 99 | def test_get_load_path_no_checkpoint_no_file(self): 100 | self.test_dir = "test_workspace" 101 | self.workspace = Workspace(self.test_dir) 102 | self.workspace.__del__ = lambda: None 103 | self.namespace = Namespace(self.workspace, "test_namespace") 104 | self.assertEqual(None, self.namespace.get_load_path("resource")) 105 | del self.workspace 106 | gc.collect() 107 | if os.path.exists(self.test_dir): 108 | shutil.rmtree(self.test_dir) 109 | 110 | def test_get_load_path_no_checkpoint_with_file(self): 111 | self.test_dir = "test_workspace" 112 | self.workspace = Workspace(self.test_dir) 113 | self.workspace.__del__ = lambda: None 114 | self.namespace = Namespace(self.workspace, "test_namespace") 115 | 116 | with open(os.path.join(self.test_dir, "test_namespace_resource"), "w") as f: 117 | f.write("test") 118 | self.assertEqual( 119 | os.path.join("test_workspace", "test_namespace_resource"), self.namespace.get_load_path("resource") 120 | ) 121 | del self.workspace 122 | gc.collect() 123 | if os.path.exists(self.test_dir): 124 | shutil.rmtree(self.test_dir) 125 | 126 | def test_get_load_path_with_checkpoint(self): 127 | self.test_dir = "test_workspace" 128 | self.workspace = Workspace(self.test_dir) 129 | self.namespace = Namespace(self.workspace, "test_namespace") 130 | self.workspace.current_load_checkpoint = 1 131 | load_path = self.namespace.get_load_path("resource") 132 | self.assertEqual(load_path, os.path.join(self.test_dir, "1", "test_namespace_resource")) 133 | 134 | del self.workspace 135 | gc.collect() 136 | if os.path.exists(self.test_dir): 137 | shutil.rmtree(self.test_dir) 138 | 139 | def test_get_save_path_creates_directory(self): 140 | self.test_dir = "test_workspace" 141 | self.workspace = Workspace(self.test_dir) 142 | self.namespace = Namespace(self.workspace, "test_namespace") 143 | save_path = self.namespace.get_save_path("resource") 144 | self.assertTrue(os.path.exists(os.path.join(*os.path.split(save_path)[:-1]))) 145 | 146 | del self.workspace 147 | gc.collect() 148 | if os.path.exists(self.test_dir): 149 | shutil.rmtree(self.test_dir) 150 | 151 | 152 | if __name__ == "__main__": 153 | unittest.main() 154 | -------------------------------------------------------------------------------- /tests/_types_test.py: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | import re 3 | import unittest 4 | from dataclasses import asdict 5 | 6 | from fast_graphrag._types import ( 7 | TChunk, 8 | TContext, 9 | TDocument, 10 | TEntity, 11 | TGraph, 12 | TQueryResponse, 13 | TRelation, 14 | TScore, 15 | ) 16 | 17 | 18 | class TestTypes(unittest.TestCase): 19 | def test_tdocument(self): 20 | doc = TDocument(data="Sample data", metadata={"key": "value"}) 21 | self.assertEqual(doc.data, "Sample data") 22 | self.assertEqual(doc.metadata, {"key": "value"}) 23 | 24 | def test_tchunk(self): 25 | chunk = TChunk(id=123, content="Sample content", metadata={"key": "value"}) 26 | self.assertEqual(chunk.id, 123) 27 | self.assertEqual(chunk.content, "Sample content") 28 | self.assertEqual(chunk.metadata, {"key": "value"}) 29 | 30 | def test_tentity(self): 31 | entity = TEntity(name="Entity1", type="Type1", description="Description1") 32 | self.assertEqual(entity.name, "Entity1") 33 | self.assertEqual(entity.type, "Type1") 34 | self.assertEqual(entity.description, "Description1") 35 | 36 | pydantic_entity = TEntity.Model(name="Entity1", type="Type1", desc="Description1") 37 | entity.name = entity.name.upper() 38 | entity.type = entity.type.upper() 39 | self.assertEqual(asdict(entity), asdict(pydantic_entity.to_dataclass(pydantic_entity))) 40 | 41 | def test_trelation(self): 42 | relation = TRelation(source="Entity1", target="Entity2", description="Relation description") 43 | self.assertEqual(relation.source, "Entity1") 44 | self.assertEqual(relation.target, "Entity2") 45 | self.assertEqual(relation.description, "Relation description") 46 | 47 | pydantic_relation = TRelation.Model(source="Entity1", target="Entity2", desc="Relation description") 48 | 49 | relation.source = relation.source.upper() 50 | relation.target = relation.target.upper() 51 | self.assertEqual(asdict(relation), asdict(pydantic_relation.to_dataclass(pydantic_relation))) 52 | 53 | def test_tgraph(self): 54 | entity = TEntity(name="Entity1", type="Type1", description="Description1") 55 | relation = TRelation(source="Entity1", target="Entity2", description="Relation description") 56 | graph = TGraph(entities=[entity], relationships=[relation]) 57 | self.assertEqual(graph.entities, [entity]) 58 | self.assertEqual(graph.relationships, [relation]) 59 | 60 | pydantic_graph = TGraph.Model( 61 | entities=[TEntity.Model(name="Entity1", type="Type1", desc="Description1")], 62 | relationships=[TRelation.Model(source="Entity1", target="Entity2", desc="Relation description")], 63 | other_relationships=[], 64 | ) 65 | 66 | for entity in graph.entities: 67 | entity.name = entity.name.upper() 68 | entity.type = entity.type.upper() 69 | for relation in graph.relationships: 70 | relation.source = relation.source.upper() 71 | relation.target = relation.target.upper() 72 | self.assertEqual(asdict(graph), asdict(pydantic_graph.to_dataclass(pydantic_graph))) 73 | 74 | def test_tcontext(self): 75 | entities = [TEntity(name="Entity1", type="Type1", description="Sample description 1")] * 8 + [ 76 | TEntity(name="Entity2", type="Type2", description="Sample description 2") 77 | ] * 8 78 | relationships = [ 79 | TRelation(source="Entity1", target="Entity2", description="Relation description 12") 80 | ] * 8 + [ 81 | TRelation(source="Entity2", target="Entity1", description="Relation description 21") 82 | ] * 8 83 | chunks = [ 84 | TChunk(id=i, content=f"Long and repeated chunk content {i}" * 4, metadata={"key": f"value {i}"}) 85 | for i in range(16) 86 | ] 87 | 88 | for r, c in zip(relationships, chunks): 89 | r.chunks = [c.id] 90 | context = TContext( 91 | entities=[(e, TScore(0.9)) for e in entities], 92 | relations=[(r, TScore(0.8)) for r in relationships], 93 | chunks=[(c, TScore(0.7)) for c in chunks], 94 | ) 95 | max_chars = {"entities": 128, "relations": 128, "chunks": 512} 96 | csv = context.truncate(max_chars.copy(), True) 97 | 98 | csv_entities = re.findall(r"## Entities\n```csv\n(.*?)\n```", csv, re.DOTALL) 99 | csv_relationships = re.findall(r"## Relationships\n```csv\n(.*?)\n```", csv, re.DOTALL) 100 | csv_chunks = re.findall(r"## Sources\n.*=====", csv, re.DOTALL) 101 | 102 | self.assertEqual(len(csv_entities), 1) 103 | self.assertEqual(len(csv_relationships), 1) 104 | self.assertEqual(len(csv_chunks), 1) 105 | 106 | self.assertGreaterEqual( 107 | sum(max_chars.values()) + 16, len(csv_entities[0]) + len(csv_relationships[0]) + len(csv_chunks[0]) 108 | ) 109 | 110 | def test_tqueryresponse(self): 111 | context = TContext( 112 | entities=[("Entity1", TScore(0.9))], 113 | relations=[("Relation1", TScore(0.8))], 114 | chunks=[("Chunk1", TScore(0.7))], 115 | ) 116 | query_response = TQueryResponse(response="Sample response", context=context) 117 | self.assertEqual(query_response.response, "Sample response") 118 | self.assertEqual(query_response.context, context) 119 | 120 | 121 | if __name__ == "__main__": 122 | unittest.main() 123 | -------------------------------------------------------------------------------- /tests/_utils_test.py: -------------------------------------------------------------------------------- 1 | # tests/test_utils.py 2 | 3 | import asyncio 4 | import threading 5 | import unittest 6 | from typing import List 7 | 8 | import numpy as np 9 | from scipy.sparse import csr_matrix 10 | 11 | from fast_graphrag._utils import csr_from_indices_list, extract_sorted_scores, get_event_loop 12 | 13 | 14 | class TestGetEventLoop(unittest.TestCase): 15 | def test_get_existing_event_loop(self): 16 | loop = asyncio.new_event_loop() 17 | asyncio.set_event_loop(loop) 18 | self.assertEqual(get_event_loop(), loop) 19 | loop.close() 20 | 21 | def test_get_event_loop_in_sub_thread(self): 22 | def target(): 23 | loop = get_event_loop() 24 | self.assertIsInstance(loop, asyncio.AbstractEventLoop) 25 | loop.close() 26 | 27 | thread = threading.Thread(target=target) 28 | thread.start() 29 | thread.join() 30 | 31 | 32 | # Not checked 33 | class TestExtractSortedScores(unittest.TestCase): 34 | def test_non_zero_elements(self): 35 | row_vector = csr_matrix([[0, 0.1, 0, 0.7, 0.5, 0]]) 36 | indices, scores = extract_sorted_scores(row_vector) 37 | np.testing.assert_array_equal(indices, np.array([3, 4, 1])) 38 | np.testing.assert_array_equal(scores, np.array([0.7, 0.5, 0.1])) 39 | 40 | def test_empty(self): 41 | row_vector = csr_matrix((0, 0)) 42 | indices, scores = extract_sorted_scores(row_vector) 43 | np.testing.assert_array_equal(indices, np.array([], dtype=np.int64)) 44 | np.testing.assert_array_equal(scores, np.array([], dtype=np.float32)) 45 | 46 | def test_empty_row_vector(self): 47 | row_vector = csr_matrix([[]]) 48 | indices, scores = extract_sorted_scores(row_vector) 49 | np.testing.assert_array_equal(indices, np.array([], dtype=np.int64)) 50 | np.testing.assert_array_equal(scores, np.array([], dtype=np.float32)) 51 | 52 | def test_single_element(self): 53 | row_vector = csr_matrix([[0.5]]) 54 | indices, scores = extract_sorted_scores(row_vector) 55 | np.testing.assert_array_equal(indices, np.array([0])) 56 | np.testing.assert_array_equal(scores, np.array([0.5])) 57 | 58 | def test_all_zero_elements(self): 59 | row_vector = csr_matrix([[0, 0, 0, 0, 0]]) 60 | indices, scores = extract_sorted_scores(row_vector) 61 | np.testing.assert_array_equal(indices, np.array([], dtype=np.int64)) 62 | np.testing.assert_array_equal(scores, np.array([], dtype=np.float32)) 63 | 64 | def test_duplicate_elements(self): 65 | row_vector = csr_matrix([[0, 0.1, 0, 0.7, 0.5, 0.7]]) 66 | indices, scores = extract_sorted_scores(row_vector) 67 | expected_indices_1 = np.array([5, 3, 4, 1]) 68 | expected_indices_2 = np.array([3, 5, 4, 1]) 69 | self.assertTrue( 70 | np.array_equal(indices, expected_indices_1) or np.array_equal(indices, expected_indices_2), 71 | f"indices {indices} do not match either {expected_indices_1} or {expected_indices_2}" 72 | ) 73 | np.testing.assert_array_equal(scores, np.array([0.7, 0.7, 0.5, 0.1])) 74 | 75 | 76 | class TestCsrFromListOfLists(unittest.TestCase): 77 | def test_repeated_elements(self): 78 | data: List[List[int]] = [[0, 0], [], []] 79 | expected_matrix = csr_matrix(([1, 1, 0], ([0, 0, 0], [0, 0, 0])), shape=(3, 3)) 80 | result_matrix = csr_from_indices_list(data, shape=(3, 3)) 81 | np.testing.assert_array_equal(result_matrix.toarray(), expected_matrix.toarray()) 82 | 83 | def test_non_zero_elements(self): 84 | data = [[0, 1, 2], [2, 3], [0, 3]] 85 | expected_matrix = csr_matrix([[1, 1, 1, 0, 0], [0, 0, 1, 1, 0], [1, 0, 0, 1, 0]], shape=(3, 5)) 86 | result_matrix = csr_from_indices_list(data, shape=(3, 5)) 87 | np.testing.assert_array_equal(result_matrix.toarray(), expected_matrix.toarray()) 88 | 89 | def test_empty_list_of_lists(self): 90 | data: List[List[int]] = [] 91 | expected_matrix = csr_matrix((0, 0)) 92 | result_matrix = csr_from_indices_list(data, shape=(0, 0)) 93 | np.testing.assert_array_equal(result_matrix.toarray(), expected_matrix.toarray()) 94 | 95 | def test_empty_list_of_lists_with_unempty_shape(self): 96 | data: List[List[int]] = [] 97 | expected_matrix = csr_matrix((1, 1)) 98 | result_matrix = csr_from_indices_list(data, shape=(1, 1)) 99 | np.testing.assert_array_equal(result_matrix.toarray(), expected_matrix.toarray()) 100 | 101 | def test_list_with_empty_sublists(self): 102 | data: List[List[int]] = [[], [], []] 103 | expected_matrix = csr_matrix((3, 0)) 104 | result_matrix = csr_from_indices_list(data, shape=(3, 0)) 105 | np.testing.assert_array_equal(result_matrix.toarray(), expected_matrix.toarray()) 106 | 107 | 108 | if __name__ == "__main__": 109 | unittest.main() 110 | --------------------------------------------------------------------------------