├── .github
├── ISSUE_TEMPLATE
│ ├── 🐞-bug-report.md
│ └── 💡-feature-request.md
└── workflows
│ ├── publish.yml
│ └── python-package.yml
├── .gitignore
├── .vscode
└── settings.json
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── banner.png
├── benchmarks
├── README.md
├── _domain.py
├── create_db.bat
├── create_dbs.bat
├── create_dbs.sh
├── datasets
│ └── 2wikimultihopqa.json
├── db
│ └── .gitignore
├── evaluate_dbs.bat
├── evaluate_dbs.sh
├── graph_benchmark.py
├── lightrag_benchmark.py
├── nano_benchmark.py
├── questions
│ ├── 2wikimultihopqa_101.json
│ └── 2wikimultihopqa_51.json
├── results
│ ├── graph
│ │ ├── 2wikimultihopqa_101.json
│ │ └── 2wikimultihopqa_51.json
│ ├── lightrag
│ │ ├── 2wikimultihopqa_101_local.json
│ │ └── 2wikimultihopqa_51_local.json
│ ├── nano
│ │ ├── 2wikimultihopqa_101_local.json
│ │ └── 2wikimultihopqa_51_local.json
│ └── vdb
│ │ ├── 2wikimultihopqa_101.json
│ │ └── 2wikimultihopqa_51.json
└── vdb_benchmark.py
├── demo.gif
├── examples
├── checkpointing.ipynb
├── custom_llm.py
├── gemini_example.py
├── gemini_vertexai_llm.py
└── query_parameters.ipynb
├── fast_graphrag
├── __init__.py
├── _exceptions.py
├── _graphrag.py
├── _llm
│ ├── __init__.py
│ ├── _base.py
│ ├── _default.py
│ ├── _llm_genai.py
│ ├── _llm_openai.py
│ └── _llm_voyage.py
├── _models.py
├── _policies
│ ├── __init__.py
│ ├── _base.py
│ ├── _graph_upsert.py
│ └── _ranking.py
├── _prompt.py
├── _services
│ ├── __init__.py
│ ├── _base.py
│ ├── _chunk_extraction.py
│ ├── _information_extraction.py
│ └── _state_manager.py
├── _storage
│ ├── __init__.py
│ ├── _base.py
│ ├── _blob_pickle.py
│ ├── _default.py
│ ├── _gdb_igraph.py
│ ├── _ikv_pickle.py
│ ├── _namespace.py
│ └── _vdb_hnswlib.py
├── _types.py
└── _utils.py
├── mock_data.txt
├── poetry.lock
├── pyproject.toml
└── tests
├── __init__.py
├── _graphrag_test.py
├── _llm
├── __init__.py
├── _base_test.py
└── _llm_openai_test.py
├── _models_test.py
├── _policies
├── __init__.py
├── _graph_upsert_test.py
└── _ranking_test.py
├── _services
├── __init__.py
├── _chunk_extraction_test.py
└── _information_extraction_test.py
├── _storage
├── __init__.py
├── _base_test.py
├── _blob_pickle_test.py
├── _gdb_igraph_test.py
├── _ikv_pickle_test.py
├── _namespace_test.py
└── _vdb_hnswlib_test.py
├── _types_test.py
└── _utils_test.py
/.github/ISSUE_TEMPLATE/🐞-bug-report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: "\U0001F41E Bug report"
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Describe the bug**
11 | A clear and concise description of what the bug is.
12 |
13 | **To Reproduce**
14 | Steps to reproduce the behavior:
15 | 1. Go to '...'
16 | 2. Click on '....'
17 | 3. Scroll down to '....'
18 | 4. See error
19 |
20 | **Expected behavior**
21 | A clear and concise description of what you expected to happen.
22 |
23 | **Screenshots**
24 | If applicable, add screenshots to help explain your problem.
25 |
26 | **Additional context**
27 | Add any other context about the problem here.
28 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/💡-feature-request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: "\U0001F4A1 Feature request"
3 | about: Suggest an idea for this project
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Is your feature request related to a problem? Please describe.**
11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
12 |
13 | **Describe the solution you'd like**
14 | A clear and concise description of what you want to happen.
15 |
16 | **Describe alternatives you've considered**
17 | A clear and concise description of any alternative solutions or features you've considered.
18 |
19 | **Additional context**
20 | Add any other context or screenshots about the feature request here.
21 |
--------------------------------------------------------------------------------
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish Package
2 |
3 | on:
4 | release:
5 | types: [published]
6 |
7 | workflow_dispatch: # TODO remove
8 |
9 | jobs:
10 | build-n-publish:
11 | name: Build and publish to PyPI
12 | runs-on: ubuntu-22.04
13 | environment:
14 | name: pypi
15 | url: https://pypi.org/p/fast-graphrag
16 | permissions:
17 | id-token: write
18 | steps:
19 | - uses: actions/checkout@master
20 |
21 | - name: Set up Python 3.11
22 | uses: actions/setup-python@v1
23 | with:
24 | python-version: 3.11
25 |
26 | - name: Install Poetry
27 | run: pipx install poetry==1.8.*
28 |
29 | - name: Cache Poetry virtual environment
30 | uses: actions/cache@v3
31 | with:
32 | path: ~/.cache/pypoetry
33 | key: ${{ runner.os }}-poetry-${{ hashFiles('**/poetry.lock') }}
34 | restore-keys: |
35 | ${{ runner.os }}-poetry-
36 |
37 | - name: Lock
38 | run: poetry lock
39 | - name: Build
40 | run: poetry build
41 |
42 | - name: pypi-publish
43 | uses: pypa/gh-action-pypi-publish@v1.10.3
--------------------------------------------------------------------------------
/.github/workflows/python-package.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
3 |
4 | name: Python package
5 |
6 | on:
7 | push:
8 | branches: [ "main" ]
9 | pull_request:
10 | branches: [ "main" ]
11 |
12 | jobs:
13 | build:
14 |
15 | runs-on: ubuntu-latest
16 | strategy:
17 | fail-fast: false
18 | matrix:
19 | python-version: ["3.10", "3.11", "3.12"]
20 |
21 | steps:
22 | - uses: actions/checkout@v4
23 | - name: Set up Python ${{ matrix.python-version }}
24 | uses: actions/setup-python@v3
25 | with:
26 | python-version: ${{ matrix.python-version }}
27 | - name: Install dependencies
28 | run: |
29 | python -m pip install --upgrade pip
30 | python -m pip install ruff
31 | pipx install poetry
32 | poetry lock
33 | poetry install
34 | - name: Lint with ruff
35 | run: |
36 | # Stop the build if there are Python syntax errors or undefined names
37 | ruff check . --select E9,F63,F7,F82 --show-files
38 |
39 | # Check with the same settings as the dev environment
40 | ruff check . --select E,W,F,I,B,C4,N,D --ignore C901,W191,D401 --show-files
41 |
42 | # Treat all errors as warnings with max line length and complexity constraints
43 | ruff check . --exit-zero --line-length 127 --select C901
44 | - name: Test with unittest
45 | run: |
46 | poetry run python -m unittest discover -s tests/ -p "*_test.py"
47 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # celery beat schedule file
88 | celerybeat-schedule
89 |
90 | # SageMath parsed files
91 | *.sage.py
92 |
93 | # Environments
94 | .env
95 | .venv
96 | env/
97 | venv/
98 | ENV/
99 | env.bak/
100 | venv.bak/
101 |
102 | # Spyder project settings
103 | .spyderproject
104 | .spyproject
105 |
106 | # Rope project settings
107 | .ropeproject
108 |
109 | # mkdocs documentation
110 | /site
111 |
112 | # mypy
113 | .mypy_cache/
114 | .dmypy.json
115 | dmypy.json
116 |
117 | # Pyre type checker
118 | .pyre/
119 |
120 | # pytype static type analyzer
121 | .pytype/
122 |
123 | # Cython debug symbols
124 | cython_debug/
125 | book_example/
126 | book.txt
127 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python.analysis.typeCheckingMode": "strict",
3 | "python.testing.unittestArgs": [
4 | "-v",
5 | "-s",
6 | ".",
7 | "-p",
8 | "*_test.py"
9 | ],
10 | "python.testing.pytestEnabled": false,
11 | "python.testing.unittestEnabled": true,
12 | "python.autoComplete.extraPaths": [
13 | "${workspaceFolder}"
14 | ],
15 | "python.analysis.extraPaths": [
16 | "${workspaceFolder}"
17 | ]
18 | }
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct - Fast GraphRAG
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to make participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behaviour that contributes to a positive environment for our
15 | community include:
16 |
17 | * Demonstrating empathy and kindness toward other people
18 | * Being respectful of differing opinions, viewpoints, and experiences
19 | * Giving and gracefully accepting constructive feedback
20 | * Accepting responsibility and apologising to those affected by our mistakes,
21 | and learning from the experience
22 | * Focusing on what is best not just for us as individuals, but for the
23 | overall community
24 |
25 | Examples of unacceptable behaviour include:
26 |
27 | * The use of sexualised language or imagery, and sexual attention or advances
28 | * Trolling, insulting or derogatory comments, and personal or political attacks
29 | * Public or private harassment
30 | * Publishing others' private information, such as a physical or email
31 | address, without their explicit permission
32 | * Other conduct which could reasonably be considered inappropriate in a
33 | professional setting
34 |
35 | ## Our Responsibilities
36 |
37 | Project maintainers are responsible for clarifying and enforcing our standards of
38 | acceptable behaviour and will take appropriate and fair corrective action in
39 | response to any instances of unacceptable behaviour.
40 |
41 | Project maintainers have the right and responsibility to remove, edit, or reject
42 | comments, commits, code, wiki edits, issues, and other contributions that are
43 | not aligned to this Code of Conduct, or to ban
44 | temporarily or permanently any contributor for other behaviours that they deem
45 | inappropriate, threatening, offensive, or harmful.
46 |
47 | ## Scope
48 |
49 | This Code of Conduct applies within all community spaces, and also applies when
50 | an individual is officially representing the community in public spaces.
51 | Examples of representing our community include using an official e-mail address,
52 | posting via an official social media account, or acting as an appointed
53 | representative at an online or offline event.
54 |
55 | ## Enforcement
56 |
57 | Instances of abusive, harassing, or otherwise unacceptable behaviour may be
58 | reported to the community leaders responsible for enforcement at .
59 | All complaints will be reviewed and investigated promptly and fairly.
60 |
61 | All community leaders are obligated to respect the privacy and security of the
62 | reporter of any incident.
63 |
64 | ## Attribution
65 |
66 | This Code of Conduct is adapted from the [Contributor Covenant](https://contributor-covenant.org/), version
67 | [1.4](https://www.contributor-covenant.org/version/1/4/code-of-conduct/code_of_conduct.md) and
68 | [2.0](https://www.contributor-covenant.org/version/2/0/code_of_conduct/code_of_conduct.md),
69 | and was generated by [contributing-gen](https://github.com/bttger/contributing-gen).
70 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 |
2 | # Contributing to Fast GraphRAG
3 |
4 | First off, thanks for taking the time to contribute! ❤️
5 |
6 | All types of contributions are encouraged and valued. See the [Table of Contents](#table-of-contents) for different ways to help and details about how this project handles them. Please make sure to read the relevant section before making your contribution. It will make it a lot easier for us maintainers and smooth out the experience for all involved. The community looks forward to your contributions. 🎉
7 |
8 | > And if you like the project, but just don't have time to contribute, that's fine. There are other easy ways to support the project and show your appreciation, which we would also be very happy about:
9 | > - Star the project
10 | > - Tweet about it
11 | > - Refer this project in your project's readme
12 | > - Mention the project at local meetups and tell your friends/colleagues
13 |
14 |
15 | ## Table of Contents
16 |
17 | - [Code of Conduct](#code-of-conduct)
18 | - [I Have a Question](#i-have-a-question)
19 | - [I Want To Contribute](#i-want-to-contribute)
20 | - [Reporting Bugs](#reporting-bugs)
21 | - [Suggesting Enhancements](#suggesting-enhancements)
22 | - [Your First Code Contribution](#your-first-code-contribution)
23 | - [Improving The Documentation](#improving-the-documentation)
24 | - [Styleguides](#styleguides)
25 | - [Commit Messages](#commit-messages)
26 | - [Join The Project Team](#join-the-project-team)
27 |
28 |
29 | ## Code of Conduct
30 |
31 | This project and everyone participating in it is governed by the
32 | [Fast GraphRAG Code of Conduct](https://github.com/circlemind-ai/fast-graphrag/blob/main/CODE_OF_CONDUCT.md).
33 | By participating, you are expected to uphold this code. Please report unacceptable behavior
34 | to .
35 |
36 |
37 | ## I Have a Question
38 |
39 | First off, make sure to join the discord community: https://discord.gg/McpuSEkR
40 |
41 | Before you ask a question, it is best to search for existing [Issues](https://github.com/circlemind-ai/fast-graphrag/issues) that might help you. In case you have found a suitable issue and still need clarification, you can write your question in this issue. It is also advisable to search the internet for answers first.
42 |
43 | If you then still feel the need to ask a question and need clarification, we recommend the following:
44 |
45 | - Open an [Issue](https://github.com/circlemind-ai/fast-graphrag/issues/new).
46 | - Provide as much context as you can about what you're running into.
47 | - Provide project and platform versions (python, os, etc), depending on what seems relevant.
48 |
49 | We will then take care of the issue as soon as possible.
50 |
51 | ## I Want To Contribute
52 |
53 | > ### Legal Notice
54 | > When contributing to this project, you must agree that you have authored 100% of the content, that you have the necessary rights to the content and that the content you contribute may be provided under the project licence.
55 |
56 | ### Reporting Bugs
57 |
58 |
59 | #### Before Submitting a Bug Report
60 |
61 | A good bug report shouldn't leave others needing to chase you up for more information. Therefore, we ask you to investigate carefully, collect information and describe the issue in detail in your report. Please complete the following steps in advance to help us fix any potential bug as fast as possible.
62 |
63 | - Make sure that you are using the latest version.
64 | - Determine if your bug is really a bug and not an error on your side e.g. using incompatible environment components/versions. If you are looking for support, you might want to check Discord first.
65 | - To see if other users have experienced (and potentially already solved) the same issue you are having, check if there is not already a bug report existing for your bug or error in the [bug tracker](https://github.com/circlemind-ai/fast-graphrag/issues?q=label%3Abug).
66 | - Also make sure to search the internet (including Stack Overflow) to see if users outside of the GitHub community have discussed the issue.
67 | - Collect all important information about the bug
68 |
69 |
70 | #### How Do I Submit a Good Bug Report?
71 |
72 | > You must never report security related issues, vulnerabilities or bugs including sensitive information to the issue tracker, or elsewhere in public. Instead sensitive bugs must be sent by email to security@circlemind.co
73 |
74 | We use GitHub issues to track bugs and errors. If you run into an issue with the project:
75 |
76 | - Open an [Issue](https://github.com/circlemind-ai/fast-graphrag/issues/new). (Since we can't be sure at this point whether it is a bug or not, we ask you not to talk about a bug yet and not to label the issue.)
77 | - Explain the behavior you would expect and the actual behavior.
78 | - Please provide as much context as possible and describe the *reproduction steps* that someone else can follow to recreate the issue on their own. This usually includes your code. For good bug reports you should isolate the problem and create a reduced test case.
79 | - Provide the information you collected in the previous section.
80 |
81 | Once it's filed:
82 |
83 | - The project team will label the issue accordingly.
84 | - A team member will try to reproduce the issue with your provided steps. If there are no reproduction steps or no obvious way to reproduce the issue, the team will ask you for those steps and mark the issue as `needs-repro`. Bugs with the `needs-repro` tag will not be addressed until they are reproduced.
85 | - If the team is able to reproduce the issue, it will be marked `needs-fix`, as well as possibly other tags (such as `critical`), and the issue will be left to be [implemented by someone](#your-first-code-contribution).
86 |
87 |
88 |
89 |
90 | ### Suggesting Enhancements
91 |
92 | This section guides you through submitting an enhancement suggestion for Fast GraphRAG, **including completely new features and minor improvements to existing functionality**. Following these guidelines will help maintainers and the community to understand your suggestion and find related suggestions.
93 |
94 |
95 | #### Before Submitting an Enhancement
96 |
97 | - Make sure that you are using the latest version.
98 | - Perform a [search](https://github.com/circlemind-ai/fast-graphrag/issues) to see if the enhancement has already been suggested. If it has, add a comment to the existing issue instead of opening a new one.
99 | - Find out whether your idea fits with the scope and aims of the project. It's up to you to make a strong case to convince the project's developers of the merits of this feature. Keep in mind that we want features that will be useful to the majority of our users and not just a small subset. If you're just targeting a minority of users, consider writing an add-on/plugin library.
100 |
101 |
102 | #### How Do I Submit a Good Enhancement Suggestion?
103 |
104 | Enhancement suggestions are tracked as [GitHub issues](https://github.com/circlemind-ai/fast-graphrag/issues).
105 |
106 | - Use a **clear and descriptive title** for the issue to identify the suggestion.
107 | - Provide a **step-by-step description of the suggested enhancement** in as many details as possible.
108 | - **Describe the current behavior** and **explain which behavior you expected to see instead** and why. At this point you can also tell which alternatives do not work for you.
109 | - **Explain why this enhancement would be useful** to most Fast GraphRAG users. You may also want to point out the other projects that solved it better and which could serve as inspiration.
110 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 circlemind-ai
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
16 |
17 |
Streamlined and promptable Fast GraphRAG framework designed for interpretable, high-precision, agent-driven retrieval workflows.
Looking for a Managed Service? »
18 |
19 |
20 |
27 |
28 | > [!NOTE]
29 | > Using *The Wizard of Oz*, `fast-graphrag` costs $0.08 vs. `graphrag` $0.48 — **a 6x costs saving** that further improves with data size and number of insertions.
30 |
31 | ## News (and Coming Soon)
32 | - [ ] Support for IDF weightening of entities
33 | - [x] Support for generic entities and concepts (initial commit)
34 | - [x] [2024.12.02] Benchmarks comparing Fast GraphRAG to LightRAG, GraphRAG and VectorDBs released [here](https://github.com/circlemind-ai/fast-graphrag/blob/main/benchmarks/README.md)
35 |
36 | ## Features
37 |
38 | - **Interpretable and Debuggable Knowledge:** Graphs offer a human-navigable view of knowledge that can be queried, visualized, and updated.
39 | - **Fast, Low-cost, and Efficient:** Designed to run at scale without heavy resource or cost requirements.
40 | - **Dynamic Data:** Automatically generate and refine graphs to best fit your domain and ontology needs.
41 | - **Incremental Updates:** Supports real-time updates as your data evolves.
42 | - **Intelligent Exploration:** Leverages PageRank-based graph exploration for enhanced accuracy and dependability.
43 | - **Asynchronous & Typed:** Fully asynchronous, with complete type support for robust and predictable workflows.
44 |
45 | Fast GraphRAG is built to fit seamlessly into your retrieval pipeline, giving you the power of advanced RAG, without the overhead of building and designing agentic workflows.
46 |
47 | ## Install
48 |
49 | **Install from source (recommended for best performance)**
50 |
51 | ```bash
52 | # clone this repo first
53 | cd fast_graphrag
54 | poetry install
55 | ```
56 |
57 | **Install from PyPi (recommended for stability)**
58 |
59 | ```bash
60 | pip install fast-graphrag
61 | ```
62 |
63 | ## Quickstart
64 |
65 | Set the OpenAI API key in the environment:
66 |
67 | ```bash
68 | export OPENAI_API_KEY="sk-..."
69 | ```
70 |
71 | Download a copy of *A Christmas Carol* by Charles Dickens:
72 |
73 | ```bash
74 | curl https://raw.githubusercontent.com/circlemind-ai/fast-graphrag/refs/heads/main/mock_data.txt > ./book.txt
75 | ```
76 |
77 | Optional: Set the limit for concurrent requests to the LLM (i.e., to control the number of tasks processed simultaneously by the LLM, this is helpful when running local models)
78 | ```bash
79 | export CONCURRENT_TASK_LIMIT=8
80 | ```
81 |
82 | Use the Python snippet below:
83 |
84 | ```python
85 | from fast_graphrag import GraphRAG
86 |
87 | DOMAIN = "Analyze this story and identify the characters. Focus on how they interact with each other, the locations they explore, and their relationships."
88 |
89 | EXAMPLE_QUERIES = [
90 | "What is the significance of Christmas Eve in A Christmas Carol?",
91 | "How does the setting of Victorian London contribute to the story's themes?",
92 | "Describe the chain of events that leads to Scrooge's transformation.",
93 | "How does Dickens use the different spirits (Past, Present, and Future) to guide Scrooge?",
94 | "Why does Dickens choose to divide the story into \"staves\" rather than chapters?"
95 | ]
96 |
97 | ENTITY_TYPES = ["Character", "Animal", "Place", "Object", "Activity", "Event"]
98 |
99 | grag = GraphRAG(
100 | working_dir="./book_example",
101 | domain=DOMAIN,
102 | example_queries="\n".join(EXAMPLE_QUERIES),
103 | entity_types=ENTITY_TYPES
104 | )
105 |
106 | with open("./book.txt") as f:
107 | grag.insert(f.read())
108 |
109 | print(grag.query("Who is Scrooge?").response)
110 | ```
111 |
112 | The next time you initialize fast-graphrag from the same working directory, it will retain all the knowledge automatically.
113 |
114 | ## Examples
115 | Please refer to the `examples` folder for a list of tutorials on common use cases of the library:
116 | - `custom_llm.py`: a brief example on how to configure fast-graphrag to run with different OpenAI API compatible language models and embedders;
117 | - `checkpointing.ipynb`: a tutorial on how to use checkpoints to avoid irreversible data corruption;
118 | - `query_parameters.ipynb`: a tutorial on how to use the different query parameters. In particular, it shows how to include references to the used information in the provided answer (using the `with_references=True` parameter).
119 |
120 | ## Contributing
121 |
122 | Whether it's big or small, we love contributions. Contributions are what make the open-source community such an amazing place to learn, inspire, and create. Any contributions you make are greatly appreciated. Check out our [guide](https://github.com/circlemind-ai/fast-graphrag/blob/main/CONTRIBUTING.md) to see how to get started.
123 |
124 | Not sure where to get started? You can join our [Discord](https://discord.gg/DvY2B8u4sA) and ask us any questions there.
125 |
126 | ## Philosophy
127 |
128 | Our mission is to increase the number of successful GenAI applications in the world. To do that, we build memory and data tools that enable LLM apps to leverage highly specialized retrieval pipelines without the complexity of setting up and maintaining agentic workflows.
129 |
130 | Fast GraphRAG currently exploit the personalized pagerank algorithm to explore the graph and find the most relevant pieces of information to answer your query. For an overview on why this works, you can check out the HippoRAG paper [here](https://arxiv.org/abs/2405.14831).
131 |
132 | ## Open-source or Managed Service
133 |
134 | This repo is under the MIT License. See [LICENSE.txt](https://github.com/circlemind-ai/fast-graphrag/blob/main/LICENSE) for more information.
135 |
136 | The fastest and most reliable way to get started with Fast GraphRAG is using our managed service. Your first 100 requests are free every month, after which you pay based on usage.
137 |
138 |
139 |
140 |
141 |
142 | To learn more about our managed service, [book a demo](https://circlemind.co/demo) or see our [docs](https://docs.circlemind.co/quickstart).
143 |
--------------------------------------------------------------------------------
/banner.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/circlemind-ai/fast-graphrag/b370efe01ef836af292a3713d59b2ec23d2fe7c4/banner.png
--------------------------------------------------------------------------------
/benchmarks/README.md:
--------------------------------------------------------------------------------
1 | ## Benchmarks
2 | We validate the benchmark results provided in [HippoRAG](https://arxiv.org/abs/2405.14831), as well as comparing with other methods:
3 | - NaiveRAG (vector dbs) using the OpenAI embedder `text-embedding-3-small`
4 | - [LightRAG](https://github.com/HKUDS/LightRAG)
5 | - [GraphRAG](https://github.com/gusye1234/nano-graphrag) (we use the implementation provided by `nano-graphrag`, based on the original [Microsoft GraphRAG](https://github.com/microsoft/graphrag))
6 |
7 | ### Results
8 | **2wikimultihopQA**
9 | | # Queries | Method | All queries % | Multihop only % |
10 | |----------:|:--------:|--------------:|----------------:|
11 | | 51||||
12 | | | VectorDB| 0.49| 0.32|
13 | | | LightRAG| 0.47| 0.32|
14 | | | GraphRAG| 0.75| 0.68|
15 | | |**Circlemind**| **0.96**| **0.95**|
16 | | 101||||
17 | | | VectorDB| 0.42| 0.23|
18 | | | LightRAG| 0.45| 0.28|
19 | | | GraphRAG| 0.73| 0.64|
20 | | |**Circlemind**| **0.93**| **0.90**|
21 |
22 | **Circlemind is up to 4x more accurate than VectorDB RAG.**
23 |
24 | **HotpotQA**
25 | | # Queries | Method | All queries % |
26 | |----------:|:--------:|--------------:|
27 | | 101|||
28 | | | VectorDB| 0.78|
29 | | | LightRAG| 0.55|
30 | | | GraphRAG| -*|
31 | | |**Circlemind**| **0.84**|
32 |
33 | *: crashes after half an hour of processing
34 |
35 | Below, find the insertion times for the 2wikimultihopqa benchmark (~800 chunks):
36 | | Method | Time (minutes) |
37 | |:--------:|-----------------:|
38 | | VectorDB| ~0.3|
39 | | LightRAG| ~25|
40 | | GraphRAG| ~40|
41 | |**Circlemind**| ~1.5|
42 |
43 | **Circlemind is 27x faster than GraphRAG while also being over 40% more accurate in retrieval.**
44 |
45 | ### Run it yourself
46 | The scripts in this directory will generate and evaluate the 2wikimultihopqa datasets on a subsets of 51 and 101 queries with the same methodology as in the HippoRAG paper. In particular, we evaluate the retrieval capabilities of each method, mesauring the percentage of queries for which all the required evidence was retrieved. We preloaded the results so it is enough to run `evaluate_dbs.xx` to get the numbers. You can also run `create_dbs.xx` to regenerate the databases for the different methods.
47 |
48 | A couple of NOTES:
49 | - you will need to set an OPENAI_API_KEY;
50 | - LightRAG and GraphRAG could take a over an 1 hour to process and they can be expensive;
51 | - when pip installing LightRAG, not all dependencies are added; to run it we simply deleted all the imports of each missing dependency (since we use OpenAI they are not necessary).
52 | - we also benchmarked on the HotpotQA dataset (we will soon release the code for that as well).
53 |
54 | The output will look similar to the following (the exact numbers could vary based on your graph configuration)
55 | ```
56 | Evaluation of the performance of different RAG methods on 2wikimultihopqa (51 queries)
57 |
58 | VectorDB
59 | Loading dataset...
60 | [all questions] Percentage of queries with perfect retrieval: 0.49019607843137253
61 | [multihop only] Percentage of queries with perfect retrieval: 0.32432432432432434
62 |
63 | LightRAG [local mode]
64 | Loading dataset...
65 | Percentage of queries with perfect retrieval: 0.47058823529411764
66 | [multihop only] Percentage of queries with perfect retrieval: 0.32432432432432434
67 |
68 | GraphRAG [local mode]
69 | Loading dataset...
70 | [all questions] Percentage of queries with perfect retrieval: 0.7450980392156863
71 | [multihop only] Percentage of queries with perfect retrieval: 0.6756756756756757
72 |
73 | Circlemind
74 | Loading dataset...
75 | [all questions] Percentage of queries with perfect retrieval: 0.9607843137254902
76 | [multihop only] Percentage of queries with perfect retrieval: 0.9459459459459459
77 |
78 |
79 | Evaluation of the performance of different RAG methods on 2wikimultihopqa (101 queries)
80 |
81 | VectorDB
82 | Loading dataset...
83 | [all questions] Percentage of queries with perfect retrieval: 0.4158415841584158
84 | [multihop only] Percentage of queries with perfect retrieval: 0.2318840579710145
85 |
86 | LightRAG [local mode]
87 | Loading dataset...
88 | [all questions] Percentage of queries with perfect retrieval: 0.44554455445544555
89 | [multihop only] Percentage of queries with perfect retrieval: 0.2753623188405797
90 |
91 | GraphRAG [local mode]
92 | Loading dataset...
93 | [all questions] Percentage of queries with perfect retrieval: 0.7326732673267327
94 | [multihop only] Percentage of queries with perfect retrieval: 0.6376811594202898
95 |
96 | Circlemind
97 | Loading dataset...
98 | [all questions] Percentage of queries with perfect retrieval: 0.9306930693069307
99 | [multihop only] Percentage of queries with perfect retrieval: 0.8985507246376812
100 | ```
101 |
--------------------------------------------------------------------------------
/benchmarks/_domain.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List
2 |
3 | DOMAIN: Dict[str, str] = {
4 | "2wikimultihopqa": """Analyse the following passage and identify the people, creative works, and places mentioned in it. Your goal is to create an RDF (Resource Description Framework) graph from the given text.
5 | IMPORTANT: among other entities and relationships you find, make sure to extract as separate entities (to be connected with the main one) a person's
6 | role as a family member (such as 'son', 'uncle', 'wife', ...), their profession (such as 'director'), and the location
7 | where they live or work. Pay attention to the spelling of the names.""", # noqa: E501
8 | "hotpotqa": """Analyse the following passage and identify all the entities mentioned in it and their relationships. Your goal is to create an RDF (Resource Description Framework) graph from the given text.
9 | Pay attention to the spelling of the entity names."""
10 | }
11 |
12 | QUERIES: Dict[str, List[str]] = {
13 | "2wikimultihopqa": [
14 | "When did Prince Arthur's mother die?",
15 | "What is the place of birth of Elizabeth II's husband?",
16 | "Which film has the director died later, Interstellar or Harry Potter I?",
17 | "Where does the singer who wrote the song Blank Space work at?",
18 | ],
19 | "hotpotqa": [
20 | "Are Christopher Nolan and Sathish Kalathil both film directors?",
21 | "What language were books being translated into during the era of Haymo of Faversham?",
22 | "Who directed the film that was shot in or around Leland, North Carolina in 1986?",
23 | "Who wrote a song after attending a luau in the Koolauloa District on the island of Oahu in Honolulu County?"
24 | ]
25 | }
26 |
27 | ENTITY_TYPES: Dict[str, List[str]] = {
28 | "2wikimultihopqa": [
29 | "person",
30 | "familiy_role",
31 | "location",
32 | "organization",
33 | "creative_work",
34 | "profession",
35 | ],
36 | "hotpotqa": [
37 | "person",
38 | "familiy_role",
39 | "location",
40 | "organization",
41 | "creative_work",
42 | "profession",
43 | "event",
44 | "year"
45 | ],
46 | }
47 |
--------------------------------------------------------------------------------
/benchmarks/create_db.bat:
--------------------------------------------------------------------------------
1 | :: 2wikimultihopqa benchmark
2 | :: Creating databases
3 | python graph_benchmark.py -n 51 -c
4 | python graph_benchmark.py -n 101 -c
5 |
6 | :: Evaluation (create reports)
7 | python graph_benchmark.py -n 51 -b
8 | python graph_benchmark.py -n 101 -b
--------------------------------------------------------------------------------
/benchmarks/create_dbs.bat:
--------------------------------------------------------------------------------
1 | :: 2wikimultihopqa benchmark
2 | :: Creating databases
3 | python vdb_benchmark.py -n 51 -c
4 | python vdb_benchmark.py -n 101 -c
5 | python lightrag_benchmark.py -n 51 -c
6 | python lightrag_benchmark.py -n 101 -c
7 | python nano_benchmark.py -n 51 -c
8 | python nano_benchmark.py -n 101 -c
9 | python graph_benchmark.py -n 51 -c
10 | python graph_benchmark.py -n 101 -c
11 |
12 | :: Evaluation (create reports)
13 | python vdb_benchmark.py -n 51 -b
14 | python vdb_benchmark.py -n 101 -b
15 | python lightrag_benchmark.py -n 51 -b --mode=local
16 | python lightrag_benchmark.py -n 101 -b --mode=local
17 | python nano_benchmark.py -n 51 -b --mode=local
18 | python nano_benchmark.py -n 101 -b --mode=local
19 | python graph_benchmark.py -n 51 -b
20 | python graph_benchmark.py -n 101 -b
--------------------------------------------------------------------------------
/benchmarks/create_dbs.sh:
--------------------------------------------------------------------------------
1 | # 2wikimultihopqa benchmark
2 | # Creating databases
3 | python vdb_benchmark.py -n 51 -c
4 | python vdb_benchmark.py -n 101 -c
5 | python lightrag_benchmark.py -n 51 -c
6 | python lightrag_benchmark.py -n 101 -c
7 | python nano_benchmark.py -n 51 -c
8 | python nano_benchmark.py -n 101 -c
9 | python graph_benchmark.py -n 51 -c
10 | python graph_benchmark.py -n 101 -c
11 |
12 | # Evaluation (create reports)
13 | python vdb_benchmark.py -n 51 -b
14 | python vdb_benchmark.py -n 101 -b
15 | python lightrag_benchmark.py -n 51 -b --mode=local
16 | python lightrag_benchmark.py -n 101 -b --mode=local
17 | python nano_benchmark.py -n 51 -b --mode=local
18 | python nano_benchmark.py -n 101 -b --mode=local
19 | python graph_benchmark.py -n 51 -b
20 | python graph_benchmark.py -n 101 -b
--------------------------------------------------------------------------------
/benchmarks/db/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/benchmarks/evaluate_dbs.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 | echo Evaluation of the performance of different RAG methods on 2wikimultihopqa (51 queries)
3 | echo.
4 | echo VectorDB
5 | python vdb_benchmark.py -n 51 -s
6 | echo.
7 | echo LightRAG
8 | python lightrag_benchmark.py -n 51 -s --mode=local
9 | echo.
10 | echo GraphRAG
11 | python nano_benchmark.py -n 51 -s --mode=local
12 | echo.
13 | echo Circlemind
14 | python graph_benchmark.py -n 51 -s
15 |
16 | echo.
17 | echo.
18 | echo Evaluation of the performance of different RAG methods on 2wikimultihopqa (101 queries)
19 | echo.
20 | echo VectorDB
21 | python vdb_benchmark.py -n 101 -s
22 | echo.
23 | echo LightRAG
24 | python lightrag_benchmark.py -n 101 -s --mode=local
25 | echo.
26 | echo GraphRAG
27 | python nano_benchmark.py -n 101 -s --mode=local
28 | echo.
29 | echo Circlemind
30 | python graph_benchmark.py -n 101 -s
31 |
--------------------------------------------------------------------------------
/benchmarks/evaluate_dbs.sh:
--------------------------------------------------------------------------------
1 | echo "Evaluation of the performance of different RAG methods on the 2wikimultihopqa (51 queries)";
2 | echo;
3 | echo "VectorDB";
4 | python vdb_benchmark.py -n 51 -s
5 | echo;
6 | echo "LightRAG [local mode]";
7 | python lightrag_benchmark.py -n 51 -s --mode=local
8 | echo;
9 | echo "GraphRAG [local mode]";
10 | python nano_benchmark.py -n 51 -s --mode=local
11 | echo;
12 | echo "Circlemind"
13 | python graph_benchmark.py -n 51 -s
14 |
15 | echo "Evaluation of the performance of different RAG methods on the 2wikimultihopqa (101 queries)";
16 | echo;
17 | echo "VectorDB";
18 | python vdb_benchmark.py -n 101 -s
19 | echo;
20 | echo "LightRAG [local mode]";
21 | python lightrag_benchmark.py -n 101 -s --mode=local
22 | echo;
23 | echo "GraphRAG [local mode]";
24 | python nano_benchmark.py -n 101 -s --mode=local
25 | echo;
26 | echo "Circlemind";
27 | python graph_benchmark.py -n 101 -s
--------------------------------------------------------------------------------
/benchmarks/graph_benchmark.py:
--------------------------------------------------------------------------------
1 | """Benchmarking script for GraphRAG."""
2 |
3 | import argparse
4 | import asyncio
5 | import json
6 | from dataclasses import dataclass, field
7 | from typing import Any, Dict, List, Tuple
8 |
9 | import numpy as np
10 | import xxhash
11 | from _domain import DOMAIN, ENTITY_TYPES, QUERIES
12 | from dotenv import load_dotenv
13 | from tqdm import tqdm
14 |
15 | from fast_graphrag import GraphRAG, QueryParam
16 | from fast_graphrag._utils import get_event_loop
17 |
18 |
19 | @dataclass
20 | class Query:
21 | """Dataclass for a query."""
22 |
23 | question: str = field()
24 | answer: str = field()
25 | evidence: List[Tuple[str, int]] = field()
26 |
27 |
28 | def load_dataset(dataset_name: str, subset: int = 0) -> Any:
29 | """Load a dataset from the datasets folder."""
30 | with open(f"./datasets/{dataset_name}.json", "r") as f:
31 | dataset = json.load(f)
32 |
33 | if subset:
34 | return dataset[:subset]
35 | else:
36 | return dataset
37 |
38 |
39 | def get_corpus(dataset: Any, dataset_name: str) -> Dict[int, Tuple[int | str, str]]:
40 | """Get the corpus from the dataset."""
41 | if dataset_name == "2wikimultihopqa" or dataset_name == "hotpotqa":
42 | passages: Dict[int, Tuple[int | str, str]] = {}
43 |
44 | for datapoint in dataset:
45 | context = datapoint["context"]
46 |
47 | for passage in context:
48 | title, text = passage
49 | title = title.encode("utf-8").decode()
50 | text = "\n".join(text).encode("utf-8").decode()
51 | hash_t = xxhash.xxh3_64_intdigest(text)
52 | if hash_t not in passages:
53 | passages[hash_t] = (title, text)
54 |
55 | return passages
56 | else:
57 | raise NotImplementedError(f"Dataset {dataset_name} not supported.")
58 |
59 |
60 | def get_queries(dataset: Any):
61 | """Get the queries from the dataset."""
62 | queries: List[Query] = []
63 |
64 | for datapoint in dataset:
65 | queries.append(
66 | Query(
67 | question=datapoint["question"].encode("utf-8").decode(),
68 | answer=datapoint["answer"],
69 | evidence=list(datapoint["supporting_facts"]),
70 | )
71 | )
72 |
73 | return queries
74 |
75 |
76 | if __name__ == "__main__":
77 | load_dotenv()
78 |
79 | parser = argparse.ArgumentParser(description="GraphRAG CLI")
80 | parser.add_argument("-d", "--dataset", default="2wikimultihopqa", help="Dataset to use.")
81 | parser.add_argument("-n", type=int, default=0, help="Subset of corpus to use.")
82 | parser.add_argument("-c", "--create", action="store_true", help="Create the graph for the given dataset.")
83 | parser.add_argument("-b", "--benchmark", action="store_true", help="Benchmark the graph for the given dataset.")
84 | parser.add_argument("-s", "--score", action="store_true", help="Report scores after benchmarking.")
85 | args = parser.parse_args()
86 |
87 | print("Loading dataset...")
88 | dataset = load_dataset(args.dataset, subset=args.n)
89 | working_dir = f"./db/graph/{args.dataset}_{args.n}"
90 | corpus = get_corpus(dataset, args.dataset)
91 |
92 | if args.create:
93 | print("Dataset loaded. Corpus:", len(corpus))
94 | grag = GraphRAG(
95 | working_dir=working_dir,
96 | domain=DOMAIN[args.dataset],
97 | example_queries="\n".join(QUERIES),
98 | entity_types=ENTITY_TYPES[args.dataset],
99 | )
100 | grag.insert(
101 | [f"{title}: {corpus}" for _, (title, corpus) in tuple(corpus.items())],
102 | metadata=[{"id": title} for title in tuple(corpus.keys())],
103 | )
104 | if args.benchmark:
105 | queries = get_queries(dataset)
106 | print("Dataset loaded. Queries:", len(queries))
107 | grag = GraphRAG(
108 | working_dir=working_dir,
109 | domain=DOMAIN[args.dataset],
110 | example_queries="\n".join(QUERIES),
111 | entity_types=ENTITY_TYPES[args.dataset],
112 | )
113 |
114 | async def _query_task(query: Query) -> Dict[str, Any]:
115 | answer = await grag.async_query(query.question, QueryParam(only_context=True))
116 | return {
117 | "question": query.question,
118 | "answer": answer.response,
119 | "evidence": [
120 | corpus[chunk.metadata["id"]][0]
121 | if isinstance(chunk.metadata["id"], int)
122 | else chunk.metadata["id"]
123 | for chunk, _ in answer.context.chunks
124 | ],
125 | "ground_truth": [e[0] for e in query.evidence],
126 | }
127 |
128 | async def _run():
129 | await grag.state_manager.query_start()
130 | answers = [
131 | await a
132 | for a in tqdm(asyncio.as_completed([_query_task(query) for query in queries]), total=len(queries))
133 | ]
134 | await grag.state_manager.query_done()
135 | return answers
136 |
137 | answers = get_event_loop().run_until_complete(_run())
138 |
139 | with open(f"./results/graph/{args.dataset}_{args.n}.json", "w") as f:
140 | json.dump(answers, f, indent=4)
141 |
142 | if args.benchmark or args.score:
143 | with open(f"./results/graph/{args.dataset}_{args.n}.json", "r") as f:
144 | answers = json.load(f)
145 |
146 | try:
147 | with open(f"./questions/{args.dataset}_{args.n}.json", "r") as f:
148 | questions_multihop = json.load(f)
149 | except FileNotFoundError:
150 | questions_multihop = []
151 |
152 | # Compute retrieval metrics
153 | retrieval_scores: List[float] = []
154 | retrieval_scores_multihop: List[float] = []
155 |
156 | for answer in answers:
157 | ground_truth = answer["ground_truth"]
158 | predicted_evidence = answer["evidence"]
159 |
160 | p_retrieved: float = len(set(ground_truth).intersection(set(predicted_evidence))) / len(set(ground_truth))
161 | retrieval_scores.append(p_retrieved)
162 |
163 | if answer["question"] in questions_multihop:
164 | retrieval_scores_multihop.append(p_retrieved)
165 |
166 | print(
167 | f"Percentage of queries with perfect retrieval: {np.mean([1 if s == 1.0 else 0 for s in retrieval_scores])}"
168 | )
169 | if len(retrieval_scores_multihop):
170 | print(
171 | f"[multihop] Percentage of queries with perfect retrieval: {
172 | np.mean([1 if s == 1.0 else 0 for s in retrieval_scores_multihop])
173 | }"
174 | )
175 |
--------------------------------------------------------------------------------
/benchmarks/lightrag_benchmark.py:
--------------------------------------------------------------------------------
1 | """Benchmarking script for GraphRAG."""
2 |
3 | import argparse
4 | import asyncio
5 | import json
6 | import os
7 | import re
8 | from dataclasses import dataclass, field
9 | from typing import Any, Dict, List, Tuple
10 |
11 | import numpy as np
12 | import xxhash
13 | from dotenv import load_dotenv
14 | from lightrag import LightRAG, QueryParam
15 | from lightrag.lightrag import always_get_an_event_loop
16 | from lightrag.llm import gpt_4o_mini_complete
17 | from lightrag.utils import logging
18 | from tqdm import tqdm
19 |
20 | logging.getLogger("httpx").setLevel(logging.WARNING)
21 | logging.getLogger("nano-vectordb").setLevel(logging.WARNING)
22 |
23 | @dataclass
24 | class Query:
25 | """Dataclass for a query."""
26 |
27 | question: str = field()
28 | answer: str = field()
29 | evidence: List[Tuple[str, int]] = field()
30 |
31 |
32 | def load_dataset(dataset_name: str, subset: int = 0) -> Any:
33 | """Load a dataset from the datasets folder."""
34 | with open(f"./datasets/{dataset_name}.json", "r") as f:
35 | dataset = json.load(f)
36 |
37 | if subset:
38 | return dataset[:subset]
39 | else:
40 | return dataset
41 |
42 |
43 | def get_corpus(dataset: Any, dataset_name: str) -> Dict[int, Tuple[int | str, str]]:
44 | """Get the corpus from the dataset."""
45 | if dataset_name == "2wikimultihopqa" or dataset_name == "hotpotqa":
46 | passages: Dict[int, Tuple[int | str, str]] = {}
47 |
48 | for datapoint in dataset:
49 | context = datapoint["context"]
50 |
51 | for passage in context:
52 | title, text = passage
53 | title = title.encode("utf-8").decode()
54 | text = "\n".join(text).encode("utf-8").decode()
55 | hash_t = xxhash.xxh3_64_intdigest(text)
56 | if hash_t not in passages:
57 | passages[hash_t] = (title, text)
58 |
59 | return passages
60 | else:
61 | raise NotImplementedError(f"Dataset {dataset_name} not supported.")
62 |
63 |
64 | def get_queries(dataset: Any):
65 | """Get the queries from the dataset."""
66 | queries: List[Query] = []
67 |
68 | for datapoint in dataset:
69 | queries.append(
70 | Query(
71 | question=datapoint["question"].encode("utf-8").decode(),
72 | answer=datapoint["answer"],
73 | evidence=list(datapoint["supporting_facts"]),
74 | )
75 | )
76 |
77 | return queries
78 |
79 |
80 | if __name__ == "__main__":
81 | load_dotenv()
82 |
83 | parser = argparse.ArgumentParser(description="LightRAG CLI")
84 | parser.add_argument("-d", "--dataset", default="2wikimultihopqa", help="Dataset to use.")
85 | parser.add_argument("-n", type=int, default=0, help="Subset of corpus to use.")
86 | parser.add_argument("-c", "--create", action="store_true", help="Create the graph for the given dataset.")
87 | parser.add_argument("-b", "--benchmark", action="store_true", help="Benchmark the graph for the given dataset.")
88 | parser.add_argument("-s", "--score", action="store_true", help="Report scores after benchmarking.")
89 | parser.add_argument("--mode", default="local", help="LightRAG query mode.")
90 | args = parser.parse_args()
91 |
92 | print("Loading dataset...")
93 | dataset = load_dataset(args.dataset, subset=args.n)
94 | working_dir = f"./db/lightrag/{args.dataset}_{args.n}"
95 | corpus = get_corpus(dataset, args.dataset)
96 |
97 | if not os.path.exists(working_dir):
98 | os.mkdir(working_dir)
99 | if args.create:
100 | print("Dataset loaded. Corpus:", len(corpus))
101 | grag = LightRAG(
102 | working_dir=working_dir,
103 | llm_model_func=gpt_4o_mini_complete,
104 | log_level=logging.WARNING
105 | )
106 | grag.insert([f"{title}: {corpus}" for _, (title, corpus) in tuple(corpus.items())])
107 | if args.benchmark:
108 | queries = get_queries(dataset)
109 | print("Dataset loaded. Queries:", len(queries))
110 | grag = LightRAG(
111 | working_dir=working_dir,
112 | llm_model_func=gpt_4o_mini_complete,
113 | log_level=logging.WARNING
114 | )
115 |
116 | async def _query_task(query: Query, mode: str) -> Dict[str, Any]:
117 | answer = await grag.aquery(
118 | query.question, QueryParam(mode=mode, only_need_context=True, max_token_for_text_unit=9000)
119 | )
120 | chunks = [
121 | c.split(",")[1].split(":")[0].lstrip('"')
122 | for c in re.findall(r"\n-----Sources-----\n```csv\n(.*?)\n```", answer, re.DOTALL)[0].split("\r\n")[
123 | 1:-1
124 | ]
125 | ]
126 | return {
127 | "question": query.question,
128 | "answer": "",
129 | "evidence": chunks[:8],
130 | "ground_truth": [e[0] for e in query.evidence],
131 | }
132 |
133 | async def _run(mode: str):
134 | answers = [
135 | await a
136 | for a in tqdm(
137 | asyncio.as_completed([_query_task(query, mode=mode) for query in queries]), total=len(queries)
138 | )
139 | ]
140 | return answers
141 |
142 | answers = always_get_an_event_loop().run_until_complete(_run(mode=args.mode))
143 |
144 | with open(f"./results/lightrag/{args.dataset}_{args.n}_{args.mode}.json", "w") as f:
145 | json.dump(answers, f, indent=4)
146 |
147 | if args.benchmark or args.score:
148 | with open(f"./results/lightrag/{args.dataset}_{args.n}_{args.mode}.json", "r") as f:
149 | answers = json.load(f)
150 |
151 | try:
152 | with open(f"./questions/{args.dataset}_{args.n}.json", "r") as f:
153 | questions_multihop = json.load(f)
154 | except FileNotFoundError:
155 | questions_multihop = []
156 |
157 | # Compute retrieval metrics
158 | retrieval_scores: List[float] = []
159 | retrieval_scores_multihop: List[float] = []
160 |
161 | for answer in answers:
162 | ground_truth = answer["ground_truth"]
163 | predicted_evidence = answer["evidence"]
164 |
165 | p_retrieved: float = len(set(ground_truth).intersection(set(predicted_evidence))) / len(set(ground_truth))
166 | retrieval_scores.append(p_retrieved)
167 |
168 | if answer["question"] in questions_multihop:
169 | retrieval_scores_multihop.append(p_retrieved)
170 |
171 | print(
172 | f"Percentage of queries with perfect retrieval: {np.mean([1 if s == 1.0 else 0 for s in retrieval_scores])}"
173 | )
174 | if len(retrieval_scores_multihop):
175 | print(
176 | f"[multihop] Percentage of queries with perfect retrieval: {
177 | np.mean([1 if s == 1.0 else 0 for s in retrieval_scores_multihop])
178 | }"
179 | )
180 |
--------------------------------------------------------------------------------
/benchmarks/nano_benchmark.py:
--------------------------------------------------------------------------------
1 | """Benchmarking script for GraphRAG."""
2 |
3 | import argparse
4 | import asyncio
5 | import json
6 | import os
7 | import re
8 | from dataclasses import dataclass, field
9 | from typing import Any, Dict, List, Tuple
10 |
11 | import numpy as np
12 | import xxhash
13 | from dotenv import load_dotenv
14 | from nano_graphrag import GraphRAG, QueryParam
15 | from nano_graphrag._llm import gpt_4o_mini_complete
16 | from nano_graphrag._utils import always_get_an_event_loop, logging
17 | from tqdm import tqdm
18 |
19 | logging.getLogger("nano-graphrag").setLevel(logging.WARNING)
20 | logging.getLogger("httpx").setLevel(logging.WARNING)
21 | logging.getLogger("nano-vectordb").setLevel(logging.WARNING)
22 |
23 | @dataclass
24 | class Query:
25 | """Dataclass for a query."""
26 |
27 | question: str = field()
28 | answer: str = field()
29 | evidence: List[Tuple[str, int]] = field()
30 |
31 |
32 | def load_dataset(dataset_name: str, subset: int = 0) -> Any:
33 | """Load a dataset from the datasets folder."""
34 | with open(f"./datasets/{dataset_name}.json", "r") as f:
35 | dataset = json.load(f)
36 |
37 | if subset:
38 | return dataset[:subset]
39 | else:
40 | return dataset
41 |
42 |
43 | def get_corpus(dataset: Any, dataset_name: str) -> Dict[int, Tuple[int | str, str]]:
44 | """Get the corpus from the dataset."""
45 | if dataset_name == "2wikimultihopqa" or dataset_name == "hotpotqa":
46 | passages: Dict[int, Tuple[int | str, str]] = {}
47 |
48 | for datapoint in dataset:
49 | context = datapoint["context"]
50 |
51 | for passage in context:
52 | title, text = passage
53 | title = title.encode("utf-8").decode()
54 | text = "\n".join(text).encode("utf-8").decode()
55 | hash_t = xxhash.xxh3_64_intdigest(text)
56 | if hash_t not in passages:
57 | passages[hash_t] = (title, text)
58 |
59 | return passages
60 | else:
61 | raise NotImplementedError(f"Dataset {dataset_name} not supported.")
62 |
63 |
64 | def get_queries(dataset: Any):
65 | """Get the queries from the dataset."""
66 | queries: List[Query] = []
67 |
68 | for datapoint in dataset:
69 | queries.append(
70 | Query(
71 | question=datapoint["question"].encode("utf-8").decode(),
72 | answer=datapoint["answer"],
73 | evidence=list(datapoint["supporting_facts"]),
74 | )
75 | )
76 |
77 | return queries
78 |
79 |
80 | if __name__ == "__main__":
81 | load_dotenv()
82 |
83 | parser = argparse.ArgumentParser(description="LightRAG CLI")
84 | parser.add_argument("-d", "--dataset", default="2wikimultihopqa", help="Dataset to use.")
85 | parser.add_argument("-n", type=int, default=0, help="Subset of corpus to use.")
86 | parser.add_argument("-c", "--create", action="store_true", help="Create the graph for the given dataset.")
87 | parser.add_argument("-b", "--benchmark", action="store_true", help="Benchmark the graph for the given dataset.")
88 | parser.add_argument("-s", "--score", action="store_true", help="Report scores after benchmarking.")
89 | parser.add_argument("--mode", default="local", help="LightRAG query mode.")
90 | args = parser.parse_args()
91 |
92 | print("Loading dataset...")
93 | dataset = load_dataset(args.dataset, subset=args.n)
94 | working_dir = f"./db/nano/{args.dataset}_{args.n}"
95 | corpus = get_corpus(dataset, args.dataset)
96 |
97 | if not os.path.exists(working_dir):
98 | os.mkdir(working_dir)
99 | if args.create:
100 | print("Dataset loaded. Corpus:", len(corpus))
101 | grag = GraphRAG(
102 | working_dir=working_dir,
103 | best_model_func=gpt_4o_mini_complete
104 | )
105 | grag.insert([f"{title}: {corpus}" for _, (title, corpus) in tuple(corpus.items())])
106 | if args.benchmark:
107 | queries = get_queries(dataset)
108 | print("Dataset loaded. Queries:", len(queries))
109 | grag = GraphRAG(
110 | working_dir=working_dir,
111 | best_model_func=gpt_4o_mini_complete
112 | )
113 |
114 | async def _query_task(query: Query, mode: str) -> Dict[str, Any]:
115 | answer = await grag.aquery(
116 | query.question, QueryParam(mode=mode, only_need_context=True, local_max_token_for_text_unit=9000)
117 | )
118 | chunks = []
119 | for c in re.findall(r"\n-----Sources-----\n```csv\n(.*?)\n```", answer, re.DOTALL)[0].split("\n")[
120 | 1:-1
121 | ]:
122 | try:
123 | chunks.append(c.split(",\t")[1].split(":")[0].lstrip('"'))
124 | except IndexError:
125 | pass
126 | return {
127 | "question": query.question,
128 | "answer": "",
129 | "evidence": chunks[:8],
130 | "ground_truth": [e[0] for e in query.evidence],
131 | }
132 |
133 | async def _run(mode: str):
134 | answers = [
135 | await a
136 | for a in tqdm(
137 | asyncio.as_completed([_query_task(query, mode=mode) for query in queries]), total=len(queries)
138 | )
139 | ]
140 | return answers
141 |
142 | answers = always_get_an_event_loop().run_until_complete(_run(mode=args.mode))
143 |
144 | with open(f"./results/nano/{args.dataset}_{args.n}_{args.mode}.json", "w") as f:
145 | json.dump(answers, f, indent=4)
146 |
147 | if args.benchmark or args.score:
148 | with open(f"./results/nano/{args.dataset}_{args.n}_{args.mode}.json", "r") as f:
149 | answers = json.load(f)
150 |
151 | try:
152 | with open(f"./questions/{args.dataset}_{args.n}.json", "r") as f:
153 | questions_multihop = json.load(f)
154 | except FileNotFoundError:
155 | questions_multihop = []
156 |
157 | # Compute retrieval metrics
158 | retrieval_scores: List[float] = []
159 | retrieval_scores_multihop: List[float] = []
160 |
161 | for answer in answers:
162 | ground_truth = answer["ground_truth"]
163 | predicted_evidence = answer["evidence"]
164 |
165 | p_retrieved: float = len(set(ground_truth).intersection(set(predicted_evidence))) / len(set(ground_truth))
166 | retrieval_scores.append(p_retrieved)
167 |
168 | if answer["question"] in questions_multihop:
169 | retrieval_scores_multihop.append(p_retrieved)
170 |
171 | print(
172 | f"Percentage of queries with perfect retrieval: {np.mean([1 if s == 1.0 else 0 for s in retrieval_scores])}"
173 | )
174 | if len(retrieval_scores_multihop):
175 | print(
176 | f"[multihop] Percentage of queries with perfect retrieval: {
177 | np.mean([1 if s == 1.0 else 0 for s in retrieval_scores_multihop])
178 | }"
179 | )
180 |
--------------------------------------------------------------------------------
/benchmarks/questions/2wikimultihopqa_101.json:
--------------------------------------------------------------------------------
1 | [
2 | "When did Lothair Ii's mother die?",
3 | "What is the place of birth of the performer of song Changed It?",
4 | "Which film has the director who is older, God'S Gift To Women or Aldri Annet Enn Bråk?",
5 | "Which film whose director was born first, El Tonto or The Heart Of Doreon?",
6 | "Who is Raghnall Mac Ruaidhrí's paternal grandfather?",
7 | "Do both films Interview With A Hitman and The Last Coupon have the directors from the same country?",
8 | "What nationality is the director of film Blood Street?",
9 | "What is the place of birth of the director of film Gaby: A True Story?",
10 | "What nationality is the performer of song When The Stars Go Blue?",
11 | "Who is the child of the performer of song Me And Bobby Mcgee?",
12 | "Where was the place of death of Maurice, Prince Of Orange's father?",
13 | "Which country Aleksander Koniecpolski (1620–1659)'s father is from?",
14 | "What is the date of death of the director of film Madame La Presidente?",
15 | "Do both directors of films Wrong Turn 5: Bloodlines and Dark River (2017 Film) have the same nationality?",
16 | "Which film has the director who died first, The Goose Woman or You Can No Longer Remain Silent?",
17 | "Where was the director of film The Private Life Of Cinema born?",
18 | "Where did Coulson Wallop's father study?",
19 | "Why did John Middleton Murry's wife die?",
20 | "Where was the composer of film Billy Elliot born?",
21 | "What is the place of birth of Lisbeth Palme's husband?",
22 | "Who is the father-in-law of Sisowath Kossamak?",
23 | "Who is Mugain's mother-in-law?",
24 | "Where did Theodore Salisbury Woolsey's father study?",
25 | "What is the place of birth of the director of film The Return Of Swamp Thing?",
26 | "Where was the place of death of the performer of song I Can'T See Myself Leaving You?",
27 | "Which film has the director who was born later, Playing It Wild or I'Ll Be Going Now?",
28 | "What nationality is Beatrice I, Countess Of Burgundy's husband?",
29 | "Which film has the director born later, Christ Walking On The Water or 45 Fathers?",
30 | "Where was the performer of song Come Dance With Me (Song) born?",
31 | "Where does the director of film Talk About A Stranger work at?",
32 | "What is the place of birth of the composer of film Inherent Vice (Film)?",
33 | "Which film whose director is younger, Dangerously They Live or Salad By The Roots?",
34 | "Where did Sylvia Burka's husband die?",
35 | "Which film whose director is younger, Phalitamsha or Gladiators Seven?",
36 | "Which film has the director who was born later, Henry Goes Arizona or The Blue Collar Worker And The Hairdresser In A Whirl Of Sex And Politics?",
37 | "Which film has the director who was born first, Tombstone Rashomon or Waiting For The Clouds?",
38 | "Where was the performer of song B Boy (Song) detained?",
39 | "Which film has the director who was born later, Illusions (1982 Film) or It'S A Wonderful Afterlife?",
40 | "Where did the composer of film The Straw Hat die?",
41 | "Where does the director of film A Nest Of Noblemen work at?",
42 | "Who is the father-in-law of John Ernest, Duke Of Saxe-Eisenach?",
43 | "Who is the father of the director of film Palo Alto (2013 Film)?",
44 | "Who is the child of the director of film Los Pagares De Mendieta?",
45 | "Who is the spouse of the director of film My Three Merry Widows?",
46 | "Which country the composer of film Thunder On The Hill is from?",
47 | "Which film has the director born later, Arrête Ton Cinéma or Agni (2004 Film)?",
48 | "Where did Prince Ferdinand Of Bavaria's mother die?",
49 | "Who is the child of the director of film An Event?",
50 | "What nationality is the director of film Good People (Film)?",
51 | "What is the place of birth of the director of film Fortunella (Film)?",
52 | "Where was the husband of Joanna Elisabeth Of Holstein-Gottorp born?",
53 | "Where was the place of death of Abdul-Aziz Bin Muhammad's father?",
54 | "Which film has the director who died later, Love, Honor And Oh-Baby! or I Cover The Underworld?",
55 | "Where was the director of film The Circus Cyclone born?",
56 | "Where did the director of film Dancing In The Rain (Film) die?",
57 | "Where did Saw Thanda's husband die?",
58 | "Which film has the director who was born earlier, The Marriage Of Princess Demidoff or The Pocket-Knife?",
59 | "Which film has the director who is older than the other, Airheads or Return To Cabin By The Lake? ",
60 | "Which film has the director died later, Lost In The Stratosphere or Blind Man'S Eyes?",
61 | "What nationality is the performer of song Am I Wrong (Étienne De Crécy Song)?",
62 | "Who is the mother of the director of film Atomised (Film)?",
63 | "What is the date of birth of Henry I Of Ziębice's father?",
64 | "Where was the composer of song Back In The U.S.A. born?",
65 | "Which film has the director who was born later, The First Day Of Freedom or Malabimba – The Malicious Whore?",
66 | "Who is Sibyl Hathaway's child-in-law?",
67 | "Which film has the director born first, Sins Of Madeleine or Captain Apache?",
68 | "Who is Godomar Ii's stepmother?",
69 | "Where was the director of film The Outlaw Express born?",
70 | "Which film has the director who died later, 45 Calibre Echo or Bons Baisers De Hong Kong?",
71 | "What nationality is Lamprocles's father?",
72 | "Which country the performer of song I Like Control is from?",
73 | "Which film has the director born later, Romance On The Run or The Palace Of Angels?",
74 | "Which film has the director born first, Mord Em'Ly or Ek Hi Bhool (1940 Film)?",
75 | "Which film has the director born earlier, The Korean Wedding Chest or True To The Navy?",
76 | "Who is Marianus V Of Arborea's mother?",
77 | "Who is the paternal grandfather of Margaret Of Bavaria, Marchioness Of Mantua?"
78 | ]
79 |
--------------------------------------------------------------------------------
/benchmarks/questions/2wikimultihopqa_51.json:
--------------------------------------------------------------------------------
1 | [
2 | "When did Lothair Ii's mother die?",
3 | "What is the place of birth of the performer of song Changed It?",
4 | "Which film has the director who is older, God'S Gift To Women or Aldri Annet Enn Bråk?",
5 | "Which film whose director was born first, El Tonto or The Heart Of Doreon?",
6 | "Who is Raghnall Mac Ruaidhrí's paternal grandfather?",
7 | "Do both films Interview With A Hitman and The Last Coupon have the directors from the same country?",
8 | "What nationality is the director of film Blood Street?",
9 | "What is the place of birth of the director of film Gaby: A True Story?",
10 | "What nationality is the performer of song When The Stars Go Blue?",
11 | "Who is the child of the performer of song Me And Bobby Mcgee?",
12 | "Where was the place of death of Maurice, Prince Of Orange's father?",
13 | "Which country Aleksander Koniecpolski (1620–1659)'s father is from?",
14 | "What is the date of death of the director of film Madame La Presidente?",
15 | "Do both directors of films Wrong Turn 5: Bloodlines and Dark River (2017 Film) have the same nationality?",
16 | "Which film has the director who died first, The Goose Woman or You Can No Longer Remain Silent?",
17 | "Where was the director of film The Private Life Of Cinema born?",
18 | "Where did Coulson Wallop's father study?",
19 | "Why did John Middleton Murry's wife die?",
20 | "Where was the composer of film Billy Elliot born?",
21 | "What is the place of birth of Lisbeth Palme's husband?",
22 | "Who is the father-in-law of Sisowath Kossamak?",
23 | "Who is Mugain's mother-in-law?",
24 | "Where did Theodore Salisbury Woolsey's father study?",
25 | "What is the place of birth of the director of film The Return Of Swamp Thing?",
26 | "Where was the place of death of the performer of song I Can'T See Myself Leaving You?",
27 | "Which film has the director who was born later, Playing It Wild or I'Ll Be Going Now?",
28 | "What nationality is Beatrice I, Countess Of Burgundy's husband?",
29 | "Which film has the director born later, Christ Walking On The Water or 45 Fathers?",
30 | "Where was the performer of song Come Dance With Me (Song) born?",
31 | "Where does the director of film Talk About A Stranger work at?",
32 | "What is the place of birth of the composer of film Inherent Vice (Film)?",
33 | "Which film whose director is younger, Dangerously They Live or Salad By The Roots?",
34 | "Where did Sylvia Burka's husband die?",
35 | "Which film whose director is younger, Phalitamsha or Gladiators Seven?",
36 | "Which film has the director who was born later, Henry Goes Arizona or The Blue Collar Worker And The Hairdresser In A Whirl Of Sex And Politics?",
37 | "Which film has the director who was born first, Tombstone Rashomon or Waiting For The Clouds?",
38 | "Where was the performer of song B Boy (Song) detained?",
39 | "Which film has the director who was born later, Illusions (1982 Film) or It'S A Wonderful Afterlife?",
40 | "Where did the composer of film The Straw Hat die?",
41 | "Where does the director of film A Nest Of Noblemen work at?"
42 | ]
--------------------------------------------------------------------------------
/demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/circlemind-ai/fast-graphrag/b370efe01ef836af292a3713d59b2ec23d2fe7c4/demo.gif
--------------------------------------------------------------------------------
/examples/checkpointing.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## Checkpointing Tutorial\n",
8 | "\n",
9 | "To properly function, fast-graphrag mantains a state synchronised among different types of databases. It is highly unlikely, but it can happend that during any reading/writing operation any of these storages can get corrupted. So, we are introducing checkpointing to signficiantly reduce the impact of this unpleasant situation. To enable checkpointing, simply set `n_checkpoints = k`, with `k > 0`:"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "from fast_graphrag import GraphRAG\n",
19 | "\n",
20 | "DOMAIN = \"Analyze this story and identify the characters. Focus on how they interact with each other, the locations they explore, and their relationships.\"\n",
21 | "\n",
22 | "EXAMPLE_QUERIES = [\n",
23 | " \"What is the significance of Christmas Eve in A Christmas Carol?\",\n",
24 | " \"How does the setting of Victorian London contribute to the story's themes?\",\n",
25 | " \"Describe the chain of events that leads to Scrooge's transformation.\",\n",
26 | " \"How does Dickens use the different spirits (Past, Present, and Future) to guide Scrooge?\",\n",
27 | " \"Why does Dickens choose to divide the story into \\\"staves\\\" rather than chapters?\"\n",
28 | "]\n",
29 | "\n",
30 | "ENTITY_TYPES = [\"Character\", \"Animal\", \"Place\", \"Object\", \"Activity\", \"Event\"]\n",
31 | "\n",
32 | "grag = GraphRAG(\n",
33 | " working_dir=\"./book_example\",\n",
34 | " n_checkpoints=2, # Number of checkpoints to keep\n",
35 | " domain=DOMAIN,\n",
36 | " example_queries=\"\\n\".join(EXAMPLE_QUERIES),\n",
37 | " entity_types=ENTITY_TYPES\n",
38 | ")\n",
39 | "\n",
40 | "with open(\"./book.txt\") as f:\n",
41 | " grag.insert(f.read())\n",
42 | "\n",
43 | "print(grag.query(\"Who is Scrooge?\").response)"
44 | ]
45 | },
46 | {
47 | "cell_type": "markdown",
48 | "metadata": {},
49 | "source": [
50 | "DONE! Now the library will automatically keep in memory the `k` most recent checkpoints and rollback to them if necessary."
51 | ]
52 | },
53 | {
54 | "cell_type": "markdown",
55 | "metadata": {},
56 | "source": [
57 | "NOTES:\n",
58 | "- if you want to migrate a project from no checkpoints to checkpoints, simply set the flag and run a insert operation (even an empty one should do the job). Check that the checkpoint was created succesfully by querying the graph. If eveything worked correctly, you should see a new directory in you storage working dir (in the case above, it would be something like `./book_example/1731555907`). You can now safely remove all the files in the root dir `./book_example/*.*`.\n",
59 | "- if you want to stop using checkpoints, simply copy all the files from the most recent checkpoints folder in the root dir, delete all the \"number\" folders and unset `n_checkpoints`."
60 | ]
61 | }
62 | ],
63 | "metadata": {
64 | "kernelspec": {
65 | "display_name": "cm",
66 | "language": "python",
67 | "name": "python3"
68 | },
69 | "language_info": {
70 | "name": "python",
71 | "version": "3.12.7"
72 | }
73 | },
74 | "nbformat": 4,
75 | "nbformat_minor": 2
76 | }
77 |
--------------------------------------------------------------------------------
/examples/custom_llm.py:
--------------------------------------------------------------------------------
1 | """Example usage of GraphRAG with custom LLM and Embedding services compatible with the OpenAI API."""
2 |
3 | from typing import List
4 |
5 | import instructor
6 | from dotenv import load_dotenv
7 |
8 | from fast_graphrag import GraphRAG
9 | from fast_graphrag._llm import OpenAIEmbeddingService, OpenAILLMService
10 |
11 | load_dotenv()
12 |
13 | DOMAIN = ""
14 | QUERIES: List[str] = []
15 | ENTITY_TYPES: List[str] = []
16 |
17 | working_dir = "./examples/ignore/hp"
18 | grag = GraphRAG(
19 | working_dir=working_dir,
20 | domain=DOMAIN,
21 | example_queries="\n".join(QUERIES),
22 | entity_types=ENTITY_TYPES,
23 | config=GraphRAG.Config(
24 | llm_service=OpenAILLMService(
25 | model="your-llm-model",
26 | base_url="llm.api.url.com",
27 | api_key="your-api-key",
28 | mode=instructor.Mode.JSON,
29 | api_version="your-llm-api_version",
30 | client="openai or azure"
31 | ),
32 | embedding_service=OpenAIEmbeddingService(
33 | model="your-embedding-model",
34 | base_url="emb.api.url.com",
35 | api_key="your-api-key",
36 | embedding_dim=512, # the output embedding dim of the chosen model
37 | api_version="your-llm-api_version",
38 | client="openai or azure"
39 | ),
40 | ),
41 | )
42 |
--------------------------------------------------------------------------------
/examples/gemini_example.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from fast_graphrag import GraphRAG, QueryParam
4 | import asyncio
5 | from fast_graphrag._utils import logger
6 | from fast_graphrag._llm import GeminiLLMService, GeminiEmbeddingService #, VoyageAIEmbeddingService
7 |
8 | WORKING_DIR="./book_example"
9 | if not os.path.exists(WORKING_DIR):
10 | os.mkdir(WORKING_DIR)
11 |
12 | GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
13 |
14 | DOMAIN = "Analyze this story and identify the characters. Focus on how they interact with each other, the locations they explore, and their relationships."
15 |
16 | EXAMPLE_QUERIES = [
17 | "What is the significance of Christmas Eve in A Christmas Carol?",
18 | "How does the setting of Victorian London contribute to the story's themes?",
19 | "Describe the chain of events that leads to Scrooge's transformation.",
20 | "How does Dickens use the different spirits (Past, Present, and Future) to guide Scrooge?",
21 | "Why does Dickens choose to divide the story into \"staves\" rather than chapters?"
22 | ]
23 |
24 | # Custom entity types for story analysis
25 | ENTITY_TYPES = ["Character", "Animal", "Place", "Object", "Activity", "Event"]
26 |
27 | # For VoyageAI Embeddings, higher rate limits
28 | # from fast_graphrag._llm import VoyageAIEmbeddingService
29 | # VOYAGE_API_KEY = os.getenv("VOYAGE_API_KEY")
30 |
31 | # For PDF Processing, uses langchain and PyMuPDF
32 | #from langchain_community.document_loaders import PyMuPDFLoader
33 | #async def process_pdf(file_path: str) -> str:
34 | # """Process PDFs with error handling"""
35 | # if not os.path.exists(file_path):
36 | # raise FileNotFoundError(f"PDF file not found: {file_path}")
37 | #
38 | # try:
39 | # loader = PyMuPDFLoader(file_path)
40 | # pages = ""
41 | # for page in loader.lazy_load():
42 | # pages += page.page_content
43 | # return pages
44 | #
45 | # except Exception as e:
46 | # raise
47 |
48 | # Text Processing Function
49 | async def process_text(file_path: str) -> str:
50 | """Process text file with encoding handling"""
51 | if not os.path.exists(file_path):
52 | raise FileNotFoundError(f"Text file not found: {file_path}")
53 |
54 | try:
55 | # Try UTF-8 first, fallback to other encodings
56 | encodings = ['utf-8', 'ascii', 'iso-8859-1', 'cp1252']
57 | text = None
58 |
59 | for encoding in encodings:
60 | try:
61 | with open(file_path, "r", encoding=encoding) as f:
62 | text = f.read()
63 | break
64 | except UnicodeDecodeError:
65 | continue
66 |
67 | if text is None:
68 | raise UnicodeError("Failed to decode file with any supported encoding")
69 |
70 | # Clean and normalize text
71 | text = text.encode('ascii', 'ignore').decode('ascii')
72 | return text.strip()
73 |
74 | except Exception as e:
75 | logger.exception("An error occurred:", exc_info=True)
76 | raise
77 |
78 | async def streaming_query_loop(rag: GraphRAG):
79 | """Basic query loop for repeated questions"""
80 | print("\nStreaming Query Interface (type 'exit' to quit)")
81 | print("="*50)
82 |
83 | while True:
84 | try:
85 | query = input("\nYou: ").strip()
86 | if query.lower() == 'exit':
87 | print("\nExiting chat...")
88 | break
89 |
90 | print("Assistant: ", end='', flush=True)
91 |
92 | try:
93 | # Higher token limits for Gemini context windows
94 | response = await rag.async_query(
95 | query,
96 | params=QueryParam(with_references=False, only_context=False, entities_max_tokens=250000, relations_max_tokens=200000, chunks_max_tokens=500000)
97 | )
98 | print(response.response)
99 | except Exception as e:
100 | import traceback
101 | traceback.print_exc()
102 |
103 | except KeyboardInterrupt:
104 | print("\nInterrupted by user")
105 | break
106 | except Exception as e:
107 | logger.exception("An error occurred:", exc_info=True)
108 | continue
109 |
110 | ### fast-graphrag example for Gemini
111 | async def main():
112 | try:
113 | grag = GraphRAG(
114 | working_dir=WORKING_DIR,
115 | domain=DOMAIN,
116 | example_queries="\n".join(EXAMPLE_QUERIES),
117 | entity_types=ENTITY_TYPES,
118 | config=GraphRAG.Config(
119 | # Ensure necessary APIs have been enabled in Google Cloud, namely the Generative API
120 | # Supports Vertex usage via passing project_id and locatio, or using an 'Express' API key if available. More aggressive rate limiting will be required
121 | # Recommendation optionis using API keys from AI Studio, enabling Billing on the studio's project for much higher rate limits (2000 RPM for 2.0 Flash as of Feb 2025)
122 | llm_service = GeminiLLMService(
123 | model="gemini-2.0-flash",
124 | api_key=GEMINI_API_KEY,
125 | temperature=0.7,
126 | rate_limit_per_minute=True,
127 | rate_limit_per_second=True,
128 | max_requests_per_minute=1950,
129 | max_requests_per_second=500
130 | ),
131 | embedding_service=GeminiEmbeddingService(
132 | api_key=GEMINI_API_KEY,
133 | max_requests_concurrent=100,
134 | ),
135 | # for Voyage AI embeddings with higher rate limits than Vertex
136 | #embedding_service=VoyageAIEmbeddingService(
137 | # model="voyage-3-large",
138 | # api_key=VOYAGE_API_KEY,
139 | # embedding_dim=1024, # the output embedding dim of the chosen model
140 | #),
141 | ),
142 | )
143 |
144 | file = await process_text("./book.txt")
145 | await grag.async_insert(file)
146 |
147 | await streaming_query_loop(grag)
148 | except Exception as e:
149 | logger.exception("An error occurred:", exc_info=True)
150 |
151 |
152 | if __name__ == "__main__":
153 | asyncio.run(main())
154 |
--------------------------------------------------------------------------------
/fast_graphrag/__init__.py:
--------------------------------------------------------------------------------
1 | """Top-level package for GraphRAG."""
2 |
3 | __all__ = ["GraphRAG", "QueryParam"]
4 |
5 | from dataclasses import dataclass, field
6 | from typing import Type
7 |
8 | from fast_graphrag._llm import DefaultEmbeddingService, DefaultLLMService
9 | from fast_graphrag._llm._base import BaseEmbeddingService
10 | from fast_graphrag._llm._llm_openai import BaseLLMService
11 | from fast_graphrag._policies._base import BaseGraphUpsertPolicy
12 | from fast_graphrag._policies._graph_upsert import (
13 | DefaultGraphUpsertPolicy,
14 | EdgeUpsertPolicy_UpsertIfValidNodes,
15 | EdgeUpsertPolicy_UpsertValidAndMergeSimilarByLLM,
16 | NodeUpsertPolicy_SummarizeDescription,
17 | )
18 | from fast_graphrag._policies._ranking import RankingPolicy_TopK, RankingPolicy_WithThreshold
19 | from fast_graphrag._services import (
20 | BaseChunkingService,
21 | BaseInformationExtractionService,
22 | BaseStateManagerService,
23 | DefaultChunkingService,
24 | DefaultInformationExtractionService,
25 | DefaultStateManagerService,
26 | )
27 | from fast_graphrag._storage import (
28 | DefaultGraphStorage,
29 | DefaultGraphStorageConfig,
30 | DefaultIndexedKeyValueStorage,
31 | DefaultVectorStorage,
32 | DefaultVectorStorageConfig,
33 | )
34 | from fast_graphrag._storage._base import BaseGraphStorage
35 | from fast_graphrag._storage._namespace import Workspace
36 | from fast_graphrag._types import TChunk, TEmbedding, TEntity, THash, TId, TIndex, TRelation
37 |
38 | from ._graphrag import BaseGraphRAG, QueryParam
39 |
40 |
41 | @dataclass
42 | class GraphRAG(BaseGraphRAG[TEmbedding, THash, TChunk, TEntity, TRelation, TId]):
43 | """A class representing a Graph-based Retrieval-Augmented Generation system."""
44 |
45 | @dataclass
46 | class Config:
47 | """Configuration for the GraphRAG class."""
48 |
49 | chunking_service_cls: Type[BaseChunkingService[TChunk]] = field(default=DefaultChunkingService)
50 | information_extraction_service_cls: Type[BaseInformationExtractionService[TChunk, TEntity, TRelation, TId]] = (
51 | field(default=DefaultInformationExtractionService)
52 | )
53 | information_extraction_upsert_policy: BaseGraphUpsertPolicy[TEntity, TRelation, TId] = field(
54 | default_factory=lambda: DefaultGraphUpsertPolicy(
55 | config=NodeUpsertPolicy_SummarizeDescription.Config(),
56 | nodes_upsert_cls=NodeUpsertPolicy_SummarizeDescription,
57 | edges_upsert_cls=EdgeUpsertPolicy_UpsertIfValidNodes,
58 | )
59 | )
60 | state_manager_cls: Type[BaseStateManagerService[TEntity, TRelation, THash, TChunk, TId, TEmbedding]] = field(
61 | default=DefaultStateManagerService
62 | )
63 |
64 | llm_service: BaseLLMService = field(default_factory=lambda: DefaultLLMService())
65 | embedding_service: BaseEmbeddingService = field(default_factory=lambda: DefaultEmbeddingService())
66 |
67 | graph_storage: BaseGraphStorage[TEntity, TRelation, TId] = field(
68 | default_factory=lambda: DefaultGraphStorage(DefaultGraphStorageConfig(node_cls=TEntity, edge_cls=TRelation))
69 | )
70 | entity_storage: DefaultVectorStorage[TIndex, TEmbedding] = field(
71 | default_factory=lambda: DefaultVectorStorage(
72 | DefaultVectorStorageConfig()
73 | )
74 | )
75 | chunk_storage: DefaultIndexedKeyValueStorage[THash, TChunk] = field(
76 | default_factory=lambda: DefaultIndexedKeyValueStorage(None)
77 | )
78 |
79 | entity_ranking_policy: RankingPolicy_WithThreshold = field(
80 | default_factory=lambda: RankingPolicy_WithThreshold(RankingPolicy_WithThreshold.Config(threshold=0.005))
81 | )
82 | relation_ranking_policy: RankingPolicy_TopK = field(
83 | default_factory=lambda: RankingPolicy_TopK(RankingPolicy_TopK.Config(top_k=64))
84 | )
85 | chunk_ranking_policy: RankingPolicy_TopK = field(
86 | default_factory=lambda: RankingPolicy_TopK(RankingPolicy_TopK.Config(top_k=8))
87 | )
88 | node_upsert_policy: NodeUpsertPolicy_SummarizeDescription = field(
89 | default_factory=lambda: NodeUpsertPolicy_SummarizeDescription()
90 | )
91 | edge_upsert_policy: EdgeUpsertPolicy_UpsertValidAndMergeSimilarByLLM = field(
92 | default_factory=lambda: EdgeUpsertPolicy_UpsertValidAndMergeSimilarByLLM()
93 | )
94 |
95 | def __post_init__(self):
96 | """Initialize the GraphRAG Config class."""
97 | self.entity_storage.embedding_dim = self.embedding_service.embedding_dim
98 |
99 | config: Config = field(default_factory=Config)
100 |
101 | def __post_init__(self):
102 | """Initialize the GraphRAG class."""
103 | self.llm_service = self.config.llm_service
104 | self.embedding_service = self.config.embedding_service
105 | self.chunking_service = self.config.chunking_service_cls()
106 | self.information_extraction_service = self.config.information_extraction_service_cls(
107 | graph_upsert=self.config.information_extraction_upsert_policy
108 | )
109 | self.state_manager = self.config.state_manager_cls(
110 | workspace=Workspace.new(self.working_dir, keep_n=self.n_checkpoints),
111 | embedding_service=self.embedding_service,
112 | graph_storage=self.config.graph_storage,
113 | entity_storage=self.config.entity_storage,
114 | chunk_storage=self.config.chunk_storage,
115 | entity_ranking_policy=self.config.entity_ranking_policy,
116 | relation_ranking_policy=self.config.relation_ranking_policy,
117 | chunk_ranking_policy=self.config.chunk_ranking_policy,
118 | node_upsert_policy=self.config.node_upsert_policy,
119 | edge_upsert_policy=self.config.edge_upsert_policy,
120 | )
121 |
--------------------------------------------------------------------------------
/fast_graphrag/_exceptions.py:
--------------------------------------------------------------------------------
1 | class InvalidStorageError(Exception):
2 | """Exception raised for errors in the storage operations."""
3 |
4 | def __init__(self, message: str = "Invalid storage operation"):
5 | self.message = message
6 | super().__init__(self.message)
7 |
8 |
9 | class InvalidStorageUsageError(Exception):
10 | """Exception raised for errors in the usage of the storage."""
11 |
12 | def __init__(self, message: str = "Invalid usage of the storage"):
13 | self.message = message
14 | super().__init__(self.message)
15 |
16 |
17 | class LLMServiceNoResponseError(Exception):
18 | """Exception raised when the LLM service does not provide a response."""
19 |
20 | def __init__(self, message: str = "LLM service did not provide a response"):
21 | self.message = message
22 | super().__init__(self.message)
23 |
--------------------------------------------------------------------------------
/fast_graphrag/_llm/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = [
2 | "BaseLLMService",
3 | "BaseEmbeddingService",
4 | "DefaultEmbeddingService",
5 | "DefaultLLMService",
6 | "format_and_send_prompt",
7 | "OpenAIEmbeddingService",
8 | "OpenAILLMService",
9 | "GeminiLLMService",
10 | "GeminiEmbeddingService",
11 | "VoyageAIEmbeddingService"
12 | ]
13 |
14 | from ._base import BaseEmbeddingService, BaseLLMService, format_and_send_prompt
15 | from ._default import DefaultEmbeddingService, DefaultLLMService
16 | from ._llm_genai import GeminiEmbeddingService, GeminiLLMService
17 | from ._llm_openai import OpenAIEmbeddingService, OpenAILLMService
18 | from ._llm_voyage import VoyageAIEmbeddingService
19 |
--------------------------------------------------------------------------------
/fast_graphrag/_llm/_base.py:
--------------------------------------------------------------------------------
1 | """LLM Services module."""
2 |
3 | import os
4 | import re
5 | from dataclasses import dataclass, field
6 | from typing import Any, Optional, Tuple, Type, TypeVar, Union
7 |
8 | import numpy as np
9 | from pydantic import BaseModel
10 |
11 | from fast_graphrag._models import BaseModelAlias
12 | from fast_graphrag._prompt import PROMPTS
13 |
14 | T_model = TypeVar("T_model", bound=Union[BaseModel, BaseModelAlias])
15 | TOKEN_PATTERN = re.compile(r"\w+|[^\w\s]", re.UNICODE)
16 |
17 |
18 | async def format_and_send_prompt(
19 | prompt_key: str,
20 | llm: "BaseLLMService",
21 | format_kwargs: dict[str, Any],
22 | response_model: Type[T_model],
23 | **args: Any,
24 | ) -> Tuple[T_model, list[dict[str, str]]]:
25 | """Get a prompt, format it with the supplied args, and send it to the LLM.
26 |
27 | If a system prompt is provided (i.e. PROMPTS contains a key named
28 | '{prompt_key}_system'), it will use both the system and prompt entries:
29 | - System prompt: PROMPTS[prompt_key + '_system']
30 | - Message prompt: PROMPTS[prompt_key + '_prompt']
31 |
32 | Otherwise, it will default to using the single prompt defined by:
33 | - PROMPTS[prompt_key]
34 |
35 | Args:
36 | prompt_key (str): The key for the prompt in the PROMPTS dictionary.
37 | llm (BaseLLMService): The LLM service to use for sending the message.
38 | response_model (Type[T_model]): The expected response model.
39 | format_kwargs (dict[str, Any]): Dictionary of arguments to format the prompt.
40 | model (str | None): The model to use for the LLM. Defaults to None.
41 | max_tokens (int | None): The maximum number of tokens for the response. Defaults to None.
42 | **args (Any): Additional keyword arguments to pass to the LLM.
43 |
44 | Returns:
45 | Tuple[T_model, list[dict[str, str]]]: The response from the LLM.
46 | """
47 | system_key = prompt_key + "_system"
48 |
49 | if system_key in PROMPTS:
50 | # Use separate system and prompt entries
51 | system = PROMPTS[system_key]
52 | prompt = PROMPTS[prompt_key + "_prompt"]
53 | formatted_system = system.format(**format_kwargs)
54 | formatted_prompt = prompt.format(**format_kwargs)
55 | return await llm.send_message(
56 | system_prompt=formatted_system, prompt=formatted_prompt, response_model=response_model, **args
57 | )
58 | else:
59 | # Default: use the single prompt entry
60 | prompt = PROMPTS[prompt_key]
61 | formatted_prompt = prompt.format(**format_kwargs)
62 | return await llm.send_message(prompt=formatted_prompt, response_model=response_model, **args)
63 |
64 |
65 | @dataclass
66 | class BaseLLMService:
67 | """Base class for Language Model implementations."""
68 |
69 | model: str = field()
70 | base_url: Optional[str] = field(default=None)
71 | api_key: Optional[str] = field(default=None)
72 | llm_async_client: Any = field(init=False, default=None)
73 | max_requests_concurrent: int = field(default=int(os.getenv("CONCURRENT_TASK_LIMIT", 1024)))
74 | max_requests_per_minute: int = field(default=500)
75 | max_requests_per_second: int = field(default=60)
76 | rate_limit_concurrency: bool = field(default=True)
77 | rate_limit_per_minute: bool = field(default=False)
78 | rate_limit_per_second: bool = field(default=False)
79 |
80 | def count_tokens(self, text: str) -> int:
81 | """Returns the number of tokens for a given text using the encoding appropriate for the model."""
82 | return len(TOKEN_PATTERN.findall(text))
83 |
84 | def is_within_token_limit(self, text: str, token_limit: int):
85 | """Lightweight check to determine if `text` fits within `token_limit` tokens.
86 |
87 | Returns the token count (an int) if it is less than or equal to the limit,
88 | otherwise returns False.
89 | """
90 | token_count = self.count_tokens(text)
91 | return token_count if token_count <= token_limit else False
92 |
93 | async def send_message(
94 | self,
95 | prompt: str,
96 | system_prompt: str | None = None,
97 | history_messages: list[dict[str, str]] | None = None,
98 | response_model: Type[T_model] | None = None,
99 | **kwargs: Any,
100 | ) -> Tuple[T_model, list[dict[str, str]]]:
101 | """Send a message to the language model and receive a response.
102 |
103 | Args:
104 | prompt (str): The input message to send to the language model.
105 | model (str): The name of the model to use.
106 | system_prompt (str, optional): The system prompt to set the context for the conversation. Defaults to None.
107 | history_messages (list, optional): A list of previous messages in the conversation. Defaults to empty.
108 | response_model (Type[T], optional): The Pydantic model to parse the response. Defaults to None.
109 | **kwargs: Additional keyword arguments that may be required by specific LLM implementations.
110 |
111 | Returns:
112 | str: The response from the language model.
113 | """
114 | raise NotImplementedError
115 |
116 |
117 | @dataclass
118 | class BaseEmbeddingService:
119 | """Base class for Language Model implementations."""
120 |
121 | embedding_dim: int = field(default=1536)
122 | model: Optional[str] = field(default="text-embedding-3-small")
123 | base_url: Optional[str] = field(default=None)
124 | api_key: Optional[str] = field(default=None)
125 | max_requests_concurrent: int = field(default=int(os.getenv("CONCURRENT_TASK_LIMIT", 1024)))
126 | max_requests_per_minute: int = field(default=500) # Tier 1 OpenAI RPM
127 | max_requests_per_second: int = field(default=100)
128 | rate_limit_concurrency: bool = field(default=True)
129 | rate_limit_per_minute: bool = field(default=True)
130 | rate_limit_per_second: bool = field(default=False)
131 |
132 | embedding_async_client: Any = field(init=False, default=None)
133 |
134 | async def encode(self, texts: list[str], model: Optional[str] = None) -> np.ndarray[Any, np.dtype[np.float32]]:
135 | """Get the embedding representation of the input text.
136 |
137 | Args:
138 | texts (str): The input text to embed.
139 | model (str): The name of the model to use.
140 |
141 | Returns:
142 | list[float]: The embedding vector as a list of floats.
143 | """
144 | raise NotImplementedError
145 |
146 |
147 | class NoopAsyncContextManager:
148 | async def __aenter__(self):
149 | return self
150 |
151 | async def __aexit__(self, exc_type: Any, exc: Any, tb: Any):
152 | pass
153 |
--------------------------------------------------------------------------------
/fast_graphrag/_llm/_default.py:
--------------------------------------------------------------------------------
1 | __all__ = ['DefaultLLMService', 'DefaultEmbeddingService']
2 |
3 | from ._llm_openai import OpenAIEmbeddingService, OpenAILLMService
4 |
5 |
6 | class DefaultLLMService(OpenAILLMService):
7 | pass
8 | class DefaultEmbeddingService(OpenAIEmbeddingService):
9 | pass
10 |
--------------------------------------------------------------------------------
/fast_graphrag/_llm/_llm_voyage.py:
--------------------------------------------------------------------------------
1 | """LLM Services module."""
2 |
3 | import asyncio
4 | import os
5 | from dataclasses import dataclass, field
6 | from itertools import chain
7 | from typing import Any, List, Optional
8 |
9 | import numpy as np
10 | from aiolimiter import AsyncLimiter
11 | from voyageai import client_async
12 | from voyageai.object.embeddings import EmbeddingsObject
13 |
14 | from fast_graphrag._utils import logger
15 |
16 | from ._base import BaseEmbeddingService, NoopAsyncContextManager
17 |
18 |
19 | @dataclass
20 | class VoyageAIEmbeddingService(BaseEmbeddingService):
21 | """Base class for VoyageAI embeddings implementations."""
22 |
23 | embedding_dim: int = field(default=1024)
24 | max_elements_per_request: int = field(default=128) # Max 128 elements per batch for Voyage API
25 | model: Optional[str] = field(default="voyage-3")
26 | api_version: Optional[str] = field(default=None)
27 | max_requests_concurrent: int = field(default=int(os.getenv("CONCURRENT_TASK_LIMIT", 1024)))
28 | max_requests_per_minute: int = field(default=1800)
29 | max_requests_per_second: int = field(default=100)
30 | rate_limit_per_second: bool = field(default=False)
31 |
32 | def __post_init__(self):
33 | self.embedding_max_requests_concurrent = (
34 | asyncio.Semaphore(self.max_requests_concurrent) if self.rate_limit_concurrency else NoopAsyncContextManager()
35 | )
36 | self.embedding_per_minute_limiter = (
37 | AsyncLimiter(self.max_requests_per_minute, 60) if self.rate_limit_per_minute else NoopAsyncContextManager()
38 | )
39 | self.embedding_per_second_limiter = (
40 | AsyncLimiter(self.max_requests_per_second, 1) if self.rate_limit_per_second else NoopAsyncContextManager()
41 | )
42 | self.embedding_async_client: client_async.AsyncClient = client_async.AsyncClient(
43 | api_key=self.api_key, max_retries=4
44 | )
45 | logger.debug("Initialized VoyageAIEmbeddingService.")
46 |
47 | async def encode(self, texts: list[str], model: Optional[str] = None) -> np.ndarray[Any, np.dtype[np.float32]]:
48 | try:
49 | """Get the embedding representation of the input text.
50 |
51 | Args:
52 | texts (str): The input text to embed.
53 | model (str, optional): The name of the model to use. Defaults to the model provided in the config.
54 |
55 | Returns:
56 | list[float]: The embedding vector as a list of floats.
57 | """
58 | logger.debug(f"Getting embedding for texts: {texts}")
59 | model = model or self.model
60 | if model is None:
61 | raise ValueError("Model name must be provided.")
62 |
63 | batched_texts = [
64 | texts[i * self.max_elements_per_request : (i + 1) * self.max_elements_per_request]
65 | for i in range((len(texts) + self.max_elements_per_request - 1) // self.max_elements_per_request)
66 | ]
67 | response = await asyncio.gather(*[self._embedding_request(b, model) for b in batched_texts])
68 |
69 | data = chain(*[r.embeddings for r in response])
70 | embeddings = np.array(list(data))
71 | logger.debug(f"Received embedding response: {len(embeddings)} embeddings")
72 |
73 | return embeddings
74 | except Exception:
75 | logger.exception("An error occurred:", exc_info=True)
76 | raise
77 |
78 | async def _embedding_request(self, input: List[str], model: str) -> EmbeddingsObject:
79 | async with self.embedding_max_requests_concurrent:
80 | async with self.embedding_per_minute_limiter:
81 | async with self.embedding_per_second_limiter:
82 | return await self.embedding_async_client.embed(model=model, texts=input, output_dimension=self.embedding_dim)
83 |
--------------------------------------------------------------------------------
/fast_graphrag/_models.py:
--------------------------------------------------------------------------------
1 | from itertools import chain
2 | from typing import Any, Dict, Iterable, List, Optional
3 |
4 | from pydantic import BaseModel, Field, field_validator
5 | from pydantic._internal import _model_construction
6 |
7 | ####################################################################################################
8 | # LLM Models
9 | ####################################################################################################
10 |
11 |
12 | def _json_schema_slim(schema: dict[str, Any]) -> None:
13 | schema.pop("required")
14 | for prop in schema.get("properties", {}).values():
15 | prop.pop("title", None)
16 |
17 |
18 | class _BaseModelAliasMeta(_model_construction.ModelMetaclass):
19 | def __new__(
20 | cls, name: str, bases: tuple[type[Any], ...], dct: Dict[str, Any], alias: Optional[str] = None, **kwargs: Any
21 | ) -> type:
22 | if alias:
23 | dct["__qualname__"] = alias
24 | name = alias
25 | return super().__new__(cls, name, bases, dct, json_schema_extra=_json_schema_slim, **kwargs)
26 |
27 |
28 | class BaseModelAlias:
29 | class Model(BaseModel, metaclass=_BaseModelAliasMeta):
30 | @staticmethod
31 | def to_dataclass(pydantic: Any) -> Any:
32 | raise NotImplementedError
33 |
34 | def to_str(self) -> str:
35 | raise NotImplementedError
36 |
37 |
38 | ####################################################################################################
39 | # LLM Dumping to strings
40 | ####################################################################################################
41 |
42 |
43 | def dump_to_csv(
44 | data: Iterable[object],
45 | fields: List[str],
46 | separator: str = "\t",
47 | with_header: bool = False,
48 | **values: Dict[str, List[Any]],
49 | ) -> List[str]:
50 | rows = list(
51 | chain(
52 | (separator.join(chain(fields, values.keys())),) if with_header else (),
53 | chain(
54 | separator.join(
55 | chain(
56 | (str(getattr(d, field)).replace("\n", " ").replace("\t", " ") for field in fields),
57 | (str(v).replace("\n", " ").replace("\t", " ") for v in vs),
58 | )
59 | )
60 | for d, *vs in zip(data, *values.values())
61 | ),
62 | )
63 | )
64 | return rows
65 |
66 |
67 | def dump_to_reference_list(data: Iterable[object], separator: str = "\n=====\n\n"):
68 | return [f"[{i + 1}] {d}{separator}" for i, d in enumerate(data)]
69 |
70 |
71 | ####################################################################################################
72 | # Response Models
73 | ####################################################################################################
74 |
75 |
76 | class TAnswer(BaseModel):
77 | answer: str
78 |
79 |
80 | class TEditRelation(BaseModel):
81 | ids: List[int] = Field(..., description="Ids of the facts that you are combining into one")
82 | description: str = Field(
83 | ..., description="Summarized description of the combined facts, in detail and comprehensive"
84 | )
85 |
86 |
87 | class TEditRelationList(BaseModel):
88 | groups: List[TEditRelation] = Field(
89 | ...,
90 | description="List of new fact groups; include only groups of more than one fact",
91 | alias="grouped_facts",
92 | )
93 |
94 |
95 | class TEntityDescription(BaseModel):
96 | description: str
97 |
98 |
99 | class TQueryEntities(BaseModel):
100 | named: List[str] = Field(
101 | ...,
102 | description=("List of named entities extracted from the query"),
103 | )
104 | generic: List[str] = Field(
105 | ...,
106 | description=("List of generic entities extracted from the query"),
107 | )
108 |
109 | @field_validator("named", mode="before")
110 | @classmethod
111 | def uppercase_named(cls, value: List[str]):
112 | return [e.upper() for e in value] if value else value
113 |
114 | # @field_validator("generic", mode="before")
115 | # @classmethod
116 | # def uppercase_generic(cls, value: List[str]):
117 | # return [e.upper() for e in value] if value else value
118 |
--------------------------------------------------------------------------------
/fast_graphrag/_policies/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/circlemind-ai/fast-graphrag/b370efe01ef836af292a3713d59b2ec23d2fe7c4/fast_graphrag/_policies/__init__.py
--------------------------------------------------------------------------------
/fast_graphrag/_policies/_base.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Any, Generic, Iterable, Tuple, Type
3 |
4 | from scipy.sparse import csr_matrix
5 |
6 | from fast_graphrag._llm._llm_openai import BaseLLMService
7 | from fast_graphrag._storage._base import BaseGraphStorage
8 | from fast_graphrag._types import GTEdge, GTId, GTNode, TIndex
9 |
10 |
11 | @dataclass
12 | class BasePolicy:
13 | config: Any = field()
14 |
15 |
16 | ####################################################################################################
17 | # GRAPH UPSERT POLICIES
18 | ####################################################################################################
19 |
20 |
21 | @dataclass
22 | class BaseNodeUpsertPolicy(BasePolicy, Generic[GTNode, GTId]):
23 | async def __call__(
24 | self, llm: BaseLLMService, target: BaseGraphStorage[GTNode, GTEdge, GTId], source_nodes: Iterable[GTNode]
25 | ) -> Tuple[BaseGraphStorage[GTNode, GTEdge, GTId], Iterable[Tuple[TIndex, GTNode]]]:
26 | raise NotImplementedError
27 |
28 |
29 | @dataclass
30 | class BaseEdgeUpsertPolicy(BasePolicy, Generic[GTEdge, GTId]):
31 | async def __call__(
32 | self, llm: BaseLLMService, target: BaseGraphStorage[GTNode, GTEdge, GTId], source_edges: Iterable[GTEdge]
33 | ) -> Tuple[BaseGraphStorage[GTNode, GTEdge, GTId], Iterable[Tuple[TIndex, GTEdge]]]:
34 | raise NotImplementedError
35 |
36 |
37 | @dataclass
38 | class BaseGraphUpsertPolicy(BasePolicy, Generic[GTNode, GTEdge, GTId]):
39 | nodes_upsert_cls: Type[BaseNodeUpsertPolicy[GTNode, GTId]] = field()
40 | edges_upsert_cls: Type[BaseEdgeUpsertPolicy[GTEdge, GTId]] = field()
41 | _nodes_upsert: BaseNodeUpsertPolicy[GTNode, GTId] = field(init=False)
42 | _edges_upsert: BaseEdgeUpsertPolicy[GTEdge, GTId] = field(init=False)
43 |
44 | def __post_init__(self):
45 | self._nodes_upsert = self.nodes_upsert_cls(self.config)
46 | self._edges_upsert = self.edges_upsert_cls(self.config)
47 |
48 | async def __call__(
49 | self,
50 | llm: BaseLLMService,
51 | target: BaseGraphStorage[GTNode, GTEdge, GTId],
52 | source_nodes: Iterable[GTNode],
53 | source_edges: Iterable[GTEdge],
54 | ) -> Tuple[
55 | BaseGraphStorage[GTNode, GTEdge, GTId],
56 | Iterable[Tuple[TIndex, GTNode]],
57 | Iterable[Tuple[TIndex, GTEdge]],
58 | ]:
59 | raise NotImplementedError
60 |
61 |
62 | ####################################################################################################
63 | # RANKING POLICIES
64 | ####################################################################################################
65 |
66 |
67 | class BaseRankingPolicy(BasePolicy):
68 | def __call__(self, scores: csr_matrix) -> csr_matrix:
69 | assert scores.shape[0] == 1, "Ranking policies only supports batch size of 1"
70 | return scores
71 |
--------------------------------------------------------------------------------
/fast_graphrag/_policies/_ranking.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 |
3 | import numpy as np
4 | from scipy.sparse import csr_matrix
5 |
6 | from ._base import BaseRankingPolicy
7 |
8 |
9 | class RankingPolicy_WithThreshold(BaseRankingPolicy): # noqa: N801
10 | @dataclass
11 | class Config:
12 | threshold: float = field(default=0.05)
13 | max_entities: int = field(default=128)
14 |
15 | config: Config = field()
16 |
17 | def __call__(self, scores: csr_matrix) -> csr_matrix:
18 | # Remove scores below threshold
19 | scores.data[scores.data < self.config.threshold] = 0
20 | if scores.nnz >= self.config.max_entities:
21 | smallest_indices = np.argpartition(scores.data, -self.config.max_entities)[:-self.config.max_entities]
22 | scores.data[smallest_indices] = 0
23 | scores.eliminate_zeros()
24 |
25 | return scores
26 |
27 |
28 | class RankingPolicy_TopK(BaseRankingPolicy): # noqa: N801
29 | @dataclass
30 | class Config:
31 | top_k: int = field(default=10)
32 |
33 | top_k: Config = field()
34 |
35 | def __call__(self, scores: csr_matrix) -> csr_matrix:
36 | assert scores.shape[0] == 1, "TopK policy only supports batch size of 1"
37 | if scores.nnz <= self.config.top_k:
38 | return scores
39 |
40 | smallest_indices = np.argpartition(scores.data, -self.config.top_k)[:-self.config.top_k]
41 | scores.data[smallest_indices] = 0
42 | scores.eliminate_zeros()
43 |
44 | return scores
45 |
46 |
47 | class RankingPolicy_Elbow(BaseRankingPolicy): # noqa: N801
48 | def __call__(self, scores: csr_matrix) -> csr_matrix:
49 | assert scores.shape[0] == 1, "Elbow policy only supports batch size of 1"
50 | if scores.nnz <= 1:
51 | return scores
52 |
53 | sorted_scores = np.sort(scores.data)
54 |
55 | # Compute elbow
56 | diff = np.diff(sorted_scores)
57 | elbow = np.argmax(diff) + 1
58 |
59 | smallest_indices = np.argpartition(scores.data, elbow)[:elbow]
60 | scores.data[smallest_indices] = 0
61 | scores.eliminate_zeros()
62 |
63 | return scores
64 |
65 |
66 | class RankingPolicy_WithConfidence(BaseRankingPolicy): # noqa: N801
67 | def __call__(self, scores: csr_matrix) -> csr_matrix:
68 | raise NotImplementedError("Confidence policy is not supported yet.")
69 |
--------------------------------------------------------------------------------
/fast_graphrag/_services/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = [
2 | 'BaseChunkingService',
3 | 'BaseInformationExtractionService',
4 | 'BaseStateManagerService',
5 | 'DefaultChunkingService',
6 | 'DefaultInformationExtractionService',
7 | 'DefaultStateManagerService'
8 | ]
9 |
10 | from ._base import BaseChunkingService, BaseInformationExtractionService, BaseStateManagerService
11 | from ._chunk_extraction import DefaultChunkingService
12 | from ._information_extraction import DefaultInformationExtractionService
13 | from ._state_manager import DefaultStateManagerService
14 |
--------------------------------------------------------------------------------
/fast_graphrag/_services/_base.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from dataclasses import dataclass, field
3 | from typing import Dict, Generic, Iterable, List, Optional, Type
4 |
5 | from scipy.sparse import csr_matrix
6 |
7 | from fast_graphrag._llm import BaseEmbeddingService, BaseLLMService
8 | from fast_graphrag._policies._base import (
9 | BaseEdgeUpsertPolicy,
10 | BaseGraphUpsertPolicy,
11 | BaseNodeUpsertPolicy,
12 | BaseRankingPolicy,
13 | )
14 | from fast_graphrag._storage import BaseBlobStorage, BaseGraphStorage, BaseIndexedKeyValueStorage, BaseVectorStorage
15 | from fast_graphrag._storage._namespace import Workspace
16 | from fast_graphrag._types import (
17 | GTChunk,
18 | GTEdge,
19 | GTEmbedding,
20 | GTHash,
21 | GTId,
22 | GTNode,
23 | TContext,
24 | TDocument,
25 | TIndex,
26 | )
27 |
28 |
29 | @dataclass
30 | class BaseChunkingService(Generic[GTChunk]):
31 | """Base class for chunk extractor."""
32 |
33 | def __post__init__(self):
34 | pass
35 |
36 | async def extract(self, data: Iterable[TDocument]) -> Iterable[Iterable[GTChunk]]:
37 | """Extract unique chunks from the given data."""
38 | raise NotImplementedError
39 |
40 |
41 | @dataclass
42 | class BaseInformationExtractionService(Generic[GTChunk, GTNode, GTEdge, GTId]):
43 | """Base class for entity and relationship extractors."""
44 |
45 | graph_upsert: BaseGraphUpsertPolicy[GTNode, GTEdge, GTId]
46 | max_gleaning_steps: int = 0
47 |
48 | def extract(
49 | self,
50 | llm: BaseLLMService,
51 | documents: Iterable[Iterable[GTChunk]],
52 | prompt_kwargs: Dict[str, str],
53 | entity_types: List[str],
54 | ) -> List[asyncio.Future[Optional[BaseGraphStorage[GTNode, GTEdge, GTId]]]]:
55 | """Extract both entities and relationships from the given data."""
56 | raise NotImplementedError
57 |
58 | async def extract_entities_from_query(
59 | self, llm: BaseLLMService, query: str, prompt_kwargs: Dict[str, str]
60 | ) -> Dict[str, List[str]]:
61 | """Extract entities from the given query."""
62 | raise NotImplementedError
63 |
64 |
65 | @dataclass
66 | class BaseStateManagerService(Generic[GTNode, GTEdge, GTHash, GTChunk, GTId, GTEmbedding]):
67 | """A class for managing state operations."""
68 |
69 | workspace: Optional[Workspace] = field()
70 |
71 | graph_storage: BaseGraphStorage[GTNode, GTEdge, GTId] = field()
72 | entity_storage: BaseVectorStorage[TIndex, GTEmbedding] = field()
73 | chunk_storage: BaseIndexedKeyValueStorage[GTHash, GTChunk] = field()
74 |
75 | embedding_service: BaseEmbeddingService = field()
76 |
77 | node_upsert_policy: BaseNodeUpsertPolicy[GTNode, GTId] = field()
78 | edge_upsert_policy: BaseEdgeUpsertPolicy[GTEdge, GTId] = field()
79 |
80 | entity_ranking_policy: BaseRankingPolicy = field(default_factory=lambda: BaseRankingPolicy(None))
81 | relation_ranking_policy: BaseRankingPolicy = field(default_factory=lambda: BaseRankingPolicy(None))
82 | chunk_ranking_policy: BaseRankingPolicy = field(default_factory=lambda: BaseRankingPolicy(None))
83 |
84 | node_specificity: bool = field(default=False)
85 |
86 | blob_storage_cls: Type[BaseBlobStorage[csr_matrix]] = field(default=BaseBlobStorage)
87 |
88 | async def insert_start(self) -> None:
89 | """Prepare the storage for indexing before adding new data."""
90 | raise NotImplementedError
91 |
92 | async def insert_done(self) -> None:
93 | """Commit the storage operations after indexing."""
94 | raise NotImplementedError
95 |
96 | async def query_start(self) -> None:
97 | """Prepare the storage for indexing before adding new data."""
98 | raise NotImplementedError
99 |
100 | async def query_done(self) -> None:
101 | """Commit the storage operations after indexing."""
102 | raise NotImplementedError
103 |
104 | async def filter_new_chunks(self, chunks_per_data: Iterable[Iterable[GTChunk]]) -> List[List[GTChunk]]:
105 | """Filter the chunks to check for duplicates.
106 |
107 | This method takes a sequence of chunks and returns a sequence of new chunks
108 | that are not already present in the storage. It uses a hashing mechanism to
109 | efficiently identify duplicates.
110 |
111 | Args:
112 | chunks_per_data (Iterable[Iterable[TChunk]]): A sequence of chunks to be filtered.
113 |
114 | Returns:
115 | Iterable[Iterable[TChunk]]: A sequence of chunks that are not in the storage.
116 | """
117 | raise NotImplementedError
118 |
119 | async def upsert(
120 | self,
121 | llm: BaseLLMService,
122 | subgraphs: List[asyncio.Future[Optional[BaseGraphStorage[GTNode, GTEdge, GTId]]]],
123 | documents: Iterable[Iterable[GTChunk]],
124 | show_progress: bool = True
125 | ) -> None:
126 | """Clean and upsert entities, relationships, and chunks into the storage."""
127 | raise NotImplementedError
128 |
129 | async def get_context(
130 | self, query: str, entities: Dict[str, List[str]]
131 | ) -> Optional[TContext[GTNode, GTEdge, GTHash, GTChunk]]:
132 | """Retrieve relevant state from the storage."""
133 | raise NotImplementedError
134 |
135 | async def get_num_entities(self) -> int:
136 | """Get the number of entities in the storage."""
137 | raise NotImplementedError
138 |
139 | async def get_num_relations(self) -> int:
140 | """Get the number of relations in the storage."""
141 | raise NotImplementedError
142 |
143 | async def get_num_chunks(self) -> int:
144 | """Get the number of chunks in the storage."""
145 | raise NotImplementedError
146 |
147 | async def save_graphml(self, output_path: str) -> None:
148 | """Save the graph in GraphML format."""
149 | raise NotImplementedError
150 |
--------------------------------------------------------------------------------
/fast_graphrag/_services/_chunk_extraction.py:
--------------------------------------------------------------------------------
1 | import re
2 | from dataclasses import dataclass, field
3 | from itertools import chain
4 | from typing import Iterable, List, Set, Tuple
5 |
6 | import xxhash
7 |
8 | from fast_graphrag._types import TChunk, TDocument, THash
9 | from fast_graphrag._utils import TOKEN_TO_CHAR_RATIO
10 |
11 | from ._base import BaseChunkingService
12 |
13 | DEFAULT_SEPARATORS = [
14 | # Paragraph and page separators
15 | "\n\n\n",
16 | "\n\n",
17 | "\r\n\r\n",
18 | # Sentence ending punctuation
19 | "。", # Chinese period
20 | ".", # Full-width dot
21 | ".", # English period
22 | "!", # Chinese exclamation mark
23 | "!", # English exclamation mark
24 | "?", # Chinese question mark
25 | "?", # English question mark
26 | ]
27 |
28 |
29 | @dataclass
30 | class DefaultChunkingServiceConfig:
31 | separators: List[str] = field(default_factory=lambda: DEFAULT_SEPARATORS)
32 | chunk_token_size: int = field(default=800)
33 | chunk_token_overlap: int = field(default=100)
34 |
35 |
36 | @dataclass
37 | class DefaultChunkingService(BaseChunkingService[TChunk]):
38 | """Default class for chunk extractor."""
39 |
40 | config: DefaultChunkingServiceConfig = field(default_factory=DefaultChunkingServiceConfig)
41 |
42 | def __post_init__(self):
43 | self._split_re = re.compile(f"({'|'.join(re.escape(s) for s in self.config.separators or [])})")
44 | self._chunk_size = self.config.chunk_token_size * TOKEN_TO_CHAR_RATIO
45 | self._chunk_overlap = self.config.chunk_token_overlap * TOKEN_TO_CHAR_RATIO
46 |
47 | async def extract(self, data: Iterable[TDocument]) -> Iterable[Iterable[TChunk]]:
48 | """Extract unique chunks from the given data."""
49 | chunks_per_data: List[List[TChunk]] = []
50 |
51 | for d in data:
52 | unique_chunk_ids: Set[THash] = set()
53 | extracted_chunks = await self._extract_chunks(d)
54 | chunks: List[TChunk] = []
55 | for chunk in extracted_chunks:
56 | if chunk.id not in unique_chunk_ids:
57 | unique_chunk_ids.add(chunk.id)
58 | chunks.append(chunk)
59 | chunks_per_data.append(chunks)
60 |
61 | return chunks_per_data
62 |
63 | async def _extract_chunks(self, data: TDocument) -> List[TChunk]:
64 | # Sanitise input data:
65 | try:
66 | data.data = data.data.encode(errors="replace").decode()
67 | except UnicodeDecodeError:
68 | # Default to replacing all unrecognised characters with a space
69 | data.data = re.sub(r"[\x00-\x09\x11-\x12\x14-\x1f]", " ", data.data)
70 |
71 | if len(data.data) <= self._chunk_size:
72 | chunks = [data.data]
73 | else:
74 | chunks = self._split_text(data.data)
75 |
76 | return [
77 | TChunk(
78 | id=THash(xxhash.xxh3_64_intdigest(chunk) // 2),
79 | content=chunk,
80 | metadata=data.metadata,
81 | )
82 | for chunk in chunks
83 | ]
84 |
85 | def _split_text(self, text: str) -> List[str]:
86 | return self._merge_splits(self._split_re.split(text))
87 |
88 | def _merge_splits(self, splits: List[str]) -> List[str]:
89 | if not splits:
90 | return []
91 |
92 | # Add empty string to the end to have a separator at the end of the last chunk
93 | splits.append("")
94 |
95 | merged_splits: List[List[Tuple[str, int]]] = []
96 | current_chunk: List[Tuple[str, int]] = []
97 | current_chunk_length: int = 0
98 |
99 | for i, split in enumerate(splits):
100 | split_length: int = len(split)
101 | # Ignore splitting if it's a separator
102 | if (i % 2 == 1) or (
103 | current_chunk_length + split_length <= self._chunk_size - (self._chunk_overlap if i > 0 else 0)
104 | ):
105 | current_chunk.append((split, split_length))
106 | current_chunk_length += split_length
107 | else:
108 | merged_splits.append(current_chunk)
109 | current_chunk = [(split, split_length)]
110 | current_chunk_length = split_length
111 |
112 | merged_splits.append(current_chunk)
113 |
114 | if self._chunk_overlap > 0:
115 | return self._enforce_overlap(merged_splits)
116 | else:
117 | r = ["".join((c[0] for c in chunk)) for chunk in merged_splits]
118 |
119 | return r
120 |
121 | def _enforce_overlap(self, chunks: List[List[Tuple[str, int]]]) -> List[str]:
122 | result: List[str] = []
123 | for i, chunk in enumerate(chunks):
124 | if i == 0:
125 | result.append("".join((c[0] for c in chunk)))
126 | else:
127 | # Compute overlap
128 | overlap_length: int = 0
129 | overlap: List[str] = []
130 | for text, length in reversed(chunks[i - 1]):
131 | if overlap_length + length > self._chunk_overlap:
132 | break
133 | overlap_length += length
134 | overlap.append(text)
135 | result.append("".join(chain(reversed(overlap), (c[0] for c in chunk))))
136 | return result
137 |
--------------------------------------------------------------------------------
/fast_graphrag/_services/_information_extraction.py:
--------------------------------------------------------------------------------
1 | """Entity-Relationship extraction module."""
2 | import asyncio
3 | import re
4 | from dataclasses import dataclass
5 | from typing import Dict, Iterable, List, Literal, Optional
6 |
7 | from pydantic import BaseModel, Field
8 |
9 | from fast_graphrag._llm import BaseLLMService, format_and_send_prompt
10 | from fast_graphrag._models import TQueryEntities
11 | from fast_graphrag._storage._base import BaseGraphStorage
12 | from fast_graphrag._storage._gdb_igraph import IGraphStorage, IGraphStorageConfig
13 | from fast_graphrag._types import GTId, TChunk, TEntity, TGraph, TRelation
14 | from fast_graphrag._utils import logger
15 |
16 | from ._base import BaseInformationExtractionService
17 |
18 |
19 | class TGleaningStatus(BaseModel):
20 | status: Literal["done", "continue"] = Field(
21 | description="done if all entities and relationship have been extracted, continue otherwise"
22 | )
23 |
24 |
25 | @dataclass
26 | class DefaultInformationExtractionService(BaseInformationExtractionService[TChunk, TEntity, TRelation, GTId]):
27 | """Default entity and relationship extractor."""
28 |
29 | def extract(
30 | self,
31 | llm: BaseLLMService,
32 | documents: Iterable[Iterable[TChunk]],
33 | prompt_kwargs: Dict[str, str],
34 | entity_types: List[str],
35 | ) -> List[asyncio.Future[Optional[BaseGraphStorage[TEntity, TRelation, GTId]]]]:
36 | """Extract both entities and relationships from the given data."""
37 | return [
38 | asyncio.create_task(self._extract(llm, document, prompt_kwargs, entity_types)) for document in documents
39 | ]
40 |
41 | async def extract_entities_from_query(
42 | self, llm: BaseLLMService, query: str, prompt_kwargs: Dict[str, str]
43 | ) -> Dict[str, List[str]]:
44 | """Extract entities from the given query."""
45 | prompt_kwargs["query"] = query
46 | entities, _ = await format_and_send_prompt(
47 | prompt_key="entity_extraction_query",
48 | llm=llm,
49 | format_kwargs=prompt_kwargs,
50 | response_model=TQueryEntities,
51 | )
52 |
53 | return {
54 | "named": entities.named,
55 | "generic": entities.generic
56 | }
57 |
58 | async def _extract(
59 | self, llm: BaseLLMService, chunks: Iterable[TChunk], prompt_kwargs: Dict[str, str], entity_types: List[str]
60 | ) -> Optional[BaseGraphStorage[TEntity, TRelation, GTId]]:
61 | """Extract both entities and relationships from the given chunks."""
62 | # Extract entities and relatioships from each chunk
63 | try:
64 | chunk_graphs = await asyncio.gather(
65 | *[self._extract_from_chunk(llm, chunk, prompt_kwargs, entity_types) for chunk in chunks]
66 | )
67 | if len(chunk_graphs) == 0:
68 | return None
69 |
70 | # Combine chunk graphs in document graph
71 | return await self._merge(llm, chunk_graphs)
72 | except Exception as e:
73 | logger.error(f"Error during information extraction from document: {e}")
74 | return None
75 |
76 | async def _gleaning(
77 | self, llm: BaseLLMService, initial_graph: TGraph, history: list[dict[str, str]]
78 | ) -> Optional[TGraph]:
79 | """Do gleaning steps until the llm says we are done or we reach the max gleaning steps."""
80 | # Prompts
81 | current_graph = initial_graph
82 |
83 | try:
84 | for gleaning_count in range(self.max_gleaning_steps):
85 | # Do gleaning step
86 | gleaning_result, history = await format_and_send_prompt(
87 | prompt_key="entity_relationship_continue_extraction",
88 | llm=llm,
89 | format_kwargs={},
90 | response_model=TGraph,
91 | history_messages=history,
92 | )
93 |
94 | # Combine new entities, relationships with previously obtained ones
95 | current_graph.entities.extend(gleaning_result.entities)
96 | current_graph.relationships.extend(gleaning_result.relationships)
97 |
98 | # Stop gleaning if we don't need to keep going
99 | if gleaning_count == self.max_gleaning_steps - 1:
100 | break
101 |
102 | # Ask llm if we are done extracting entities and relationships
103 | gleaning_status, _ = await format_and_send_prompt(
104 | prompt_key="entity_relationship_gleaning_done_extraction",
105 | llm=llm,
106 | format_kwargs={},
107 | response_model=TGleaningStatus,
108 | history_messages=history,
109 | )
110 |
111 | # If we are done parsing, stop gleaning
112 | if gleaning_status.status == Literal["done"]:
113 | break
114 | except Exception as e:
115 | logger.error(f"Error during gleaning: {e}")
116 |
117 | return None
118 |
119 | return current_graph
120 |
121 | async def _extract_from_chunk(
122 | self, llm: BaseLLMService, chunk: TChunk, prompt_kwargs: Dict[str, str], entity_types: List[str]
123 | ) -> TGraph:
124 | """Extract entities and relationships from the given chunk."""
125 | prompt_kwargs["input_text"] = chunk.content
126 |
127 | chunk_graph, history = await format_and_send_prompt(
128 | prompt_key="entity_relationship_extraction",
129 | llm=llm,
130 | format_kwargs=prompt_kwargs,
131 | response_model=TGraph,
132 | )
133 |
134 | # Do gleaning
135 | chunk_graph_with_gleaning = await self._gleaning(llm, chunk_graph, history)
136 | if chunk_graph_with_gleaning:
137 | chunk_graph = chunk_graph_with_gleaning
138 |
139 | _clean_entity_types = [re.sub("[ _]", "", entity_type).upper() for entity_type in entity_types]
140 | for entity in chunk_graph.entities:
141 | if re.sub("[ _]", "", entity.type).upper() not in _clean_entity_types:
142 | entity.type = "UNKNOWN"
143 |
144 | # Assign chunk ids to relationships
145 | for relationship in chunk_graph.relationships:
146 | relationship.chunks = [chunk.id]
147 |
148 | return chunk_graph
149 |
150 | async def _merge(self, llm: BaseLLMService, graphs: List[TGraph]) -> BaseGraphStorage[TEntity, TRelation, GTId]:
151 | """Merge the given graphs into a single graph storage."""
152 | graph_storage = IGraphStorage[TEntity, TRelation, GTId](config=IGraphStorageConfig(TEntity, TRelation))
153 |
154 | await graph_storage.insert_start()
155 |
156 | try:
157 | # This is synchronous since each sub graph is inserted into the graph storage and conflicts are resolved
158 | for graph in graphs:
159 | await self.graph_upsert(llm, graph_storage, graph.entities, graph.relationships)
160 | finally:
161 | await graph_storage.insert_done()
162 |
163 | return graph_storage
164 |
--------------------------------------------------------------------------------
/fast_graphrag/_storage/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = [
2 | 'Namespace',
3 | 'BaseBlobStorage',
4 | 'BaseIndexedKeyValueStorage',
5 | 'BaseVectorStorage',
6 | 'BaseGraphStorage',
7 | 'DefaultBlobStorage',
8 | 'DefaultIndexedKeyValueStorage',
9 | 'DefaultVectorStorage',
10 | 'DefaultGraphStorage',
11 | 'DefaultGraphStorageConfig',
12 | 'DefaultVectorStorageConfig',
13 | ]
14 |
15 | from ._base import BaseBlobStorage, BaseGraphStorage, BaseIndexedKeyValueStorage, BaseVectorStorage, Namespace
16 | from ._default import (
17 | DefaultBlobStorage,
18 | DefaultGraphStorage,
19 | DefaultGraphStorageConfig,
20 | DefaultIndexedKeyValueStorage,
21 | DefaultVectorStorage,
22 | DefaultVectorStorageConfig,
23 | )
24 |
--------------------------------------------------------------------------------
/fast_graphrag/_storage/_blob_pickle.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | from dataclasses import dataclass, field
3 | from typing import Optional
4 |
5 | from fast_graphrag._exceptions import InvalidStorageError
6 | from fast_graphrag._types import GTBlob
7 | from fast_graphrag._utils import logger
8 |
9 | from ._base import BaseBlobStorage
10 |
11 |
12 | @dataclass
13 | class PickleBlobStorage(BaseBlobStorage[GTBlob]):
14 | RESOURCE_NAME = "blob_data.pkl"
15 | _data: Optional[GTBlob] = field(init=False, default=None)
16 |
17 | async def get(self) -> Optional[GTBlob]:
18 | return self._data
19 |
20 | async def set(self, blob: GTBlob) -> None:
21 | self._data = blob
22 |
23 | async def _insert_start(self):
24 | if self.namespace:
25 | data_file_name = self.namespace.get_load_path(self.RESOURCE_NAME)
26 | if data_file_name:
27 | try:
28 | with open(data_file_name, "rb") as f:
29 | self._data = pickle.load(f)
30 | except Exception as e:
31 | t = f"Error loading data file for blob storage {data_file_name}: {e}"
32 | logger.error(t)
33 | raise InvalidStorageError(t) from e
34 | else:
35 | logger.info(f"No data file found for blob storage {data_file_name}. Loading empty storage.")
36 | self._data = None
37 | else:
38 | self._data = None
39 | logger.debug("Creating new volatile blob storage.")
40 |
41 | async def _insert_done(self):
42 | if self.namespace:
43 | data_file_name = self.namespace.get_save_path(self.RESOURCE_NAME)
44 | try:
45 | with open(data_file_name, "wb") as f:
46 | pickle.dump(self._data, f)
47 | logger.debug(
48 | f"Saving blob storage '{data_file_name}'."
49 | )
50 | except Exception as e:
51 | logger.error(f"Error saving data file for blob storage {data_file_name}: {e}")
52 |
53 | async def _query_start(self):
54 | assert self.namespace, "Loading a blob storage requires a namespace."
55 |
56 | data_file_name = self.namespace.get_load_path(self.RESOURCE_NAME)
57 | if data_file_name:
58 | try:
59 | with open(data_file_name, "rb") as f:
60 | self._data = pickle.load(f)
61 | except Exception as e:
62 | t = f"Error loading data file for blob storage {data_file_name}: {e}"
63 | logger.error(t)
64 | raise InvalidStorageError(t) from e
65 | else:
66 | logger.warning(f"No data file found for blob storage {data_file_name}. Loading empty blob.")
67 | self._data = None
68 |
69 | async def _query_done(self):
70 | pass
71 |
--------------------------------------------------------------------------------
/fast_graphrag/_storage/_default.py:
--------------------------------------------------------------------------------
1 | __all__ = [
2 | "DefaultVectorStorage",
3 | "DefaultVectorStorageConfig",
4 | "DefaultBlobStorage",
5 | "DefaultIndexedKeyValueStorage",
6 | "DefaultGraphStorage",
7 | "DefaultGraphStorageConfig",
8 | ]
9 |
10 | from fast_graphrag._storage._blob_pickle import PickleBlobStorage
11 | from fast_graphrag._storage._gdb_igraph import IGraphStorage, IGraphStorageConfig
12 | from fast_graphrag._storage._ikv_pickle import PickleIndexedKeyValueStorage
13 | from fast_graphrag._storage._vdb_hnswlib import HNSWVectorStorage, HNSWVectorStorageConfig
14 | from fast_graphrag._types import GTBlob, GTEdge, GTEmbedding, GTId, GTKey, GTNode, GTValue
15 |
16 |
17 | # Storage
18 | class DefaultVectorStorage(HNSWVectorStorage[GTId, GTEmbedding]):
19 | pass
20 | class DefaultVectorStorageConfig(HNSWVectorStorageConfig):
21 | pass
22 | class DefaultBlobStorage(PickleBlobStorage[GTBlob]):
23 | pass
24 | class DefaultIndexedKeyValueStorage(PickleIndexedKeyValueStorage[GTKey, GTValue]):
25 | pass
26 | class DefaultGraphStorage(IGraphStorage[GTNode, GTEdge, GTId]):
27 | pass
28 | class DefaultGraphStorageConfig(IGraphStorageConfig[GTNode, GTEdge]):
29 | pass
30 |
--------------------------------------------------------------------------------
/fast_graphrag/_storage/_ikv_pickle.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | from dataclasses import dataclass, field
3 | from typing import Dict, Iterable, List, Optional, Union
4 |
5 | import numpy as np
6 | import numpy.typing as npt
7 |
8 | from fast_graphrag._exceptions import InvalidStorageError
9 | from fast_graphrag._types import GTKey, GTValue, TIndex
10 | from fast_graphrag._utils import logger
11 |
12 | from ._base import BaseIndexedKeyValueStorage
13 |
14 |
15 | @dataclass
16 | class PickleIndexedKeyValueStorage(BaseIndexedKeyValueStorage[GTKey, GTValue]):
17 | RESOURCE_NAME = "kv_data.pkl"
18 | _data: Dict[Union[None, TIndex], GTValue] = field(init=False, default_factory=dict)
19 | _key_to_index: Dict[GTKey, TIndex] = field(init=False, default_factory=dict)
20 | _free_indices: List[TIndex] = field(init=False, default_factory=list)
21 | _np_keys: Optional[npt.NDArray[np.object_]] = field(init=False, default=None)
22 |
23 | async def size(self) -> int:
24 | return len(self._data)
25 |
26 | async def get(self, keys: Iterable[GTKey]) -> Iterable[Optional[GTValue]]:
27 | return (self._data.get(self._key_to_index.get(key, None), None) for key in keys)
28 |
29 | async def get_by_index(self, indices: Iterable[TIndex]) -> Iterable[Optional[GTValue]]:
30 | return (self._data.get(index, None) for index in indices)
31 |
32 | async def get_index(self, keys: Iterable[GTKey]) -> Iterable[Optional[TIndex]]:
33 | return (self._key_to_index.get(key, None) for key in keys)
34 |
35 | async def upsert(self, keys: Iterable[GTKey], values: Iterable[GTValue]) -> None:
36 | for key, value in zip(keys, values):
37 | index = self._key_to_index.get(key, None)
38 | if index is None:
39 | if len(self._free_indices) > 0:
40 | index = self._free_indices.pop()
41 | else:
42 | index = TIndex(len(self._data))
43 | self._key_to_index[key] = index
44 |
45 | # Invalidate cache
46 | self._np_keys = None
47 | self._data[index] = value
48 |
49 | async def delete(self, keys: Iterable[GTKey]) -> None:
50 | for key in keys:
51 | index = self._key_to_index.pop(key, None)
52 | if index is not None:
53 | self._free_indices.append(index)
54 | self._data.pop(index, None)
55 |
56 | # Invalidate cache
57 | self._np_keys = None
58 | else:
59 | logger.warning(f"Key '{key}' not found in indexed key-value storage.")
60 |
61 | async def mask_new(self, keys: Iterable[GTKey]) -> Iterable[bool]:
62 | keys = list(keys)
63 |
64 | if len(keys) == 0:
65 | return np.array([], dtype=bool)
66 |
67 | if self._np_keys is None:
68 | self._np_keys = np.fromiter(
69 | self._key_to_index.keys(),
70 | count=len(self._key_to_index),
71 | dtype=type(keys[0]),
72 | )
73 | keys_array = np.array(keys, dtype=type(keys[0]))
74 |
75 | return ~np.isin(keys_array, self._np_keys)
76 |
77 | async def _insert_start(self):
78 | if self.namespace:
79 | data_file_name = self.namespace.get_load_path(self.RESOURCE_NAME)
80 |
81 | if data_file_name:
82 | try:
83 | with open(data_file_name, "rb") as f:
84 | self._data, self._free_indices, self._key_to_index = pickle.load(f)
85 | logger.debug(
86 | f"Loaded {len(self._data)} elements from indexed key-value storage '{data_file_name}'."
87 | )
88 | except Exception as e:
89 | t = f"Error loading data file for key-vector storage '{data_file_name}': {e}"
90 | logger.error(t)
91 | raise InvalidStorageError(t) from e
92 | else:
93 | logger.info(f"No data file found for key-vector storage '{data_file_name}'. Loading empty storage.")
94 | self._data = {}
95 | self._free_indices = []
96 | self._key_to_index = {}
97 | else:
98 | self._data = {}
99 | self._free_indices = []
100 | self._key_to_index = {}
101 | logger.debug("Creating new volatile indexed key-value storage.")
102 | self._np_keys = None
103 |
104 | async def _insert_done(self):
105 | if self.namespace:
106 | data_file_name = self.namespace.get_save_path(self.RESOURCE_NAME)
107 | try:
108 | with open(data_file_name, "wb") as f:
109 | pickle.dump((self._data, self._free_indices, self._key_to_index), f)
110 | logger.debug(f"Saving {len(self._data)} elements to indexed key-value storage '{data_file_name}'.")
111 | except Exception as e:
112 | t = f"Error saving data file for key-vector storage '{data_file_name}': {e}"
113 | logger.error(t)
114 | raise InvalidStorageError(t) from e
115 |
116 | async def _query_start(self):
117 | assert self.namespace, "Loading a kv storage requires a namespace."
118 | data_file_name = self.namespace.get_load_path(self.RESOURCE_NAME)
119 | if data_file_name:
120 | try:
121 | with open(data_file_name, "rb") as f:
122 | self._data, self._free_indices, self._key_to_index = pickle.load(f)
123 | logger.debug(
124 | f"Loaded {len(self._data)} elements from indexed key-value storage '{data_file_name}'."
125 | )
126 | except Exception as e:
127 | t = f"Error loading data file for key-vector storage {data_file_name}: {e}"
128 | logger.error(t)
129 | raise InvalidStorageError(t) from e
130 | else:
131 | logger.warning(f"No data file found for key-vector storage '{data_file_name}'. Loading empty storage.")
132 | self._data = {}
133 | self._free_indices = []
134 | self._key_to_index = {}
135 | self._np_keys = None
136 |
137 | async def _query_done(self):
138 | pass
139 |
--------------------------------------------------------------------------------
/fast_graphrag/_storage/_namespace.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import time
4 | from typing import Any, Callable, List, Optional
5 |
6 | from fast_graphrag._exceptions import InvalidStorageError
7 | from fast_graphrag._utils import logger
8 |
9 |
10 | class Workspace:
11 | @staticmethod
12 | def new(working_dir: str, checkpoint: int = 0, keep_n: int = 0) -> "Workspace":
13 | return Workspace(working_dir, checkpoint, keep_n)
14 |
15 | @staticmethod
16 | def get_path(working_dir: str, checkpoint: Optional[int] = None) -> Optional[str]:
17 | if checkpoint is None:
18 | return None
19 | elif checkpoint == 0:
20 | return working_dir
21 | return os.path.join(working_dir, str(checkpoint))
22 |
23 | def __init__(self, working_dir: str, checkpoint: int = 0, keep_n: int = 0):
24 | self.working_dir: str = working_dir
25 | self.keep_n: int = keep_n
26 | if not os.path.exists(working_dir):
27 | os.makedirs(working_dir)
28 |
29 | self.checkpoints = sorted(
30 | (int(x.name) for x in os.scandir(self.working_dir) if x.is_dir() and not x.name.startswith("0__err_")),
31 | reverse=True,
32 | )
33 | if self.checkpoints:
34 | self.current_load_checkpoint = checkpoint if checkpoint else self.checkpoints[0]
35 | else:
36 | self.current_load_checkpoint = checkpoint
37 | self.save_checkpoint: Optional[int] = None
38 | self.failed_checkpoints: List[str] = []
39 |
40 | def __del__(self):
41 | for checkpoint in self.failed_checkpoints:
42 | old_path = os.path.join(self.working_dir, checkpoint)
43 | new_path = os.path.join(self.working_dir, f"0__err_{checkpoint}")
44 | os.rename(old_path, new_path)
45 |
46 | if self.keep_n > 0:
47 | checkpoints = sorted((x.name for x in os.scandir(self.working_dir) if x.is_dir()), reverse=True)
48 | for checkpoint in checkpoints[self.keep_n + 1 :]:
49 | shutil.rmtree(os.path.join(self.working_dir, str(checkpoint)))
50 |
51 | def make_for(self, namespace: str) -> "Namespace":
52 | return Namespace(self, namespace)
53 |
54 | def get_load_path(self) -> Optional[str]:
55 | load_path = self.get_path(self.working_dir, self.current_load_checkpoint)
56 | if load_path == self.working_dir and len([x for x in os.scandir(load_path) if x.is_file()]) == 0:
57 | return None
58 | return load_path
59 |
60 |
61 | def get_save_path(self) -> str:
62 | if self.save_checkpoint is None:
63 | if self.keep_n > 0:
64 | self.save_checkpoint = int(time.time())
65 | else:
66 | self.save_checkpoint = 0
67 | save_path = self.get_path(self.working_dir, self.save_checkpoint)
68 |
69 | assert save_path is not None, "Save path cannot be None."
70 |
71 | if not os.path.exists(save_path):
72 | os.makedirs(save_path)
73 | return os.path.join(save_path)
74 |
75 | def _rollback(self) -> bool:
76 | if self.current_load_checkpoint is None:
77 | return False
78 | # List all directories in the working directory and select the one
79 | # with the smallest number greater then the current load checkpoint.
80 | try:
81 | self.current_load_checkpoint = next(x for x in self.checkpoints if x < self.current_load_checkpoint)
82 | logger.warning("Rolling back to checkpoint: %s", self.current_load_checkpoint)
83 | except (StopIteration, ValueError):
84 | self.current_load_checkpoint = None
85 | logger.warning("No checkpoints to rollback to. Last checkpoint tried: %s", self.current_load_checkpoint)
86 |
87 | return True
88 |
89 | async def with_checkpoints(self, fn: Callable[[], Any]) -> Any:
90 | while True:
91 | try:
92 | return await fn()
93 | except Exception as e:
94 | logger.warning("Error occurred loading checkpoint: %s", e)
95 | if self.current_load_checkpoint is not None:
96 | self.failed_checkpoints.append(str(self.current_load_checkpoint))
97 | if self._rollback() is False:
98 | break
99 | raise InvalidStorageError("No valid checkpoints to load or default storages cannot be created.")
100 |
101 |
102 | class Namespace:
103 | def __init__(self, workspace: Workspace, namespace: Optional[str] = None):
104 | self.namespace = namespace
105 | self.workspace = workspace
106 |
107 | def get_load_path(self, resource_name: str) -> Optional[str]:
108 | assert self.namespace is not None, "Namespace must be set to get resource load path."
109 | load_path = self.workspace.get_load_path()
110 | if load_path is None:
111 | return None
112 | return os.path.join(load_path, f"{self.namespace}_{resource_name}")
113 |
114 | def get_save_path(self, resource_name: str) -> str:
115 | assert self.namespace is not None, "Namespace must be set to get resource save path."
116 | return os.path.join(self.workspace.get_save_path(), f"{self.namespace}_{resource_name}")
117 |
--------------------------------------------------------------------------------
/fast_graphrag/_storage/_vdb_hnswlib.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | from dataclasses import dataclass, field
3 | from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
4 |
5 | import hnswlib
6 | import numpy as np
7 | import numpy.typing as npt
8 | from scipy.sparse import csr_matrix
9 |
10 | from fast_graphrag._exceptions import InvalidStorageError
11 | from fast_graphrag._types import GTEmbedding, GTId, TScore
12 | from fast_graphrag._utils import logger
13 |
14 | from ._base import BaseVectorStorage
15 |
16 |
17 | @dataclass
18 | class HNSWVectorStorageConfig:
19 | ef_construction: int = field(default=128)
20 | M: int = field(default=64)
21 | ef_search: int = field(default=96)
22 | num_threads: int = field(default=-1)
23 |
24 |
25 | @dataclass
26 | class HNSWVectorStorage(BaseVectorStorage[GTId, GTEmbedding]):
27 | RESOURCE_NAME = "hnsw_index_{}.bin"
28 | RESOURCE_METADATA_NAME = "hnsw_metadata.pkl"
29 | INITIAL_MAX_ELEMENTS = 128000
30 | config: HNSWVectorStorageConfig = field() # type: ignore
31 | _index: Any = field(init=False, default=None) # type: ignore
32 | _metadata: Dict[GTId, Dict[str, Any]] = field(default_factory=dict)
33 |
34 | @property
35 | def size(self) -> int:
36 | return self._index.get_current_count()
37 |
38 | @property
39 | def max_size(self) -> int:
40 | return self._index.get_max_elements() or self.INITIAL_MAX_ELEMENTS
41 |
42 | async def upsert(
43 | self,
44 | ids: Iterable[GTId],
45 | embeddings: Iterable[GTEmbedding],
46 | metadata: Union[Iterable[Dict[str, Any]], None] = None,
47 | ) -> None:
48 | ids = list(ids)
49 | embeddings = np.array(list(embeddings), dtype=np.float32)
50 | metadata = list(metadata) if metadata else None
51 |
52 | assert (len(ids) == len(embeddings)) and (
53 | metadata is None or (len(metadata) == len(ids))
54 | ), "ids, embeddings, and metadata (if provided) must have the same length"
55 |
56 | if self.size + len(embeddings) >= self.max_size:
57 | new_size = self.max_size * 2
58 | while self.size + len(embeddings) >= new_size:
59 | new_size *= 2
60 | self._index.resize_index(new_size)
61 | logger.info("Resizing HNSW index.")
62 |
63 | if metadata:
64 | self._metadata.update(dict(zip(ids, metadata)))
65 | self._index.add_items(data=embeddings, ids=ids, num_threads=self.config.num_threads)
66 |
67 | async def get_knn(
68 | self, embeddings: Iterable[GTEmbedding], top_k: int
69 | ) -> Tuple[Iterable[Iterable[GTId]], npt.NDArray[TScore]]:
70 | if self.size == 0:
71 | empty_list: List[List[GTId]] = []
72 | logger.info("Querying knns in empty index.")
73 | return empty_list, np.array([], dtype=TScore)
74 |
75 | top_k = min(top_k, self.size)
76 |
77 | if top_k > self.config.ef_search:
78 | self._index.set_ef(top_k)
79 |
80 | # distances is [0, 2] (best, worst)
81 | ids, distances = self._index.knn_query(data=embeddings, k=top_k, num_threads=self.config.num_threads)
82 |
83 | return ids, 1.0 - np.array(distances, dtype=TScore) * 0.5
84 |
85 | async def score_all(
86 | self, embeddings: Iterable[GTEmbedding], top_k: int = 1, threshold: Optional[float] = None
87 | ) -> csr_matrix:
88 | if not isinstance(embeddings, np.ndarray):
89 | embeddings = np.array(list(embeddings), dtype=np.float32)
90 |
91 | if embeddings.size == 0 or self.size == 0:
92 | logger.warning(f"No provided embeddings ({embeddings.size}) or empty index ({self.size}).")
93 | return csr_matrix((0, self.size))
94 |
95 | top_k = min(top_k, self.size)
96 | if top_k > self.config.ef_search:
97 | self._index.set_ef(top_k)
98 |
99 | # distances is [0, 2] (best, worst)
100 | ids, distances = self._index.knn_query(data=embeddings, k=top_k, num_threads=self.config.num_threads)
101 |
102 | ids = np.array(ids)
103 | scores = 1.0 - np.array(distances, dtype=TScore) * 0.5
104 |
105 | if threshold is not None:
106 | scores[scores < threshold] = 0
107 |
108 | # Create sparse distance matrix with shape (#embeddings, #all_embeddings)
109 | flattened_ids = ids.ravel()
110 | flattened_scores = scores.ravel()
111 |
112 | scores = csr_matrix(
113 | (flattened_scores, (np.repeat(np.arange(len(ids)), top_k), flattened_ids)),
114 | shape=(len(ids), self.size),
115 | )
116 |
117 | return scores
118 |
119 | async def _insert_start(self):
120 | self._index = hnswlib.Index(space="cosine", dim=self.embedding_dim) # type: ignore
121 |
122 | if self.namespace:
123 | index_file_name = self.namespace.get_load_path(self.RESOURCE_NAME.format(self.embedding_dim))
124 | metadata_file_name = self.namespace.get_load_path(self.RESOURCE_METADATA_NAME)
125 |
126 | if index_file_name and metadata_file_name:
127 | try:
128 | self._index.load_index(index_file_name, allow_replace_deleted=True)
129 | with open(metadata_file_name, "rb") as f:
130 | self._metadata = pickle.load(f)
131 | logger.debug(
132 | f"Loaded {self.size} elements from vectordb storage '{index_file_name}'."
133 | )
134 | return # All good
135 | except Exception as e:
136 | t = f"Error loading metadata file for vectordb storage '{metadata_file_name}': {e}"
137 | logger.error(t)
138 | raise InvalidStorageError(t) from e
139 | else:
140 | logger.info(f"No data file found for vectordb storage '{index_file_name}'. Loading empty vectordb.")
141 | else:
142 | logger.debug("Creating new volatile vectordb storage.")
143 | self._index.init_index(
144 | max_elements=self.INITIAL_MAX_ELEMENTS,
145 | ef_construction=self.config.ef_construction,
146 | M=self.config.M,
147 | allow_replace_deleted=True
148 | )
149 | self._index.set_ef(self.config.ef_search)
150 | self._metadata = {}
151 |
152 | async def _insert_done(self):
153 | if self.namespace:
154 | index_file_name = self.namespace.get_save_path(self.RESOURCE_NAME.format(self.embedding_dim))
155 | metadata_file_name = self.namespace.get_save_path(self.RESOURCE_METADATA_NAME)
156 |
157 | try:
158 | self._index.save_index(index_file_name)
159 | with open(metadata_file_name, "wb") as f:
160 | pickle.dump(self._metadata, f)
161 | logger.debug(f"Saving {self.size} elements from vectordb storage '{index_file_name}'.")
162 | except Exception as e:
163 | t = f"Error saving vectordb storage from {index_file_name}: {e}"
164 | logger.error(t)
165 | raise InvalidStorageError(t) from e
166 |
167 | async def _query_start(self):
168 | assert self.namespace, "Loading a vectordb requires a namespace."
169 | self._index = hnswlib.Index(space="cosine", dim=self.embedding_dim) # type: ignore
170 |
171 | index_file_name = self.namespace.get_load_path(self.RESOURCE_NAME.format(self.embedding_dim))
172 | metadata_file_name = self.namespace.get_load_path(self.RESOURCE_METADATA_NAME)
173 | if index_file_name and metadata_file_name:
174 | try:
175 | self._index.load_index(index_file_name, allow_replace_deleted=True)
176 | with open(metadata_file_name, "rb") as f:
177 | self._metadata = pickle.load(f)
178 | logger.debug(f"Loaded {self.size} elements from vectordb storage '{index_file_name}'.")
179 |
180 | return # All good
181 | except Exception as e:
182 | t = f"Error loading vectordb storage from {index_file_name}: {e}"
183 | logger.error(t)
184 | raise InvalidStorageError(t) from e
185 | else:
186 | logger.warning(f"No data file found for vectordb storage '{index_file_name}'. Loading empty vectordb.")
187 | self._index.init_index(
188 | max_elements=self.INITIAL_MAX_ELEMENTS,
189 | ef_construction=self.config.ef_construction,
190 | M=self.config.M,
191 | allow_replace_deleted=True
192 | )
193 | self._index.set_ef(self.config.ef_search)
194 | self._metadata = {}
195 |
196 | async def _query_done(self):
197 | pass
198 |
--------------------------------------------------------------------------------
/fast_graphrag/_utils.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import logging
3 | import time
4 | from functools import wraps
5 | from typing import Any, Callable, List, Optional, Tuple, TypeVar, Union
6 |
7 | import numpy as np
8 | import numpy.typing as npt
9 | from scipy.sparse import csr_matrix
10 |
11 | from fast_graphrag._types import TIndex
12 |
13 | logger = logging.getLogger("graphrag")
14 | TOKEN_TO_CHAR_RATIO = 4
15 |
16 |
17 | def timeit(func: Callable[..., Any]):
18 | @wraps(func)
19 | async def wrapper(*args: Any, **kwargs: Any) -> Any:
20 | start = time.time()
21 | result = await func(*args, **kwargs)
22 | duration = time.time() - start
23 | wrapper.execution_times.append(duration) # type: ignore
24 | return result
25 |
26 | wrapper.execution_times = [] # type: ignore
27 | return wrapper
28 |
29 |
30 | def throttle_async_func_call(
31 | max_concurrent: int = 2048,
32 | stagger_time: Optional[float] = None,
33 | waiting_time: float = 0.001,
34 | ):
35 | _wrappedFn = TypeVar("_wrappedFn", bound=Callable[..., Any])
36 |
37 | def decorator(func: _wrappedFn) -> _wrappedFn:
38 | semaphore = asyncio.Semaphore(max_concurrent)
39 |
40 | @wraps(func)
41 | async def wait_func(*args: Any, **kwargs: Any) -> Any:
42 | async with semaphore:
43 | try:
44 | if stagger_time:
45 | await asyncio.sleep(stagger_time)
46 | return await func(*args, **kwargs)
47 | except Exception as e:
48 | logger.error(f"Error in throttled function {func.__name__}: {e}")
49 | raise e
50 |
51 | return wait_func # type: ignore
52 |
53 | return decorator
54 |
55 |
56 | def get_event_loop() -> asyncio.AbstractEventLoop:
57 | try:
58 | # If there is already an event loop, use it.
59 | loop = asyncio.get_event_loop()
60 | except RuntimeError:
61 | # If in a sub-thread, create a new event loop.
62 | loop = asyncio.new_event_loop()
63 | asyncio.set_event_loop(loop)
64 | return loop
65 |
66 |
67 | def extract_sorted_scores(
68 | row_vector: csr_matrix,
69 | ) -> Tuple[npt.NDArray[np.int64], npt.NDArray[np.float32]]:
70 | """Take a sparse row vector and return a list of non-zero (index, score) pairs sorted by score."""
71 | assert row_vector.shape[0] <= 1, "The input matrix must be a row vector."
72 | if row_vector.shape[0] == 0:
73 | return np.array([], dtype=np.int64), np.array([], dtype=np.float32)
74 |
75 | # Step 1: Get the indices of non-zero elements
76 | non_zero_indices = row_vector.nonzero()[1]
77 |
78 | # Step 2: Extract the probabilities of these indices
79 | probabilities = row_vector.data
80 |
81 | # Step 3: Use NumPy to create arrays for indices and probabilities
82 | indices_array = np.array(non_zero_indices)
83 | probabilities_array = np.array(probabilities)
84 |
85 | # Step 4: Sort the probabilities and get the sorted indices
86 | sorted_indices = np.argsort(probabilities_array)[::-1]
87 |
88 | # Step 5: Create sorted arrays for indices and probabilities
89 | sorted_indices_array = indices_array[sorted_indices]
90 | sorted_probabilities_array = probabilities_array[sorted_indices]
91 |
92 | return sorted_indices_array, sorted_probabilities_array
93 |
94 |
95 | def csr_from_indices_list(
96 | data: List[List[Union[int, TIndex]]], shape: Tuple[int, int]
97 | ) -> csr_matrix:
98 | """Create a CSR matrix from a list of lists."""
99 | num_rows = len(data)
100 |
101 | # Flatten the list of lists and create corresponding row indices
102 | row_indices = np.repeat(np.arange(num_rows), [len(row) for row in data])
103 | col_indices = np.concatenate(data) if num_rows > 0 else np.array([], dtype=np.int64)
104 |
105 | # Data values (all ones in this case)
106 | values = np.broadcast_to(1, len(row_indices))
107 |
108 | # Create the CSR matrix
109 | return csr_matrix((values, (row_indices, col_indices)), shape=shape)
110 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "fast-graphrag"
3 | version = "0.0.5"
4 | description = ""
5 | authors = ["Luca Pinchetti ", "Antonio Vespoli ", "Yuhang Song "]
6 | packages = [{include = "fast_graphrag" }]
7 | readme = "README.md"
8 |
9 | [tool.poetry.dependencies]
10 | python = ">=3.10,<3.13"
11 | igraph = "^0.11.6"
12 | xxhash = "^3.5.0"
13 | pydantic = "^2.9.2"
14 | scipy = "^1.14.1"
15 | scikit-learn = "^1.5.2"
16 | tenacity = "^9.0.0"
17 | openai = "^1.52.1"
18 | scipy-stubs = "^1.14.1.5"
19 | hnswlib = "^0.8.0"
20 | instructor = "^1.6.3"
21 | requests = "^2.32.3"
22 | python-dotenv = "^1.0.1"
23 | tiktoken = "^0.8.0"
24 | aiolimiter = "^1.1.0"
25 | google-genai = "^1.3.0"
26 | vertexai = "^1.71.1"
27 | sentencepiece = "^0.2.0"
28 | json-repair = "^0.39.1"
29 | voyageai = "^0.3.2"
30 |
31 |
32 | [tool.poetry.group.dev.dependencies]
33 | ruff = "^0.7.0"
34 |
35 | [build-system]
36 | requires = ["poetry-core"]
37 | build-backend = "poetry.core.masonry.api"
38 |
39 | [tool.ruff]
40 | line-length = 120
41 | indent-width = 2
42 |
43 | [tool.ruff.lint]
44 | select = [
45 | "E", # pycodestyle errors
46 | "W", # pycodestyle warnings
47 | "F", # pyflakes
48 | "I", # isort
49 | "B", # flake8-bugbear
50 | "C4", # flake8-comprehensions
51 | "N", # PEP8 naming convetions
52 | "D" # pydocstyle
53 | ]
54 | ignore = [
55 | "C901", # too complex
56 | "W191", # indentation contains tabs
57 | "D401" # imperative mood
58 | ]
59 |
60 | [tool.ruff.lint.pydocstyle]
61 | convention = "google"
62 |
63 | [tool.ruff.lint.per-file-ignores]
64 | "_prompt.py" = ["E501"]
65 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | """Testing."""
2 |
--------------------------------------------------------------------------------
/tests/_graphrag_test.py:
--------------------------------------------------------------------------------
1 | # type: ignore
2 | import unittest
3 | from dataclasses import dataclass
4 | from unittest.mock import AsyncMock, MagicMock, patch
5 |
6 | from fast_graphrag._graphrag import BaseGraphRAG
7 | from fast_graphrag._models import TAnswer
8 | from fast_graphrag._types import TContext, TQueryResponse
9 |
10 |
11 | class TestBaseGraphRAG(unittest.IsolatedAsyncioTestCase):
12 | def setUp(self):
13 | self.llm_service = AsyncMock()
14 | self.chunking_service = AsyncMock()
15 | self.information_extraction_service = MagicMock()
16 | self.information_extraction_service.extract_entities_from_query = AsyncMock()
17 | self.state_manager = AsyncMock()
18 | self.state_manager.embedding_service.embedding_dim = self.state_manager.entity_storage.embedding_dim = 1
19 |
20 | @dataclass
21 | class BaseGraphRAGNoEmbeddingValidation(BaseGraphRAG):
22 | def __post_init__(self):
23 | pass
24 |
25 | self.graph_rag = BaseGraphRAGNoEmbeddingValidation(
26 | working_dir="test_dir",
27 | domain="test_domain",
28 | example_queries="test_query",
29 | entity_types=["type1", "type2"],
30 | )
31 | self.graph_rag.llm_service = self.llm_service
32 | self.graph_rag.chunking_service = self.chunking_service
33 | self.graph_rag.information_extraction_service = self.information_extraction_service
34 | self.graph_rag.state_manager = self.state_manager
35 |
36 | async def test_async_insert(self):
37 | self.chunking_service.extract = AsyncMock(return_value=["chunked_data"])
38 | self.state_manager.filter_new_chunks = AsyncMock(return_value=["new_chunks"])
39 | self.information_extraction_service.extract = MagicMock(return_value=["subgraph"])
40 | self.state_manager.upsert = AsyncMock()
41 |
42 | await self.graph_rag.async_insert("test_content", {"meta": "data"})
43 |
44 | self.chunking_service.extract.assert_called_once()
45 | self.state_manager.filter_new_chunks.assert_called_once()
46 | self.information_extraction_service.extract.assert_called_once()
47 | self.state_manager.upsert.assert_called_once()
48 |
49 | @patch("fast_graphrag._graphrag.format_and_send_prompt", new_callable=AsyncMock)
50 | async def test_async_query(self, format_and_send_prompt):
51 | self.information_extraction_service.extract_entities_from_query = AsyncMock(return_value=["entities"])
52 | self.state_manager.get_context = AsyncMock(return_value=TContext([], [], []))
53 | format_and_send_prompt.return_value=(TAnswer(answer="response"), None)
54 |
55 | response = await self.graph_rag.async_query("test_query")
56 |
57 | self.information_extraction_service.extract_entities_from_query.assert_called_once()
58 | self.state_manager.get_context.assert_called_once()
59 | format_and_send_prompt.assert_called_once()
60 | self.assertIsInstance(response, TQueryResponse)
61 |
62 |
63 | if __name__ == "__main__":
64 | unittest.main()
65 |
--------------------------------------------------------------------------------
/tests/_llm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/circlemind-ai/fast-graphrag/b370efe01ef836af292a3713d59b2ec23d2fe7c4/tests/_llm/__init__.py
--------------------------------------------------------------------------------
/tests/_llm/_base_test.py:
--------------------------------------------------------------------------------
1 | # type: ignore
2 | import unittest
3 | from unittest.mock import AsyncMock, patch
4 |
5 | from pydantic import BaseModel
6 |
7 | from fast_graphrag._llm._base import BaseLLMService, format_and_send_prompt
8 |
9 | # Assuming these are defined somewhere in your codebase
10 | PROMPTS = {
11 | "example_prompt": "Hello, {name}!"
12 | }
13 |
14 | class TestModel(BaseModel):
15 | answer: str
16 |
17 | class TestFormatAndSendPrompt(unittest.IsolatedAsyncioTestCase):
18 |
19 | @patch("fast_graphrag._llm._base.PROMPTS", PROMPTS)
20 | async def test_format_and_send_prompt(self):
21 | mock_llm = AsyncMock(spec=BaseLLMService(model=""))
22 | answer = TestModel(answer="TEST")
23 | mock_response = (answer, [{"key": "value"}])
24 | mock_llm.send_message = AsyncMock(return_value=mock_response)
25 |
26 | result = await format_and_send_prompt(
27 | prompt_key="example_prompt",
28 | llm=mock_llm,
29 | format_kwargs={"name": "World"},
30 | response_model=TestModel
31 | )
32 |
33 | mock_llm.send_message.assert_called_once_with(
34 | prompt="Hello, World!",
35 | response_model=TestModel
36 | )
37 | self.assertEqual(result, mock_response)
38 |
39 | @patch("fast_graphrag._llm._base.PROMPTS", PROMPTS)
40 | async def test_format_and_send_prompt_with_additional_args(self):
41 | mock_llm = AsyncMock(spec=BaseLLMService(model=""))
42 | answer = TestModel(answer="TEST")
43 | mock_response = (answer, [{"key": "value"}])
44 | mock_llm.send_message = AsyncMock(return_value=mock_response)
45 |
46 | result = await format_and_send_prompt(
47 | prompt_key="example_prompt",
48 | llm=mock_llm,
49 | format_kwargs={"name": "World"},
50 | response_model=TestModel,
51 | model="test_model",
52 | max_tokens=100
53 | )
54 |
55 | mock_llm.send_message.assert_called_once_with(
56 | prompt="Hello, World!",
57 | response_model=TestModel,
58 | model="test_model",
59 | max_tokens=100
60 | )
61 | self.assertEqual(result, mock_response)
62 |
63 | if __name__ == "__main__":
64 | unittest.main()
65 |
--------------------------------------------------------------------------------
/tests/_llm/_llm_openai_test.py:
--------------------------------------------------------------------------------
1 | # type: ignore
2 | import os
3 | import unittest
4 | from unittest.mock import AsyncMock, MagicMock
5 |
6 | import instructor
7 | from openai import APIConnectionError, AsyncOpenAI, RateLimitError
8 | from tenacity import RetryError
9 |
10 | from fast_graphrag._exceptions import LLMServiceNoResponseError
11 | from fast_graphrag._llm._llm_openai import OpenAIEmbeddingService, OpenAILLMService
12 |
13 | os.environ["OPENAI_API_KEY"] = ""
14 |
15 |
16 | RateLimitError429 = RateLimitError(message="Rate limit exceeded", response=MagicMock(), body=None)
17 |
18 |
19 | class TestOpenAILLMService(unittest.IsolatedAsyncioTestCase):
20 | async def test_send_message_success(self):
21 | service = OpenAILLMService(api_key="test")
22 | mock_response = str("Hi!")
23 | service.llm_async_client = AsyncMock()
24 | service.llm_async_client.chat.completions.create = AsyncMock(return_value=mock_response)
25 |
26 | response, messages = await service.send_message(prompt="Hello")
27 |
28 | self.assertEqual(response, mock_response)
29 | self.assertEqual(messages[-1]["role"], "assistant")
30 |
31 | async def test_send_message_no_response(self):
32 | service = OpenAILLMService(api_key="test")
33 | service.llm_async_client = AsyncMock()
34 | service.llm_async_client.chat.completions.create.return_value = None
35 |
36 | with self.assertRaises(LLMServiceNoResponseError):
37 | await service.send_message(prompt="Hello")
38 |
39 | async def test_send_message_rate_limit_error(self):
40 | service = OpenAILLMService()
41 | mock_response = str("Hi!")
42 | async_open_ai = AsyncOpenAI(api_key="test")
43 | async_open_ai.chat.completions.create = AsyncMock(
44 | side_effect=(RateLimitError429, mock_response)
45 | )
46 | service.llm_async_client: instructor.AsyncInstructor = instructor.from_openai(
47 | async_open_ai
48 | )
49 |
50 | response, messages = await service.send_message(prompt="Hello", response_model=None)
51 |
52 | self.assertEqual(response, mock_response)
53 | self.assertEqual(messages[-1]["role"], "assistant")
54 |
55 | async def test_send_message_api_connection_error(self):
56 | service = OpenAILLMService()
57 | mock_response = str("Hi!")
58 | async_open_ai = AsyncOpenAI(api_key="test")
59 | async_open_ai.chat.completions.create = AsyncMock(
60 | side_effect=(APIConnectionError(request=MagicMock()), mock_response)
61 | )
62 | service.llm_async_client: instructor.AsyncInstructor = instructor.from_openai(
63 | async_open_ai
64 | )
65 |
66 | response, messages = await service.send_message(prompt="Hello")
67 |
68 | self.assertEqual(response, mock_response)
69 | self.assertEqual(messages[-1]["role"], "assistant")
70 |
71 | async def test_send_message_with_system_prompt(self):
72 | service = OpenAILLMService(api_key="test")
73 | mock_response = str("Hi!")
74 | service.llm_async_client = AsyncMock()
75 | service.llm_async_client.chat.completions.create = AsyncMock(return_value=mock_response)
76 |
77 | response, messages = await service.send_message(
78 | prompt="Hello", system_prompt="System prompt"
79 | )
80 |
81 | self.assertEqual(response, mock_response)
82 | self.assertEqual(messages[0]["role"], "system")
83 | self.assertEqual(messages[0]["content"], "System prompt")
84 |
85 | async def test_send_message_with_history(self):
86 | service = OpenAILLMService(api_key="test")
87 | mock_response = str("Hi!")
88 | service.llm_async_client = AsyncMock()
89 | service.llm_async_client.chat.completions.create = AsyncMock(return_value=mock_response)
90 |
91 | history = [{"role": "user", "content": "Previous message"}]
92 | response, messages = await service.send_message(prompt="Hello", history_messages=history)
93 |
94 | self.assertEqual(response, mock_response)
95 | self.assertEqual(messages[0]["role"], "user")
96 | self.assertEqual(messages[0]["content"], "Previous message")
97 |
98 |
99 | class TestOpenAIEmbeddingService(unittest.IsolatedAsyncioTestCase):
100 | async def test_get_embedding_success(self):
101 | service = OpenAIEmbeddingService(api_key="test")
102 | mock_response = AsyncMock()
103 | mock_response.data = [AsyncMock(embedding=[0.1, 0.2, 0.3])]
104 | service.embedding_async_client.embeddings.create = AsyncMock(return_value=mock_response)
105 |
106 | embeddings = await service.encode(texts=["test"], model="text-embedding-3-small")
107 |
108 | self.assertEqual(embeddings.shape, (1, 3))
109 | self.assertEqual(embeddings[0][0], 0.1)
110 |
111 | async def test_get_embedding_rate_limit_error(self):
112 | service = OpenAIEmbeddingService(api_key="test")
113 | mock_response = AsyncMock()
114 | mock_response.data = [AsyncMock(embedding=[0.1, 0.2, 0.3])]
115 | service.embedding_async_client.embeddings.create = AsyncMock(side_effect=(RateLimitError429, mock_response))
116 |
117 | embeddings = await service.encode(texts=["test"], model="text-embedding-3-small")
118 |
119 | self.assertEqual(embeddings.shape, (1, 3))
120 | self.assertEqual(embeddings[0][0], 0.1)
121 |
122 | async def test_get_embedding_api_connection_error(self):
123 | service = OpenAIEmbeddingService(api_key="test")
124 | mock_response = AsyncMock()
125 | mock_response.data = [AsyncMock(embedding=[0.1, 0.2, 0.3])]
126 | service.embedding_async_client.embeddings.create = AsyncMock(
127 | side_effect=(APIConnectionError(request=MagicMock()), mock_response)
128 | )
129 | embeddings = await service.encode(texts=["test"], model="text-embedding-3-small")
130 |
131 | self.assertEqual(embeddings.shape, (1, 3))
132 | self.assertEqual(embeddings[0][0], 0.1)
133 |
134 | async def test_get_embedding_retry_failure(self):
135 | service = OpenAIEmbeddingService(api_key="test")
136 | service.embedding_async_client.embeddings.create = AsyncMock(
137 | side_effect=RateLimitError429
138 | )
139 |
140 | with self.assertRaises(RetryError):
141 | await service.encode(texts=["test"], model="text-embedding-3-small")
142 |
143 | async def test_get_embedding_with_different_model(self):
144 | service = OpenAIEmbeddingService(api_key="test")
145 | mock_response = AsyncMock()
146 | mock_response.data = [AsyncMock(embedding=[0.4, 0.5, 0.6])]
147 | service.embedding_async_client.embeddings.create = AsyncMock(return_value=mock_response)
148 |
149 | embeddings = await service.encode(texts=["test"], model="text-embedding-3-large")
150 |
151 | self.assertEqual(embeddings.shape, (1, 3))
152 | self.assertEqual(embeddings[0][0], 0.4)
153 |
154 |
155 | if __name__ == "__main__":
156 | unittest.main()
157 |
--------------------------------------------------------------------------------
/tests/_models_test.py:
--------------------------------------------------------------------------------
1 | # type: ignore
2 | import unittest
3 |
4 | from pydantic import ValidationError
5 |
6 | from fast_graphrag._models import (
7 | TEditRelation,
8 | TEditRelationList,
9 | TQueryEntities,
10 | dump_to_csv,
11 | dump_to_reference_list,
12 | )
13 | from fast_graphrag._types import TEntity
14 |
15 |
16 | class TestModels(unittest.TestCase):
17 | def test_tqueryentities(self):
18 | query_entities = TQueryEntities(named=["Entity1", "Entity2"], generic=["Generic1", "Generic2"])
19 | self.assertEqual(query_entities.named, ["ENTITY1", "ENTITY2"])
20 | self.assertEqual(query_entities.generic, ["Generic1", "Generic2"])
21 |
22 | with self.assertRaises(ValidationError):
23 | TQueryEntities(entities=["Entity1", "Entity2"], n="two")
24 |
25 | def test_teditrelationship(self):
26 | edit_relationship = TEditRelation(ids=[1, 2], description="Combined relationship description")
27 | self.assertEqual(edit_relationship.ids, [1, 2])
28 | self.assertEqual(edit_relationship.description, "Combined relationship description")
29 |
30 | def test_teditrelationshiplist(self):
31 | edit_relationship = TEditRelation(ids=[1, 2], description="Combined relationship description")
32 | edit_relationship_list = TEditRelationList(grouped_facts=[edit_relationship])
33 | self.assertEqual(edit_relationship_list.groups, [edit_relationship])
34 |
35 | def test_dump_to_csv(self):
36 | data = [TEntity(name="Sample name", type="SAMPLE TYPE", description="Sample description")]
37 | fields = ["name", "type"]
38 | values = {"score": [0.9]}
39 | csv_output = dump_to_csv(data, fields, with_header=True, **values)
40 | expected_output = ["name\ttype\tscore", "Sample name\tSAMPLE TYPE\t0.9"]
41 | self.assertEqual(csv_output, expected_output)
42 |
43 |
44 | class TestDumpToReferenceList(unittest.TestCase):
45 | def test_empty_list(self):
46 | self.assertEqual(dump_to_reference_list([]), [])
47 |
48 | def test_single_element(self):
49 | self.assertEqual(dump_to_reference_list(["item"]), ["[1] item\n=====\n\n"])
50 |
51 | def test_multiple_elements(self):
52 | data = ["item1", "item2", "item3"]
53 | expected = [
54 | "[1] item1\n=====\n\n",
55 | "[2] item2\n=====\n\n",
56 | "[3] item3\n=====\n\n"
57 | ]
58 | self.assertEqual(dump_to_reference_list(data), expected)
59 |
60 | def test_custom_separator(self):
61 | data = ["item1", "item2"]
62 | separator = " | "
63 | expected = [
64 | "[1] item1 | ",
65 | "[2] item2 | "
66 | ]
67 | self.assertEqual(dump_to_reference_list(data, separator), expected)
68 |
69 |
70 | class TestDumpToCsv(unittest.TestCase):
71 | def test_empty_data(self):
72 | self.assertEqual(dump_to_csv([], ["field1", "field2"]), [])
73 |
74 | def test_single_element(self):
75 | class Data:
76 | def __init__(self, field1, field2):
77 | self.field1 = field1
78 | self.field2 = field2
79 |
80 | data = [Data("value1", "value2")]
81 | expected = ["value1\tvalue2"]
82 | self.assertEqual(dump_to_csv(data, ["field1", "field2"]), expected)
83 |
84 | def test_multiple_elements(self):
85 | class Data:
86 | def __init__(self, field1, field2):
87 | self.field1 = field1
88 | self.field2 = field2
89 |
90 | data = [Data("value1", "value2"), Data("value3", "value4")]
91 | expected = ["value1\tvalue2", "value3\tvalue4"]
92 | self.assertEqual(dump_to_csv(data, ["field1", "field2"]), expected)
93 |
94 | def test_with_header(self):
95 | class Data:
96 | def __init__(self, field1, field2):
97 | self.field1 = field1
98 | self.field2 = field2
99 |
100 | data = [Data("value1", "value2")]
101 | expected = ["field1\tfield2", "value1\tvalue2"]
102 | self.assertEqual(dump_to_csv(data, ["field1", "field2"], with_header=True), expected)
103 |
104 | def test_custom_separator(self):
105 | class Data:
106 | def __init__(self, field1, field2):
107 | self.field1 = field1
108 | self.field2 = field2
109 |
110 | data = [Data("value1", "value2")]
111 | expected = ["value1 | value2"]
112 | self.assertEqual(dump_to_csv(data, ["field1", "field2"], separator=" | "), expected)
113 |
114 | def test_additional_values(self):
115 | class Data:
116 | def __init__(self, field1, field2):
117 | self.field1 = field1
118 | self.field2 = field2
119 |
120 | data = [Data("value1", "value2")]
121 | expected = ["value1\tvalue2\tvalue3"]
122 | self.assertEqual(dump_to_csv(data, ["field1", "field2"], value3=["value3"]), expected)
123 |
124 |
125 | if __name__ == "__main__":
126 | unittest.main()
127 |
--------------------------------------------------------------------------------
/tests/_policies/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/circlemind-ai/fast-graphrag/b370efe01ef836af292a3713d59b2ec23d2fe7c4/tests/_policies/__init__.py
--------------------------------------------------------------------------------
/tests/_policies/_ranking_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import numpy as np
4 | from scipy.sparse import csr_matrix
5 |
6 | from fast_graphrag._policies._ranking import (
7 | RankingPolicy_Elbow,
8 | RankingPolicy_TopK,
9 | # RankingPolicy_WithConfidence,
10 | RankingPolicy_WithThreshold,
11 | )
12 |
13 |
14 | class TestRankingPolicyWithThreshold(unittest.TestCase):
15 | def test_threshold(self):
16 | policy = RankingPolicy_WithThreshold(RankingPolicy_WithThreshold.Config(0.1))
17 | scores = csr_matrix([0.05, 0.2, 0.15, 0.05])
18 | result = policy(scores)
19 | expected = csr_matrix([0, 0.2, 0.15, 0])
20 | np.testing.assert_array_equal(result.toarray(), expected.toarray())
21 |
22 | def test_all_below_threshold(self):
23 | policy = RankingPolicy_WithThreshold(RankingPolicy_WithThreshold.Config(0.1))
24 | scores = csr_matrix([0.05, 0.05, 0.05, 0.05])
25 | result = policy(scores)
26 | expected = csr_matrix([], shape=(1, 4))
27 | np.testing.assert_array_equal(result.toarray(), expected.toarray())
28 |
29 | def test_explicit_batch_size_1(self):
30 | policy = RankingPolicy_WithThreshold(RankingPolicy_WithThreshold.Config(0.1))
31 | scores = csr_matrix([[0.05, 0.2, 0.15, 0.05]])
32 | result = policy(scores)
33 | expected = csr_matrix([[0, 0.2, 0.15, 0]])
34 | np.testing.assert_array_equal(result.toarray(), expected.toarray())
35 |
36 | def test_all_above_threshold(self):
37 | policy = RankingPolicy_WithThreshold(RankingPolicy_WithThreshold.Config(0.1))
38 | scores = csr_matrix([0.15, 0.2, 0.25, 0.35])
39 | result = policy(scores)
40 | expected = csr_matrix([0.15, 0.2, 0.25, 0.35])
41 | np.testing.assert_array_equal(result.toarray(), expected.toarray())
42 |
43 |
44 | class TestRankingPolicyTopK(unittest.TestCase):
45 | def test_top_k(self):
46 | policy = RankingPolicy_TopK(RankingPolicy_TopK.Config(2))
47 | scores = csr_matrix([0.05, 0.05, 0.2, 0.15, 0.25])
48 | result = policy(scores)
49 | expected = csr_matrix([0, 0, 0.2, 0.0, 0.25])
50 | np.testing.assert_array_equal(result.toarray(), expected.toarray())
51 |
52 | policy = RankingPolicy_TopK(RankingPolicy_TopK.Config(1))
53 | result = policy(scores)
54 | expected = csr_matrix([0, 0, 0.0, 0.0, 0.25])
55 | np.testing.assert_array_equal(result.toarray(), expected.toarray())
56 |
57 | def test_top_k_less_than_k(self):
58 | policy = RankingPolicy_TopK(RankingPolicy_TopK.Config(5))
59 | scores = csr_matrix([0.05, 0.2, 0.0, 0.15])
60 | result = policy(scores)
61 | expected = csr_matrix([0.05, 0.2, 0.0, 0.15])
62 | np.testing.assert_array_equal(result.toarray(), expected.toarray())
63 |
64 | def test_top_k_is_zero(self):
65 | policy = RankingPolicy_TopK(RankingPolicy_TopK.Config(0))
66 | scores = csr_matrix([0.05, 0.2, 0.15, 0.25])
67 | result = policy(scores)
68 | expected = csr_matrix([0.05, 0.2, 0.15, 0.25])
69 | np.testing.assert_array_equal(result.toarray(), expected.toarray())
70 |
71 | def test_top_k_all_zero(self):
72 | policy = RankingPolicy_TopK(RankingPolicy_TopK.Config(2))
73 | scores = csr_matrix([0, 0, 0, 0, 0])
74 | result = policy(scores)
75 | expected = csr_matrix([0, 0, 0, 0, 0])
76 | np.testing.assert_array_equal(result.toarray(), expected.toarray())
77 |
78 |
79 | class TestRankingPolicyElbow(unittest.TestCase):
80 | def test_elbow(self):
81 | policy = RankingPolicy_Elbow(config=None)
82 | scores = csr_matrix([0.05, 0.2, 0.1, 0.25, 0.1])
83 | result = policy(scores)
84 | expected = csr_matrix([0, 0.2, 0.0, 0.25, 0])
85 | np.testing.assert_array_equal(result.toarray(), expected.toarray())
86 |
87 | def test_elbow_all_zero(self):
88 | policy = RankingPolicy_Elbow(config=None)
89 | scores = csr_matrix([0, 0, 0, 0, 0])
90 | result = policy(scores)
91 | expected = csr_matrix([0, 0, 0, 0, 0])
92 | np.testing.assert_array_equal(result.toarray(), expected.toarray())
93 |
94 | def test_elbow_all_same(self):
95 | policy = RankingPolicy_Elbow(config=None)
96 | scores = csr_matrix([0.05, 0.05, 0.05, 0.05, 0.05])
97 | result = policy(scores)
98 | expected = csr_matrix([0, 0.05, 0.05, 0.05, 0.05])
99 | np.testing.assert_array_equal(result.toarray(), expected.toarray())
100 |
101 | # class TestRankingPolicyWithConfidence(unittest.TestCase):
102 | # def test_not_implemented(self):
103 | # policy = RankingPolicy_WithConfidence()
104 | # scores = csr_matrix([0.05, 0.2, 0.15, 0.25, 0.1])
105 | # with self.assertRaises(NotImplementedError):
106 | # policy(scores)
107 |
108 | if __name__ == '__main__':
109 | unittest.main()
110 |
--------------------------------------------------------------------------------
/tests/_services/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/circlemind-ai/fast-graphrag/b370efe01ef836af292a3713d59b2ec23d2fe7c4/tests/_services/__init__.py
--------------------------------------------------------------------------------
/tests/_services/_chunk_extraction_test.py:
--------------------------------------------------------------------------------
1 | # type: ignore
2 | import unittest
3 | from dataclasses import dataclass
4 | from typing import Any, Dict
5 | from unittest.mock import patch
6 |
7 | import xxhash
8 |
9 | from fast_graphrag._services._chunk_extraction import DefaultChunkingService
10 | from fast_graphrag._types import THash
11 |
12 |
13 | @dataclass
14 | class MockDocument:
15 | data: str
16 | metadata: Dict[str, Any]
17 |
18 |
19 | @dataclass
20 | class MockChunk:
21 | id: THash
22 | content: str
23 | metadata: Dict[str, Any]
24 |
25 |
26 | class TestDefaultChunkingService(unittest.IsolatedAsyncioTestCase):
27 | def setUp(self):
28 | self.chunking_service = DefaultChunkingService()
29 |
30 | async def test_extract(self):
31 | doc1 = MockDocument(data="test data 1", metadata={"meta": "data1"})
32 | doc2 = MockDocument(data="test data 2", metadata={"meta": "data2"})
33 | documents = [doc1, doc2]
34 |
35 | with patch.object(
36 | self.chunking_service,
37 | "_extract_chunks",
38 | return_value=[
39 | MockChunk(id=THash(xxhash.xxh3_64_intdigest(doc1.data) // 2), content=doc1.data, metadata=doc1.metadata)
40 | ],
41 | ) as mock_extract_chunks:
42 | chunks = await self.chunking_service.extract(documents)
43 |
44 | self.assertEqual(len(chunks), 2)
45 | self.assertEqual(len(chunks[0]), 1)
46 | self.assertEqual(chunks[0][0].content, "test data 1")
47 | self.assertEqual(chunks[0][0].metadata, {"meta": "data1"})
48 | mock_extract_chunks.assert_called()
49 |
50 | async def test_extract_with_duplicates(self):
51 | doc1 = MockDocument(data="test data 1", metadata={"meta": "data1"})
52 | doc2 = MockDocument(data="test data 1", metadata={"meta": "data1"})
53 | documents = [doc1, doc2]
54 |
55 | with patch.object(
56 | self.chunking_service,
57 | "_extract_chunks",
58 | return_value=[
59 | MockChunk(id=THash(xxhash.xxh3_64_intdigest(doc1.data) // 2), content=doc1.data, metadata=doc1.metadata)
60 | ],
61 | ) as mock_extract_chunks:
62 | chunks = await self.chunking_service.extract(documents)
63 |
64 | self.assertEqual(len(chunks), 2)
65 | self.assertEqual(len(chunks[0]), 1)
66 | self.assertEqual(len(chunks[1]), 1)
67 | self.assertEqual(chunks[0][0].content, "test data 1")
68 | self.assertEqual(chunks[0][0].metadata, {"meta": "data1"})
69 | self.assertEqual(chunks[1][0].content, "test data 1")
70 | self.assertEqual(chunks[1][0].metadata, {"meta": "data1"})
71 | mock_extract_chunks.assert_called()
72 |
73 | async def test_extract_chunks(self):
74 | doc = MockDocument(data="test data", metadata={"meta": "data"})
75 | chunk = MockChunk(id=THash(xxhash.xxh3_64_intdigest(doc.data) // 2), content=doc.data, metadata=doc.metadata)
76 |
77 | chunks = await self.chunking_service._extract_chunks(doc)
78 | self.assertEqual(len(chunks), 1)
79 | self.assertEqual(chunks[0].id, chunk.id)
80 | self.assertEqual(chunks[0].content, chunk.content)
81 | self.assertEqual(chunks[0].metadata, chunk.metadata)
82 |
83 |
84 | if __name__ == "__main__":
85 | unittest.main()
86 |
--------------------------------------------------------------------------------
/tests/_services/_information_extraction_test.py:
--------------------------------------------------------------------------------
1 | # type: ignore
2 | import asyncio
3 | import unittest
4 | from unittest.mock import AsyncMock, MagicMock, patch
5 |
6 | from fast_graphrag._llm._base import BaseLLMService
7 | from fast_graphrag._models import TQueryEntities
8 | from fast_graphrag._policies._graph_upsert import BaseGraphUpsertPolicy
9 | from fast_graphrag._services import DefaultInformationExtractionService
10 | from fast_graphrag._storage._base import BaseGraphStorage
11 | from fast_graphrag._types import TGraph
12 |
13 |
14 | class TestDefaultInformationExtractionService(unittest.IsolatedAsyncioTestCase):
15 | def setUp(self):
16 | self.llm_service = MagicMock(spec=BaseLLMService)
17 | self.llm_service.send_message = AsyncMock()
18 | self.chunk = MagicMock()
19 | self.chunk.content = "test content"
20 | self.chunk.id = "chunk_id"
21 | self.document = [self.chunk]
22 | self.entity_types = ["entity_type"]
23 | self.prompt_kwargs = {"domain": "test_domain"}
24 | self.service = DefaultInformationExtractionService(graph_upsert=None)
25 | self.service.graph_upsert = AsyncMock(spec=BaseGraphUpsertPolicy)
26 |
27 | @patch("fast_graphrag._services._information_extraction.format_and_send_prompt", new_callable=AsyncMock)
28 | async def test_extract_entities_from_query(self, mock_format_and_send_prompt):
29 | mock_format_and_send_prompt.return_value = (
30 | TQueryEntities(named=["entity1", "entity2"], generic=["generic1"]),
31 | None,
32 | )
33 | r = await self.service.extract_entities_from_query(self.llm_service, "test query", self.prompt_kwargs)
34 | named, generic = r["named"], r["generic"]
35 | self.assertEqual(len(named), 2)
36 | self.assertEqual(named[0], "ENTITY1")
37 | self.assertEqual(named[1], "ENTITY2")
38 | self.assertEqual(len(generic), 1)
39 |
40 | @patch("fast_graphrag._services._information_extraction.format_and_send_prompt", new_callable=AsyncMock)
41 | async def test_extract(self, mock_format_and_send_prompt):
42 | mock_format_and_send_prompt.return_value = (TGraph(entities=[], relationships=[]), [])
43 | tasks = self.service.extract(self.llm_service, [self.document], self.prompt_kwargs, self.entity_types)
44 | results = await asyncio.gather(*tasks)
45 | self.assertEqual(len(results), 1)
46 | self.assertIsInstance(results[0], BaseGraphStorage)
47 |
48 |
49 | if __name__ == "__main__":
50 | unittest.main()
51 |
--------------------------------------------------------------------------------
/tests/_storage/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/circlemind-ai/fast-graphrag/b370efe01ef836af292a3713d59b2ec23d2fe7c4/tests/_storage/__init__.py
--------------------------------------------------------------------------------
/tests/_storage/_base_test.py:
--------------------------------------------------------------------------------
1 | # type: ignore
2 | import unittest
3 | from unittest.mock import AsyncMock, patch
4 |
5 | from fast_graphrag._storage._base import BaseStorage
6 |
7 |
8 | class TestBaseStorage(unittest.IsolatedAsyncioTestCase):
9 |
10 | def setUp(self):
11 | self.storage = BaseStorage(config=None)
12 |
13 | @patch.object(BaseStorage, '_insert_start', new_callable=AsyncMock)
14 | @patch.object(BaseStorage, '_query_done', new_callable=AsyncMock)
15 | @patch("fast_graphrag._storage._base.logger")
16 | async def test_insert_start_from_query_mode(self, mock_logger, mock_query_done, mock_insert_start):
17 | self.storage._mode = "query"
18 | self.storage._in_progress = True
19 |
20 | await self.storage.insert_start()
21 |
22 | mock_query_done.assert_called_once()
23 | mock_insert_start.assert_called_once()
24 | mock_logger.error.assert_called_once()
25 | self.assertEqual(self.storage._mode, "insert")
26 | self.assertFalse(self.storage._in_progress)
27 |
28 | @patch.object(BaseStorage, '_insert_start', new_callable=AsyncMock)
29 | async def test_insert_start_from_none_mode(self, mock_insert_start):
30 | self.storage._mode = None
31 | self.storage._in_progress = False
32 |
33 | await self.storage.insert_start()
34 |
35 | mock_insert_start.assert_called_once()
36 | self.assertEqual(self.storage._mode, "insert")
37 | self.assertFalse(self.storage._in_progress)
38 |
39 | @patch.object(BaseStorage, '_insert_done', new_callable=AsyncMock)
40 | async def test_insert_done_in_insert_mode(self, mock_insert_done):
41 | self.storage._mode = "insert"
42 | self.storage._in_progress = True
43 |
44 | await self.storage.insert_done()
45 |
46 | mock_insert_done.assert_called_once()
47 |
48 | @patch("fast_graphrag._storage._base.logger")
49 | async def test_insert_done_in_query_mode(self, mock_logger):
50 | self.storage._mode = "query"
51 | self.storage._in_progress = True
52 |
53 | await self.storage.insert_done()
54 | mock_logger.error.assert_called_once()
55 |
56 | @patch.object(BaseStorage, '_query_start', new_callable=AsyncMock)
57 | @patch.object(BaseStorage, '_insert_done', new_callable=AsyncMock)
58 | @patch("fast_graphrag._storage._base.logger")
59 | async def test_query_start_from_insert_mode(self, mock_logger, mock_insert_done, mock_query_start):
60 | self.storage._mode = "insert"
61 | self.storage._in_progress = True
62 |
63 | await self.storage.query_start()
64 |
65 | mock_insert_done.assert_called_once()
66 | mock_query_start.assert_called_once()
67 | mock_logger.error.assert_called_once()
68 | self.assertEqual(self.storage._mode, "query")
69 | self.assertFalse(self.storage._in_progress)
70 |
71 | @patch.object(BaseStorage, '_query_start', new_callable=AsyncMock)
72 | async def test_query_start_from_none_mode(self, mock_query_start):
73 | self.storage._mode = None
74 | self.storage._in_progress = False
75 |
76 | await self.storage.query_start()
77 |
78 | mock_query_start.assert_called_once()
79 | self.assertEqual(self.storage._mode, "query")
80 |
81 | @patch.object(BaseStorage, '_query_done', new_callable=AsyncMock)
82 | async def test_query_done_in_query_mode(self, mock_query_done):
83 | self.storage._mode = "query"
84 | self.storage._in_progress = True
85 |
86 | await self.storage.query_done()
87 |
88 | mock_query_done.assert_called_once()
89 |
90 | @patch("fast_graphrag._storage._base.logger")
91 | async def test_query_done_in_insert_mode(self, mock_logger):
92 | self.storage._mode = "insert"
93 | self.storage._in_progress = True
94 |
95 | await self.storage.query_done()
96 | mock_logger.error.assert_called_once()
97 |
98 | if __name__ == "__main__":
99 | unittest.main()
100 |
--------------------------------------------------------------------------------
/tests/_storage/_blob_pickle_test.py:
--------------------------------------------------------------------------------
1 | # type: ignore
2 | import pickle
3 | import unittest
4 | from unittest.mock import MagicMock, mock_open, patch
5 |
6 | from fast_graphrag._exceptions import InvalidStorageError
7 | from fast_graphrag._storage._blob_pickle import PickleBlobStorage
8 |
9 |
10 | class TestPickleBlobStorage(unittest.IsolatedAsyncioTestCase):
11 | async def asyncSetUp(self):
12 | self.namespace = MagicMock()
13 | self.namespace.get_load_path.return_value = "blob_data.pkl"
14 | self.namespace.get_save_path.return_value = "blob_data.pkl"
15 | self.storage = PickleBlobStorage(namespace=self.namespace, config=None)
16 |
17 | async def test_get(self):
18 | self.storage._data = {"key": "value"}
19 | result = await self.storage.get()
20 | self.assertEqual(result, {"key": "value"})
21 |
22 | async def test_set(self):
23 | blob = {"key": "value"}
24 | await self.storage.set(blob)
25 | self.assertEqual(self.storage._data, blob)
26 |
27 | @patch("builtins.open", new_callable=mock_open, read_data=pickle.dumps({"key": "value"}))
28 | async def test_insert_start_with_existing_file(self, mock_open):
29 | await self.storage._insert_start()
30 | self.assertEqual(self.storage._data, {"key": "value"})
31 | mock_open.assert_called_once_with("blob_data.pkl", "rb")
32 |
33 | @patch("os.path.exists", return_value=False)
34 | async def test_insert_start_without_existing_file(self, mock_exists):
35 | self.namespace.get_load_path.return_value = None
36 | await self.storage._insert_start()
37 | self.assertIsNone(self.storage._data)
38 |
39 | @patch("builtins.open", new_callable=mock_open)
40 | async def test_insert_done(self, mock_open):
41 | self.storage._data = {"key": "value"}
42 | await self.storage._insert_done()
43 | mock_open.assert_called_once_with("blob_data.pkl", "wb")
44 | mock_open().write.assert_called_once()
45 |
46 | @patch("builtins.open", new_callable=mock_open, read_data=pickle.dumps({"key": "value"}))
47 | async def test_query_start_with_existing_file(self, mock_open):
48 | await self.storage._query_start()
49 | self.assertEqual(self.storage._data, {"key": "value"})
50 | mock_open.assert_called_once_with("blob_data.pkl", "rb")
51 |
52 | @patch("fast_graphrag._storage._blob_pickle.logger")
53 | async def test_query_start_without_existing_file(self, mock_logger):
54 | self.namespace.get_load_path.return_value = None
55 | await self.storage._query_start()
56 | self.assertIsNone(self.storage._data)
57 | mock_logger.warning.assert_called_once()
58 |
59 | @patch("fast_graphrag._storage._blob_pickle.logger")
60 | async def test_insert_start_with_invalid_file(self, mock_logger):
61 | with self.assertRaises(InvalidStorageError):
62 | await self.storage._insert_start()
63 | mock_logger.error.assert_called_once()
64 |
65 | @patch("fast_graphrag._storage._blob_pickle.logger")
66 | async def test_query_start_with_invalid_file(self, mock_logger):
67 | with self.assertRaises(InvalidStorageError):
68 | await self.storage._query_start()
69 | mock_logger.error.assert_called_once()
70 |
71 | async def test_query_done(self):
72 | await self.storage._query_done() # Should not raise any exceptions
73 |
74 |
75 | if __name__ == "__main__":
76 | unittest.main()
77 |
--------------------------------------------------------------------------------
/tests/_storage/_gdb_igraph_test.py:
--------------------------------------------------------------------------------
1 | # type: ignore
2 |
3 | import unittest
4 | from unittest.mock import AsyncMock, MagicMock, patch
5 |
6 | import numpy as np
7 |
8 | from fast_graphrag._storage._gdb_igraph import IGraphStorage, IGraphStorageConfig
9 | from fast_graphrag._types import TEntity, TRelation
10 |
11 |
12 | class TestIGraphStorage(unittest.IsolatedAsyncioTestCase):
13 | def setUp(self):
14 | self.config = IGraphStorageConfig(node_cls=TEntity, edge_cls=TRelation)
15 | self.storage = IGraphStorage(config=self.config)
16 | self.storage._graph = MagicMock()
17 |
18 | async def test_node_count(self):
19 | self.storage._graph.vcount.return_value = 10
20 | count = await self.storage.node_count()
21 | self.assertEqual(count, 10)
22 |
23 | async def test_edge_count(self):
24 | self.storage._graph.ecount.return_value = 20
25 | count = await self.storage.edge_count()
26 | self.assertEqual(count, 20)
27 |
28 | async def test_get_node(self):
29 | node = MagicMock()
30 | node.name = "node1"
31 | node.attributes.return_value = {"name": "foo", "description": "value", "type": ""}
32 | self.storage._graph.vs.find.return_value = node
33 |
34 | result = await self.storage.get_node("node1")
35 | self.assertEqual(result, (TEntity(**node.attributes()), node.index))
36 |
37 | async def test_get_node_not_found(self):
38 | self.storage._graph.vs.find.side_effect = ValueError
39 | result = await self.storage.get_node("node1")
40 | self.assertEqual(result, (None, None))
41 |
42 | async def test_get_edges(self):
43 | self.storage._get_edge_indices = AsyncMock(return_value=[0, 1])
44 | self.storage.get_edge_by_index = AsyncMock(
45 | side_effect=[TRelation(source="node1", target="node2", description="txt"), None]
46 | )
47 |
48 | edges = await self.storage.get_edges("node1", "node2")
49 | self.assertEqual(edges, [(TRelation(source="node1", target="node2", description="txt"), 0)])
50 |
51 | async def test_get_edge_indices(self):
52 | self.storage._graph.vs.find.side_effect = lambda name: MagicMock(index=name)
53 | self.storage._graph.es.select.return_value = [MagicMock(index=0), MagicMock(index=1)]
54 |
55 | indices = await self.storage._get_edge_indices("node1", "node2")
56 | self.assertEqual(list(indices), [0, 1])
57 |
58 | async def test_get_node_by_index(self):
59 | node = MagicMock()
60 | node.attributes.return_value = {"name": "foo", "description": "value", "type": "type"}
61 | self.storage._graph.vs.__getitem__.return_value = node
62 | self.storage._graph.vcount.return_value = 1
63 |
64 | result = await self.storage.get_node_by_index(0)
65 | self.assertEqual(result, TEntity(**node.attributes()))
66 |
67 | async def test_get_edge_by_index(self):
68 | edge = MagicMock()
69 | edge.source = "node0"
70 | edge.target = "node1"
71 | edge.attributes.return_value = {"description": "value"}
72 | self.storage._graph.es.__getitem__.return_value = edge
73 | self.storage._graph.vs.__getitem__.side_effect = lambda idx: {"name": idx}
74 | self.storage._graph.ecount.return_value = 1
75 |
76 | result = await self.storage.get_edge_by_index(0)
77 | self.assertEqual(result, TRelation(source="node0", target="node1", **edge.attributes()))
78 |
79 | async def test_upsert_node(self):
80 | node = TEntity(name="node1", description="value", type="type")
81 | self.storage._graph.vcount.return_value = 1
82 | self.storage._graph.vs.__getitem__.return_value = MagicMock(index=0)
83 |
84 | index = await self.storage.upsert_node(node, 0)
85 | self.assertEqual(index, 0)
86 |
87 | async def test_upsert_edge(self):
88 | edge = TRelation(source="node1", target="node2", description="desc", chunks=[])
89 | self.storage._graph.ecount.return_value = 1
90 | self.storage._graph.es.__getitem__.return_value = MagicMock(index=0)
91 |
92 | index = await self.storage.upsert_edge(edge, 0)
93 | self.assertEqual(index, 0)
94 |
95 | async def test_delete_edges_by_index(self):
96 | self.storage._graph.delete_edges = MagicMock()
97 | indices = [0, 1]
98 | await self.storage.delete_edges_by_index(indices)
99 | self.storage._graph.delete_edges.assert_called_with(indices)
100 |
101 | @patch("fast_graphrag._storage._gdb_igraph.logger")
102 | async def test_score_nodes_empty_graph(self, mock_logger):
103 | self.storage._graph.vcount.return_value = 0
104 | scores = await self.storage.score_nodes(None)
105 | self.assertEqual(scores.shape, (1, 0))
106 | mock_logger.info.assert_called_with("Trying to score nodes in an empty graph.")
107 |
108 | async def test_score_nodes(self):
109 | self.storage._graph.vcount.return_value = 3
110 | self.storage._graph.personalized_pagerank.return_value = [0.1, 0.2, 0.7]
111 |
112 | scores = await self.storage.score_nodes(None)
113 | self.assertTrue(np.array_equal(scores.toarray(), np.array([[0.1, 0.2, 0.7]], dtype=np.float32)))
114 |
115 | async def test_get_entities_to_relationships_map_empty_graph(self):
116 | self.storage._graph.vs = []
117 | result = await self.storage.get_entities_to_relationships_map()
118 | self.assertEqual(result.shape, (0, 0))
119 |
120 | @patch("fast_graphrag._storage._gdb_igraph.csr_from_indices_list")
121 | async def test_get_entities_to_relationships_map(self, mock_csr_from_indices_list):
122 | self.storage._graph.vs = [MagicMock(incident=lambda: [MagicMock(index=0), MagicMock(index=1)])]
123 | self.storage.node_count = AsyncMock(return_value=1)
124 | self.storage.edge_count = AsyncMock(return_value=2)
125 |
126 | await self.storage.get_entities_to_relationships_map()
127 | mock_csr_from_indices_list.assert_called_with([[0, 1]], shape=(1, 2))
128 |
129 | async def test_get_relationships_attrs_empty_graph(self):
130 | self.storage._graph.es = []
131 | result = await self.storage.get_relationships_attrs("key")
132 | self.assertEqual(result, [])
133 |
134 | async def test_get_relationships_attrs(self):
135 | self.storage._graph.es.__getitem__.return_value = [[1, 2], [3, 4]]
136 | self.storage._graph.es.__len__.return_value = 2
137 | result = await self.storage.get_relationships_attrs("key")
138 | self.assertEqual(result, [[1, 2], [3, 4]])
139 |
140 | @patch("igraph.Graph.Read_Picklez")
141 | @patch("fast_graphrag._storage._gdb_igraph.logger")
142 | async def test_insert_start_with_existing_file(self, mock_logger, mock_read_picklez):
143 | self.storage.namespace = MagicMock()
144 | self.storage.namespace.get_load_path.return_value = "dummy_path"
145 |
146 | await self.storage._insert_start()
147 |
148 | mock_read_picklez.assert_called_with("dummy_path")
149 | mock_logger.debug.assert_called_with("Loaded graph storage 'dummy_path'.")
150 |
151 | @patch("igraph.Graph")
152 | @patch("fast_graphrag._storage._gdb_igraph.logger")
153 | async def test_insert_start_with_no_file(self, mock_logger, mock_graph):
154 | self.storage.namespace = MagicMock()
155 | self.storage.namespace.get_load_path.return_value = None
156 |
157 | await self.storage._insert_start()
158 |
159 | mock_graph.assert_called_with(directed=False)
160 | mock_logger.info.assert_called_with("No data file found for graph storage 'None'. Loading empty graph.")
161 |
162 | @patch("igraph.Graph")
163 | @patch("fast_graphrag._storage._gdb_igraph.logger")
164 | async def test_insert_start_with_no_namespace(self, mock_logger, mock_graph):
165 | self.storage.namespace = None
166 |
167 | await self.storage._insert_start()
168 |
169 | mock_graph.assert_called_with(directed=False)
170 | mock_logger.debug.assert_called_with("Creating new volatile graphdb storage.")
171 |
172 | @patch("igraph.Graph.write_picklez")
173 | @patch("fast_graphrag._storage._gdb_igraph.logger")
174 | async def test_insert_done(self, mock_logger, mock_write_picklez):
175 | self.storage.namespace = MagicMock()
176 | self.storage.namespace.get_save_path.return_value = "dummy_path"
177 |
178 | await self.storage._insert_done()
179 |
180 | mock_write_picklez.assert_called_with(self.storage._graph, "dummy_path")
181 |
182 | @patch("igraph.Graph.Read_Picklez")
183 | @patch("fast_graphrag._storage._gdb_igraph.logger")
184 | async def test_query_start_with_existing_file(self, mock_logger, mock_read_picklez):
185 | self.storage.namespace = MagicMock()
186 | self.storage.namespace.get_load_path.return_value = "dummy_path"
187 |
188 | await self.storage._query_start()
189 |
190 | mock_read_picklez.assert_called_with("dummy_path")
191 | mock_logger.debug.assert_called_with("Loaded graph storage 'dummy_path'.")
192 |
193 | @patch("igraph.Graph")
194 | @patch("fast_graphrag._storage._gdb_igraph.logger")
195 | async def test_query_start_with_no_file(self, mock_logger, mock_graph):
196 | self.storage.namespace = MagicMock()
197 | self.storage.namespace.get_load_path.return_value = None
198 |
199 | await self.storage._query_start()
200 |
201 | mock_graph.assert_called_with(directed=False)
202 | mock_logger.warning.assert_called_with(
203 | "No data file found for graph storage 'None'. Loading empty graph."
204 | )
205 |
206 |
207 | if __name__ == "__main__":
208 | unittest.main()
209 |
--------------------------------------------------------------------------------
/tests/_storage/_ikv_pickle_test.py:
--------------------------------------------------------------------------------
1 | # type: ignore
2 | import pickle
3 | import unittest
4 | from unittest.mock import MagicMock, mock_open, patch
5 |
6 | import numpy as np
7 |
8 | from fast_graphrag._exceptions import InvalidStorageError
9 | from fast_graphrag._storage._ikv_pickle import PickleIndexedKeyValueStorage
10 |
11 |
12 | class TestPickleIndexedKeyValueStorage(unittest.IsolatedAsyncioTestCase):
13 | async def asyncSetUp(self):
14 | self.storage = PickleIndexedKeyValueStorage(namespace=None, config=None)
15 | await self.storage._insert_start()
16 |
17 | async def test_size(self):
18 | self.storage._data = {1: "value1", 2: "value2"}
19 | size = await self.storage.size()
20 | self.assertEqual(size, 2)
21 |
22 | async def test_get(self):
23 | self.storage._data = {1: "value1", 2: "value2"}
24 | self.storage._key_to_index = {"key1": 1, "key2": 2}
25 | result = await self.storage.get(["key1", "key2", "key3"])
26 | self.assertEqual(list(result), ["value1", "value2", None])
27 |
28 | async def test_get_by_index(self):
29 | self.storage._data = {1: "value1", 2: "value2"}
30 | result = await self.storage.get_by_index([1, 2, 3])
31 | self.assertEqual(list(result), ["value1", "value2", None])
32 |
33 | async def test_get_index(self):
34 | self.storage._key_to_index = {"key1": 1, "key2": 2}
35 | result = await self.storage.get_index(["key1", "key2", "key3"])
36 | self.assertEqual(list(result), [1, 2, None])
37 |
38 | async def test_upsert(self):
39 | await self.storage.upsert(["key1", "key2"], ["value1", "value2"])
40 | self.assertEqual(self.storage._data, {0: "value1", 1: "value2"})
41 | self.assertEqual(self.storage._key_to_index, {"key1": 0, "key2": 1})
42 |
43 | async def test_delete(self):
44 | self.storage._data = {0: "value1", 1: "value2"}
45 | self.storage._key_to_index = {"key1": 0, "key2": 1}
46 | await self.storage.delete(["key1"])
47 | self.assertEqual(self.storage._data, {1: "value2"})
48 | self.assertEqual(self.storage._key_to_index, {"key2": 1})
49 | self.assertEqual(self.storage._free_indices, [0])
50 |
51 | async def test_mask_new(self):
52 | self.storage._key_to_index = {"key1": 0, "key2": 1}
53 | result = await self.storage.mask_new([["key1", "key3"]])
54 | self.assertTrue(np.array_equal(result, [[False, True]]))
55 |
56 | @patch("builtins.open", new_callable=mock_open, read_data=pickle.dumps(({0: "value"}, [1, 2, 3], {"key": 0})))
57 | @patch("os.path.exists", return_value=True)
58 | @patch("fast_graphrag._storage._ikv_pickle.logger")
59 | async def test_insert_start_with_existing_file(self, mock_logger, mock_exists, mock_open):
60 | self.storage.namespace = MagicMock()
61 | self.storage.namespace.get_load_path.return_value = "dummy_path"
62 |
63 | # Call the function
64 | await self.storage._insert_start()
65 |
66 | # Check if data was loaded correctly
67 | self.assertEqual(self.storage._data, {0: "value"})
68 | self.assertEqual(self.storage._free_indices, [1, 2, 3])
69 | self.assertEqual(self.storage._key_to_index, {"key": 0})
70 | mock_logger.debug.assert_called_once()
71 |
72 | @patch("fast_graphrag._storage._ikv_pickle.logger")
73 | async def test_insert_start_with_no_file(self, mock_logger):
74 | self.storage.namespace = MagicMock()
75 | self.storage.namespace.get_load_path.return_value = "dummy_path"
76 |
77 | # Call the function
78 | with self.assertRaises(InvalidStorageError):
79 | await self.storage._insert_start()
80 |
81 | @patch("fast_graphrag._storage._ikv_pickle.logger")
82 | async def test_insert_start_with_no_namespace(self, mock_logger):
83 | self.storage.namespace = None
84 |
85 | # Call the function
86 | await self.storage._insert_start()
87 |
88 | # Check if data was initialized correctly
89 | self.assertEqual(self.storage._data, {})
90 | self.assertEqual(self.storage._free_indices, [])
91 | mock_logger.debug.assert_called_with("Creating new volatile indexed key-value storage.")
92 |
93 | @patch("builtins.open", new_callable=mock_open)
94 | @patch("fast_graphrag._storage._ikv_pickle.logger")
95 | async def test_insert_done(self, mock_logger, mock_open):
96 | self.storage.namespace = MagicMock()
97 | self.storage.namespace.get_save_path.return_value = "dummy_path"
98 | self.storage._data = {0: "value"}
99 | self.storage._free_indices = [1, 2, 3]
100 | self.storage._key_to_index = {"key": 0}
101 |
102 | # Call the function
103 | await self.storage._insert_done()
104 |
105 | # Check if data was saved correctly
106 | mock_open.assert_called_with("dummy_path", "wb")
107 | mock_logger.debug.assert_called_with("Saving 1 elements to indexed key-value storage 'dummy_path'.")
108 |
109 | @patch("builtins.open", new_callable=mock_open, read_data=pickle.dumps(({0: "value"}, [1, 2, 3], {"key": 0})))
110 | @patch("os.path.exists", return_value=True)
111 | @patch("fast_graphrag._storage._ikv_pickle.logger")
112 | async def test_query_start_with_existing_file(self, mock_logger, mock_exists, mock_open):
113 | self.storage.namespace = MagicMock()
114 | self.storage.namespace.get_load_path.return_value = "dummy_path"
115 |
116 | # Call the function
117 | await self.storage._query_start()
118 |
119 | # Check if data was loaded correctly
120 | self.assertEqual(self.storage._data, {0: "value"})
121 | self.assertEqual(self.storage._free_indices, [1, 2, 3])
122 | self.assertEqual(self.storage._key_to_index, {"key": 0})
123 | mock_logger.debug.assert_called_with("Loaded 1 elements from indexed key-value storage 'dummy_path'.")
124 |
125 | @patch("os.path.exists", return_value=False)
126 | @patch("fast_graphrag._storage._ikv_pickle.logger")
127 | async def test_query_start_with_no_file(self, mock_logger, mock_exists):
128 | self.storage.namespace = MagicMock()
129 | self.storage.namespace.get_load_path.return_value = None
130 |
131 | # Call the function
132 | await self.storage._query_start()
133 |
134 | # Check if data was initialized correctly
135 | self.assertEqual(self.storage._data, {})
136 | self.assertEqual(self.storage._free_indices, [])
137 | mock_logger.warning.assert_called_with(
138 | "No data file found for key-vector storage 'None'. Loading empty storage."
139 | )
140 |
141 |
142 | if __name__ == "__main__":
143 | unittest.main()
144 |
--------------------------------------------------------------------------------
/tests/_storage/_namespace_test.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import os
3 | import shutil
4 | import unittest
5 | from typing import cast
6 |
7 | from fast_graphrag._exceptions import InvalidStorageError
8 | from fast_graphrag._storage._namespace import Namespace, Workspace
9 |
10 |
11 | class TestWorkspace(unittest.IsolatedAsyncioTestCase):
12 | def setUp(self):
13 | def _(self: Workspace) -> None:
14 | pass
15 |
16 | Workspace.__del__ = _
17 | self.test_dir = "test_workspace"
18 | self.workspace = Workspace(self.test_dir)
19 |
20 | def tearDown(self):
21 | if os.path.exists(self.test_dir):
22 | shutil.rmtree(self.test_dir)
23 |
24 | def test_new_workspace(self):
25 | ws = Workspace.new(self.test_dir)
26 | self.assertIsInstance(ws, Workspace)
27 | self.assertEqual(ws.working_dir, self.test_dir)
28 |
29 | def test_get_load_path_no_checkpoint(self):
30 | self.assertEqual(self.workspace.get_load_path(), None)
31 |
32 | def test_get_save_path_creates_directory(self):
33 | save_path = self.workspace.get_save_path()
34 | self.assertTrue(os.path.exists(save_path))
35 |
36 | async def test_with_checkpoint_failures(self):
37 | for checkpoint in [1, 2, 3]:
38 | os.makedirs(os.path.join(self.test_dir, str(checkpoint)))
39 | self.workspace = Workspace(self.test_dir)
40 |
41 | async def sample_fn():
42 | if "1" not in cast(str, self.workspace.get_load_path()):
43 | raise Exception("Checkpoint not loaded")
44 | return "success"
45 |
46 | result = await self.workspace.with_checkpoints(sample_fn)
47 | self.assertEqual(result, "success")
48 | self.assertEqual(self.workspace.current_load_checkpoint, 1)
49 | self.assertEqual(self.workspace.failed_checkpoints, ["3", "2"])
50 |
51 | async def test_with_checkpoint_no_failure(self):
52 | for checkpoint in [1, 2, 3]:
53 | os.makedirs(os.path.join(self.test_dir, str(checkpoint)))
54 | self.workspace = Workspace(self.test_dir)
55 |
56 | async def sample_fn():
57 | return "success"
58 |
59 | result = await self.workspace.with_checkpoints(sample_fn)
60 | self.assertEqual(result, "success")
61 | self.assertEqual(self.workspace.current_load_checkpoint, 3)
62 | self.assertEqual(self.workspace.failed_checkpoints, [])
63 |
64 | async def test_with_checkpoint_all_failures(self):
65 | for checkpoint in [1, 2, 3]:
66 | os.makedirs(os.path.join(self.test_dir, str(checkpoint)))
67 | self.workspace = Workspace(self.test_dir)
68 |
69 | async def sample_fn():
70 | raise Exception("Checkpoint not loaded")
71 |
72 | with self.assertRaises(InvalidStorageError):
73 | await self.workspace.with_checkpoints(sample_fn)
74 | self.assertEqual(self.workspace.current_load_checkpoint, None)
75 | self.assertEqual(self.workspace.failed_checkpoints, ["3", "2", "1"])
76 |
77 | async def test_with_checkpoint_all_failures_accept_none(self):
78 | for checkpoint in [1, 2, 3]:
79 | os.makedirs(os.path.join(self.test_dir, str(checkpoint)))
80 | self.workspace = Workspace(self.test_dir)
81 |
82 | async def sample_fn():
83 | if self.workspace.get_load_path() is not None:
84 | raise Exception("Checkpoint not loaded")
85 |
86 | result = await self.workspace.with_checkpoints(sample_fn)
87 | self.assertEqual(result, None)
88 | self.assertEqual(self.workspace.current_load_checkpoint, None)
89 | self.assertEqual(self.workspace.failed_checkpoints, ["3", "2", "1"])
90 |
91 |
92 | class TestNamespace(unittest.TestCase):
93 | def setUp(self):
94 | def _(self: Workspace) -> None:
95 | pass
96 |
97 | Workspace.__del__ = _
98 |
99 | def test_get_load_path_no_checkpoint_no_file(self):
100 | self.test_dir = "test_workspace"
101 | self.workspace = Workspace(self.test_dir)
102 | self.workspace.__del__ = lambda: None
103 | self.namespace = Namespace(self.workspace, "test_namespace")
104 | self.assertEqual(None, self.namespace.get_load_path("resource"))
105 | del self.workspace
106 | gc.collect()
107 | if os.path.exists(self.test_dir):
108 | shutil.rmtree(self.test_dir)
109 |
110 | def test_get_load_path_no_checkpoint_with_file(self):
111 | self.test_dir = "test_workspace"
112 | self.workspace = Workspace(self.test_dir)
113 | self.workspace.__del__ = lambda: None
114 | self.namespace = Namespace(self.workspace, "test_namespace")
115 |
116 | with open(os.path.join(self.test_dir, "test_namespace_resource"), "w") as f:
117 | f.write("test")
118 | self.assertEqual(
119 | os.path.join("test_workspace", "test_namespace_resource"), self.namespace.get_load_path("resource")
120 | )
121 | del self.workspace
122 | gc.collect()
123 | if os.path.exists(self.test_dir):
124 | shutil.rmtree(self.test_dir)
125 |
126 | def test_get_load_path_with_checkpoint(self):
127 | self.test_dir = "test_workspace"
128 | self.workspace = Workspace(self.test_dir)
129 | self.namespace = Namespace(self.workspace, "test_namespace")
130 | self.workspace.current_load_checkpoint = 1
131 | load_path = self.namespace.get_load_path("resource")
132 | self.assertEqual(load_path, os.path.join(self.test_dir, "1", "test_namespace_resource"))
133 |
134 | del self.workspace
135 | gc.collect()
136 | if os.path.exists(self.test_dir):
137 | shutil.rmtree(self.test_dir)
138 |
139 | def test_get_save_path_creates_directory(self):
140 | self.test_dir = "test_workspace"
141 | self.workspace = Workspace(self.test_dir)
142 | self.namespace = Namespace(self.workspace, "test_namespace")
143 | save_path = self.namespace.get_save_path("resource")
144 | self.assertTrue(os.path.exists(os.path.join(*os.path.split(save_path)[:-1])))
145 |
146 | del self.workspace
147 | gc.collect()
148 | if os.path.exists(self.test_dir):
149 | shutil.rmtree(self.test_dir)
150 |
151 |
152 | if __name__ == "__main__":
153 | unittest.main()
154 |
--------------------------------------------------------------------------------
/tests/_types_test.py:
--------------------------------------------------------------------------------
1 | # type: ignore
2 | import re
3 | import unittest
4 | from dataclasses import asdict
5 |
6 | from fast_graphrag._types import (
7 | TChunk,
8 | TContext,
9 | TDocument,
10 | TEntity,
11 | TGraph,
12 | TQueryResponse,
13 | TRelation,
14 | TScore,
15 | )
16 |
17 |
18 | class TestTypes(unittest.TestCase):
19 | def test_tdocument(self):
20 | doc = TDocument(data="Sample data", metadata={"key": "value"})
21 | self.assertEqual(doc.data, "Sample data")
22 | self.assertEqual(doc.metadata, {"key": "value"})
23 |
24 | def test_tchunk(self):
25 | chunk = TChunk(id=123, content="Sample content", metadata={"key": "value"})
26 | self.assertEqual(chunk.id, 123)
27 | self.assertEqual(chunk.content, "Sample content")
28 | self.assertEqual(chunk.metadata, {"key": "value"})
29 |
30 | def test_tentity(self):
31 | entity = TEntity(name="Entity1", type="Type1", description="Description1")
32 | self.assertEqual(entity.name, "Entity1")
33 | self.assertEqual(entity.type, "Type1")
34 | self.assertEqual(entity.description, "Description1")
35 |
36 | pydantic_entity = TEntity.Model(name="Entity1", type="Type1", desc="Description1")
37 | entity.name = entity.name.upper()
38 | entity.type = entity.type.upper()
39 | self.assertEqual(asdict(entity), asdict(pydantic_entity.to_dataclass(pydantic_entity)))
40 |
41 | def test_trelation(self):
42 | relation = TRelation(source="Entity1", target="Entity2", description="Relation description")
43 | self.assertEqual(relation.source, "Entity1")
44 | self.assertEqual(relation.target, "Entity2")
45 | self.assertEqual(relation.description, "Relation description")
46 |
47 | pydantic_relation = TRelation.Model(source="Entity1", target="Entity2", desc="Relation description")
48 |
49 | relation.source = relation.source.upper()
50 | relation.target = relation.target.upper()
51 | self.assertEqual(asdict(relation), asdict(pydantic_relation.to_dataclass(pydantic_relation)))
52 |
53 | def test_tgraph(self):
54 | entity = TEntity(name="Entity1", type="Type1", description="Description1")
55 | relation = TRelation(source="Entity1", target="Entity2", description="Relation description")
56 | graph = TGraph(entities=[entity], relationships=[relation])
57 | self.assertEqual(graph.entities, [entity])
58 | self.assertEqual(graph.relationships, [relation])
59 |
60 | pydantic_graph = TGraph.Model(
61 | entities=[TEntity.Model(name="Entity1", type="Type1", desc="Description1")],
62 | relationships=[TRelation.Model(source="Entity1", target="Entity2", desc="Relation description")],
63 | other_relationships=[],
64 | )
65 |
66 | for entity in graph.entities:
67 | entity.name = entity.name.upper()
68 | entity.type = entity.type.upper()
69 | for relation in graph.relationships:
70 | relation.source = relation.source.upper()
71 | relation.target = relation.target.upper()
72 | self.assertEqual(asdict(graph), asdict(pydantic_graph.to_dataclass(pydantic_graph)))
73 |
74 | def test_tcontext(self):
75 | entities = [TEntity(name="Entity1", type="Type1", description="Sample description 1")] * 8 + [
76 | TEntity(name="Entity2", type="Type2", description="Sample description 2")
77 | ] * 8
78 | relationships = [
79 | TRelation(source="Entity1", target="Entity2", description="Relation description 12")
80 | ] * 8 + [
81 | TRelation(source="Entity2", target="Entity1", description="Relation description 21")
82 | ] * 8
83 | chunks = [
84 | TChunk(id=i, content=f"Long and repeated chunk content {i}" * 4, metadata={"key": f"value {i}"})
85 | for i in range(16)
86 | ]
87 |
88 | for r, c in zip(relationships, chunks):
89 | r.chunks = [c.id]
90 | context = TContext(
91 | entities=[(e, TScore(0.9)) for e in entities],
92 | relations=[(r, TScore(0.8)) for r in relationships],
93 | chunks=[(c, TScore(0.7)) for c in chunks],
94 | )
95 | max_chars = {"entities": 128, "relations": 128, "chunks": 512}
96 | csv = context.truncate(max_chars.copy(), True)
97 |
98 | csv_entities = re.findall(r"## Entities\n```csv\n(.*?)\n```", csv, re.DOTALL)
99 | csv_relationships = re.findall(r"## Relationships\n```csv\n(.*?)\n```", csv, re.DOTALL)
100 | csv_chunks = re.findall(r"## Sources\n.*=====", csv, re.DOTALL)
101 |
102 | self.assertEqual(len(csv_entities), 1)
103 | self.assertEqual(len(csv_relationships), 1)
104 | self.assertEqual(len(csv_chunks), 1)
105 |
106 | self.assertGreaterEqual(
107 | sum(max_chars.values()) + 16, len(csv_entities[0]) + len(csv_relationships[0]) + len(csv_chunks[0])
108 | )
109 |
110 | def test_tqueryresponse(self):
111 | context = TContext(
112 | entities=[("Entity1", TScore(0.9))],
113 | relations=[("Relation1", TScore(0.8))],
114 | chunks=[("Chunk1", TScore(0.7))],
115 | )
116 | query_response = TQueryResponse(response="Sample response", context=context)
117 | self.assertEqual(query_response.response, "Sample response")
118 | self.assertEqual(query_response.context, context)
119 |
120 |
121 | if __name__ == "__main__":
122 | unittest.main()
123 |
--------------------------------------------------------------------------------
/tests/_utils_test.py:
--------------------------------------------------------------------------------
1 | # tests/test_utils.py
2 |
3 | import asyncio
4 | import threading
5 | import unittest
6 | from typing import List
7 |
8 | import numpy as np
9 | from scipy.sparse import csr_matrix
10 |
11 | from fast_graphrag._utils import csr_from_indices_list, extract_sorted_scores, get_event_loop
12 |
13 |
14 | class TestGetEventLoop(unittest.TestCase):
15 | def test_get_existing_event_loop(self):
16 | loop = asyncio.new_event_loop()
17 | asyncio.set_event_loop(loop)
18 | self.assertEqual(get_event_loop(), loop)
19 | loop.close()
20 |
21 | def test_get_event_loop_in_sub_thread(self):
22 | def target():
23 | loop = get_event_loop()
24 | self.assertIsInstance(loop, asyncio.AbstractEventLoop)
25 | loop.close()
26 |
27 | thread = threading.Thread(target=target)
28 | thread.start()
29 | thread.join()
30 |
31 |
32 | # Not checked
33 | class TestExtractSortedScores(unittest.TestCase):
34 | def test_non_zero_elements(self):
35 | row_vector = csr_matrix([[0, 0.1, 0, 0.7, 0.5, 0]])
36 | indices, scores = extract_sorted_scores(row_vector)
37 | np.testing.assert_array_equal(indices, np.array([3, 4, 1]))
38 | np.testing.assert_array_equal(scores, np.array([0.7, 0.5, 0.1]))
39 |
40 | def test_empty(self):
41 | row_vector = csr_matrix((0, 0))
42 | indices, scores = extract_sorted_scores(row_vector)
43 | np.testing.assert_array_equal(indices, np.array([], dtype=np.int64))
44 | np.testing.assert_array_equal(scores, np.array([], dtype=np.float32))
45 |
46 | def test_empty_row_vector(self):
47 | row_vector = csr_matrix([[]])
48 | indices, scores = extract_sorted_scores(row_vector)
49 | np.testing.assert_array_equal(indices, np.array([], dtype=np.int64))
50 | np.testing.assert_array_equal(scores, np.array([], dtype=np.float32))
51 |
52 | def test_single_element(self):
53 | row_vector = csr_matrix([[0.5]])
54 | indices, scores = extract_sorted_scores(row_vector)
55 | np.testing.assert_array_equal(indices, np.array([0]))
56 | np.testing.assert_array_equal(scores, np.array([0.5]))
57 |
58 | def test_all_zero_elements(self):
59 | row_vector = csr_matrix([[0, 0, 0, 0, 0]])
60 | indices, scores = extract_sorted_scores(row_vector)
61 | np.testing.assert_array_equal(indices, np.array([], dtype=np.int64))
62 | np.testing.assert_array_equal(scores, np.array([], dtype=np.float32))
63 |
64 | def test_duplicate_elements(self):
65 | row_vector = csr_matrix([[0, 0.1, 0, 0.7, 0.5, 0.7]])
66 | indices, scores = extract_sorted_scores(row_vector)
67 | expected_indices_1 = np.array([5, 3, 4, 1])
68 | expected_indices_2 = np.array([3, 5, 4, 1])
69 | self.assertTrue(
70 | np.array_equal(indices, expected_indices_1) or np.array_equal(indices, expected_indices_2),
71 | f"indices {indices} do not match either {expected_indices_1} or {expected_indices_2}"
72 | )
73 | np.testing.assert_array_equal(scores, np.array([0.7, 0.7, 0.5, 0.1]))
74 |
75 |
76 | class TestCsrFromListOfLists(unittest.TestCase):
77 | def test_repeated_elements(self):
78 | data: List[List[int]] = [[0, 0], [], []]
79 | expected_matrix = csr_matrix(([1, 1, 0], ([0, 0, 0], [0, 0, 0])), shape=(3, 3))
80 | result_matrix = csr_from_indices_list(data, shape=(3, 3))
81 | np.testing.assert_array_equal(result_matrix.toarray(), expected_matrix.toarray())
82 |
83 | def test_non_zero_elements(self):
84 | data = [[0, 1, 2], [2, 3], [0, 3]]
85 | expected_matrix = csr_matrix([[1, 1, 1, 0, 0], [0, 0, 1, 1, 0], [1, 0, 0, 1, 0]], shape=(3, 5))
86 | result_matrix = csr_from_indices_list(data, shape=(3, 5))
87 | np.testing.assert_array_equal(result_matrix.toarray(), expected_matrix.toarray())
88 |
89 | def test_empty_list_of_lists(self):
90 | data: List[List[int]] = []
91 | expected_matrix = csr_matrix((0, 0))
92 | result_matrix = csr_from_indices_list(data, shape=(0, 0))
93 | np.testing.assert_array_equal(result_matrix.toarray(), expected_matrix.toarray())
94 |
95 | def test_empty_list_of_lists_with_unempty_shape(self):
96 | data: List[List[int]] = []
97 | expected_matrix = csr_matrix((1, 1))
98 | result_matrix = csr_from_indices_list(data, shape=(1, 1))
99 | np.testing.assert_array_equal(result_matrix.toarray(), expected_matrix.toarray())
100 |
101 | def test_list_with_empty_sublists(self):
102 | data: List[List[int]] = [[], [], []]
103 | expected_matrix = csr_matrix((3, 0))
104 | result_matrix = csr_from_indices_list(data, shape=(3, 0))
105 | np.testing.assert_array_equal(result_matrix.toarray(), expected_matrix.toarray())
106 |
107 |
108 | if __name__ == "__main__":
109 | unittest.main()
110 |
--------------------------------------------------------------------------------