├── .flake8 ├── .github └── workflows │ └── main.yml ├── .gitignore ├── CHANGELOG.md ├── CONTRIBUTING.md ├── DS_Store ├── LICENCE.md ├── README.md ├── docs ├── api.md ├── index.md ├── installation.md ├── milvus.md ├── pinecone.md └── tutorial.md ├── examples ├── create_index.py ├── demo.ipynb ├── milvus_tutorial.py ├── qdrant │ ├── create_collection.py │ └── query.py └── querying.py ├── mkdocs.yml ├── pyproject.toml ├── src ├── DS_Store └── whyhow_rbr │ ├── __init__.py │ ├── embedding.py │ ├── exceptions.py │ ├── processing.py │ ├── rag.py │ ├── rag_milvus.py │ └── rag_qdrant.py └── tests ├── conftest.py ├── test_dummy.py ├── test_embedding.py ├── test_processing.py ├── test_qdrant_rag.py └── test_rag.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | docstring-convention = numpy 3 | max-line-length = 79 4 | ignore = 5 | # slice notation whitespace, invalid 6 | E203 7 | # import at top, too many circular import fixes 8 | E402 9 | # line length, handled by bugbear B950 10 | E501 11 | # bare except, handled by bugbear B001 12 | E722 13 | # bin op line break, invalid 14 | W503 15 | # bin op line break, invalid 16 | per-file-ignores = 17 | tests/*:D 18 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: all 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [main] 7 | 8 | jobs: 9 | build: 10 | 11 | runs-on: ${{ matrix.os }} 12 | strategy: 13 | matrix: 14 | os: [ubuntu-latest] 15 | python-version: ['3.10'] 16 | 17 | steps: 18 | - uses: actions/checkout@v2 19 | 20 | - name: Set up Python ${{ matrix.python-version }} 21 | uses: actions/setup-python@v2 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | 25 | - name: Install Python dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | pip install -e .[dev] 29 | 30 | - name: Lint with flake8 31 | run: | 32 | flake8 src tests examples 33 | 34 | - name: Check style with black 35 | run: | 36 | black src tests examples 37 | 38 | - name: Run security check 39 | run: | 40 | bandit -qr -c pyproject.toml src examples 41 | 42 | - name: Run import check 43 | run: | 44 | isort --check src tests examples 45 | 46 | - name: Run mypy 47 | run: | 48 | mypy src 49 | 50 | - name: Test with pytest 51 | run: | 52 | pytest --color=yes 53 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | .DS_Store 162 | 163 | data/ 164 | .python-version 165 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), 6 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 7 | 8 | ## [Unreleased] 9 | 10 | ### Added 11 | - First version of the docs 12 | - Change default examples to go in the prompt 13 | - Implement `query` method with the possibility of adding rules 14 | - Implement `upload_documents` method 15 | - Implement `get_index`, `create_index`, `__init__` methods and custom exceptions 16 | - Minimal package structure + CI 17 | - Installation instructions in the README 18 | 19 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Rule-based Retrieval 2 | 3 | Thank you for your interest in contributing to the Rule-based Retrieval! We welcome contributions from the community to help improve and expand the capabilities of this project. 4 | 5 | ## How to Contribute 6 | 7 | ### Reporting Issues 8 | 9 | If you encounter any bugs, have feature requests, or want to suggest improvements, please [open an issue](https://github.com/whyhow-ai/rule-based-retrieval/issues) on the GitHub repository. When creating an issue, please provide as much detail as possible, including steps to reproduce the problem, expected behavior, and any relevant code snippets or screenshots. 10 | 11 | ### Submitting Pull Requests 12 | 13 | We encourage you to submit pull requests for bug fixes, new features, or improvements to the existing codebase. To submit a pull request, please follow these steps: 14 | 15 | 1. Fork the repository and create a new branch for your feature or bug fix. 16 | 2. Make your changes in the new branch, ensuring that your code follows the project's coding style and conventions. 17 | 3. Write appropriate tests for your changes and ensure that all existing tests pass. 18 | 4. Update the documentation, including README.md and API references, if necessary. 19 | 5. Commit your changes with descriptive commit messages. 20 | 6. Push your changes to your forked repository. 21 | 7. Open a pull request from your branch to the main repository's `main` branch. 22 | 8. Provide a clear and detailed description of your changes in the pull request. 23 | 24 | ### Code Style and Conventions 25 | 26 | To maintain a consistent codebase, please adhere to the following guidelines: 27 | 28 | - Follow the [PEP 8](https://www.python.org/dev/peps/pep-0008/) style guide for Python code. 29 | - Use meaningful variable and function names that accurately describe their purpose. 30 | - Write docstrings for all public functions, classes, and modules following the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings). 31 | - Ensure that your code is properly formatted and passes the project's linting checks. 32 | 33 | ### Testing 34 | 35 | The Rule-based Retrieval package uses [pytest](https://docs.pytest.org/) for testing. When contributing, please make sure to: 36 | 37 | - Write unit tests for any new functionality or bug fixes. 38 | - Ensure that all existing tests pass before submitting a pull request. 39 | - Add integration tests if your changes involve multiple components or complex scenarios. 40 | 41 | To run the tests, use the following command: 42 | 43 | ```shell 44 | pytest tests/ 45 | ``` 46 | 47 | ### Documentation 48 | Keeping the documentation up to date is crucial for the usability and adoption of the package. If your contributions involve changes to the public API or introduce new features, please update the relevant documentation files, including: 49 | 50 | * README.md: Update the main README file with any necessary changes or additions. 51 | * docs/: Update the corresponding files in the docs/ directory to reflect your changes. 52 | 53 | ### Code of Conduct 54 | Please note that this project adheres to the Contributor Covenant Code of Conduct. By participating in this project, you are expected to uphold this code. Please report any unacceptable behavior to the project maintainers. 55 | 56 | ### License 57 | By contributing to this project, you agree that your contributions will be licensed under the MIT License. 58 | 59 | Thank you for your contributions and happy coding! -------------------------------------------------------------------------------- /DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whyhow-ai/rule-based-retrieval/91701f45822823d6c54cac3b526e43cdb409e4e3/DS_Store -------------------------------------------------------------------------------- /LICENCE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 WhyHow.ai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Rule-based Retrieval 2 | 3 | [![Python Version](https://img.shields.io/badge/python-3.10%2B-blue)](https://www.python.org/downloads/) 4 | [![License](https://img.shields.io/badge/license-MIT-green)](https://opensource.org/licenses/MIT) 5 | [![PyPI Version](https://img.shields.io/pypi/v/rule-based-retrieval)](https://pypi.org/project/rule-based-retrieval/) 6 | [![Code Style: Black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 7 | [![Checked with mypy](https://img.shields.io/badge/mypy-checked-blue)](https://mypy-lang.org/) 8 | [![Whyhow Discord](https://dcbadge.vercel.app/api/server/PAgGMxfhKd?compact=true&style=flat)](https://discord.gg/PAgGMxfhKd) 9 | 10 | The Rule-based Retrieval package is a Python package that enables you to create and manage Retrieval Augmented Generation (RAG) applications with advanced filtering capabilities. It seamlessly integrates with OpenAI for text generation and Pinecone for efficient vector database management. 11 | 12 | # Installation 13 | 14 | ### Prerequisites 15 | 16 | - Python 3.10 or higher 17 | - OpenAI API key 18 | - Pinecone, Milvus or Qdrant credentials 19 | 20 | ### Install from PyPI 21 | 22 | You can install the package directly from PyPI using pip: 23 | 24 | ```shell 25 | pip install rule-based-retrieval 26 | ``` 27 | 28 | ### Install from GitHub 29 | 30 | Alternatively, you can clone the repo and install the package: 31 | 32 | ```shell 33 | git clone git@github.com:whyhow-ai/rule-based-retrieval.git 34 | cd rule-based-retrieval 35 | pip install . 36 | ``` 37 | 38 | ### Developer Install 39 | 40 | For a developer installation, use an editable install and include the development dependencies: 41 | 42 | ```shell 43 | pip install -e .[dev] 44 | ``` 45 | 46 | For ZSH: 47 | 48 | ```shell 49 | pip install -e ".[dev]" 50 | ``` 51 | 52 | If you want to install the package directly without explicitly cloning yourself 53 | run 54 | 55 | ```shell 56 | pip install git+ssh://git@github.com/whyhow-ai/rule-based-retrieval 57 | ``` 58 | 59 | # Documentation 60 | 61 | Documentation can be found [here](https://whyhow-ai.github.io/rule-based-retrieval/). 62 | 63 | To serve the docs locally run 64 | 65 | ```shell 66 | pip install -e .[docs] 67 | mkdocs serve 68 | ``` 69 | 70 | For ZSH: 71 | 72 | ```shell 73 | pip install -e ".[docs]" 74 | mkdocs serve 75 | ``` 76 | 77 | Navigate to http://127.0.0.1:8000/ in your browser to view the documentation. 78 | 79 | # Examples 80 | 81 | Check out the `examples/` directory for sample scripts demonstrating how to use the Rule-based Retrieval package. 82 | 83 | # How to 84 | 85 | ### [Demo](https://www.loom.com/share/089101b455b34701875b9f362ba16b89) 86 | `whyhow_rbr` offers different ways to implement Rule-based Retrieval through two databases and down below are the documentations(tutorial and example) for each implementation: 87 | 88 | - [Milvus](docs/milvus.md) 89 | - [Pinecone](docs/pinecone.md) 90 | - [Qdrant](docs/qdrant.md) 91 | 92 | # Contributing 93 | 94 | We welcome contributions to improve the Rule-based Retrieval package! If you have any ideas, bug reports, or feature requests, please open an issue on the GitHub repository. 95 | 96 | If you'd like to contribute code, please follow these steps: 97 | 98 | 1. Fork the repository 99 | 2. Create a new branch for your feature or bug fix 100 | 3. Make your changes and commit them with descriptive messages 101 | 4. Push your changes to your forked repository 102 | 5. Open a pull request to the main repository 103 | 104 | ### License 105 | 106 | This project is licensed under the MIT License. 107 | 108 | ### Support 109 | 110 | WhyHow.AI is building tools to help developers bring more determinism and control to their RAG pipelines using graph structures. If you're thinking about, in the process of, or have already incorporated knowledge graphs in RAG, we’d love to chat at team@whyhow.ai, or follow our newsletter at [WhyHow.AI](https://www.whyhow.ai/). Join our discussions about rules, determinism and knowledge graphs in RAG on our newly-created [Discord](https://discord.com/invite/9bWqrsxgHr). 111 | -------------------------------------------------------------------------------- /docs/api.md: -------------------------------------------------------------------------------- 1 | # Reference 2 | 3 | ## `whyhow.embedding` module 4 | 5 | ### ::: whyhow_rbr.embedding.generate_embeddings 6 | 7 | ## `whyhow.exceptions` module 8 | 9 | ### ::: whyhow_rbr.exceptions.IndexAlreadyExistsException 10 | 11 | ### ::: whyhow_rbr.exceptions.IndexNotFoundException 12 | 13 | ### ::: whyhow_rbr.exceptions.OpenAIException 14 | 15 | ## `whyhow.processing` module 16 | 17 | ### ::: whyhow_rbr.processing.parse_and_split 18 | 19 | ### ::: whyhow_rbr.processing.clean_chunks 20 | 21 | ## `whyhow.rag` module 22 | 23 | ### ::: whyhow_rbr.rag.Client 24 | 25 | ### ::: whyhow_rbr.rag.Rule 26 | 27 | ### ::: whyhow_rbr.rag.PineconeMetadata 28 | 29 | ### ::: whyhow_rbr.rag.PineconeDocument 30 | 31 | ### ::: whyhow_rbr.rag.PineconeMatch 32 | 33 | ### ::: whyhow_rbr.rag.Input 34 | 35 | ### ::: whyhow_rbr.rag.Output 36 | 37 | #### ::: whyhow_rbr.rag.Client.query 38 | 39 | :docstring: 40 | :members: 41 | :undoc-members: 42 | :show-inheritance: 43 | :special-members: __init__ 44 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Welcome to Rule-based Retrieval Documentation 2 | 3 | ![Python 3.10](https://img.shields.io/badge/python-3.10-blue.svg) 4 | ![Python 3.11](https://img.shields.io/badge/python-3.11-blue.svg) 5 | 6 | The Rule-based Retrieval package is a Python package for creating Retrieval Augmented Generation (RAG) applications with filtering capabilities. It leverages OpenAI for text generation and Pinecone for vector database management. 7 | 8 | ## Key Features 9 | 10 | - Easy-to-use API for creating and managing Pinecone indexes 11 | - Uploading and processing documents (currently supports PDF files) 12 | - Generating embeddings using OpenAI models 13 | - Querying the index with custom filtering rules 14 | - Retrieval Augmented Generation for question answering 15 | - Querying the index with custom filtering rules, including processing rules separately and triggering rules based on keywords 16 | 17 | ## Getting Started 18 | 19 | 1. Install the package by following the [Installation Guide](installation.md) 20 | 2. Set up your OpenAI and Pinecone API keys as environment variables 21 | 3. Create an index and upload your documents using the `Client` class 22 | 4. Query the index with custom rules to retrieve relevant documents, optionally processing rules separately or triggering rules based on keywords 23 | 5. Use the retrieved documents to generate answers to your questions 24 | 25 | For a detailed walkthrough and code examples, check out the [Tutorial](tutorial.md). 26 | 27 | ## Architecture Overview 28 | 29 | The Rule-based Retrieval package consists of the following main components: 30 | 31 | - `Client`: The central class for managing resources and performing RAG-related tasks 32 | - `Rule`: Allows defining custom filtering rules for retrieving documents 33 | - `PineconeMetadata` and `PineconeDocument`: Classes for representing and storing document metadata and embeddings in Pinecone 34 | - `embedding`, `processing`, and `exceptions` modules: Utility functions and custom exceptions -------------------------------------------------------------------------------- /docs/installation.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | To install the Rule-based Retrieval package, follow these steps: 4 | 5 | ## Prerequisites 6 | 7 | - Python 3.10 or higher 8 | - OpenAI API key 9 | - Pinecone API key 10 | 11 | ## Install from GitHub 12 | 13 | Clone the repository: 14 | 15 | ```shell 16 | git clone git@github.com:whyhow-ai/rule-based-retrieval.git 17 | cd rule-based-retrieval 18 | ``` 19 | 20 | Install the packages: 21 | 22 | ```shell 23 | pip install . 24 | ``` 25 | 26 | Set the required environment variables: 27 | 28 | ```shell 29 | export OPENAI_API_KEY= 30 | export PINECONE_API_KEY= 31 | ``` 32 | 33 | ## Developer Installation 34 | 35 | For a developer installation, use an editable install and include the development dependencies: 36 | 37 | ```shell 38 | pip install -e .[dev] 39 | ``` 40 | 41 | For ZSH: 42 | 43 | ```shell 44 | pip install -e ".[dev]" 45 | ``` 46 | 47 | ## Install Documentation Dependencies 48 | 49 | To build and serve the documentation locally, install the documentation dependencies: 50 | 51 | ```shell 52 | pip install -e .[docs] 53 | ``` 54 | 55 | For ZSH: 56 | 57 | ```shell 58 | pip install -e ".[docs]" 59 | ``` 60 | 61 | Then, use mkdocs to serve the documentation: 62 | 63 | ```shell 64 | mkdocs serve 65 | ``` 66 | 67 | Navigate to http://127.0.0.1:8000/ in your browser to view the documentation. 68 | 69 | ## Troubleshooting 70 | 71 | If you encounter any issues during installation, please check the following: 72 | 73 | - Ensure that you have Python 3.10 or higher installed. You can check your Python version by running `python --version` in your terminal. 74 | - Make sure that you have correctly set the `OPENAI_API_KEY` and `PINECONE_API_KEY` environment variables with your respective API keys. 75 | - If you are installing from GitHub, ensure that you have cloned the repository correctly and are in the right directory. 76 | - If you are using a virtual environment, make sure that it is activated before running the installation commands. 77 | - If you still face problems, please open an issue on the GitHub repository with detailed information about the error and your environment setup. 78 | -------------------------------------------------------------------------------- /docs/milvus.md: -------------------------------------------------------------------------------- 1 | # Tutorial of Rule-based Retrieval through Milvus 2 | 3 | The `whyhow_rbr` package helps create customized RAG pipelines. It is built on top 4 | of the following technologies (and their respective Python SDKs) 5 | 6 | - **OpenAI** - text generation 7 | - **Milvus** - vector database 8 | 9 | ## Initialization 10 | 11 | Please import some essential package 12 | ```python 13 | from pymilvus import DataType 14 | 15 | from src.whyhow_rbr.rag_milvus import ClientMilvus 16 | ``` 17 | 18 | ## Client 19 | 20 | The central object is a `ClientMilvus`. It manages all necessary resources 21 | and provides a simple interface for all the RAG related tasks. 22 | 23 | First of all, to instantiate it one needs to provide the following 24 | credentials: 25 | 26 | - `OPENAI_API_KEY` 27 | - `Milvus_URI` 28 | - `Milvus_API_TOKEN` 29 | 30 | Initialize the ClientMilvus like this: 31 | 32 | ```python 33 | # Set up your Milvus Cloud information 34 | YOUR_MILVUS_CLOUD_END_POINT="YOUR_MILVUS_CLOUD_END_POINT" 35 | YOUR_MILVUS_CLOUD_TOKEN="YOUR_MILVUS_CLOUD_TOKEN" 36 | 37 | # Initialize the ClientMilvus 38 | milvus_client = ClientMilvus( 39 | milvus_uri=YOUR_MILVUS_CLOUD_END_POINT, 40 | milvus_token=YOUR_MILVUS_CLOUD_TOKEN 41 | ) 42 | ``` 43 | 44 | ## Vector database operations 45 | 46 | This tutorial `whyhow_rbr` uses Milvus for everything related to vector databses. 47 | 48 | ### Defining necessary variables 49 | 50 | ```python 51 | # Define collection name 52 | COLLECTION_NAME="YOUR_COLLECTION_NAME" # take your own collection name 53 | 54 | # Define vector dimension size 55 | DIMENSION=1536 # decide by the model you use 56 | ``` 57 | 58 | ### Add schema 59 | 60 | Before inserting any data into Milvus database, we need to first define the data field, which is called schema in here. Through create object `CollectionSchema` and add data field through `addd_field()`, we can control our data type and their characteristics. This step is required. 61 | 62 | ```python 63 | schema = milvus_client.create_schema(auto_id=True) # Enable id matching 64 | 65 | schema = milvus_client.add_field(schema=schema, field_name="id", datatype=DataType.INT64, is_primary=True) 66 | schema = milvus_client.add_field(schema=schema, field_name="embedding", datatype=DataType.FLOAT_VECTOR, dim=DIMENSION) 67 | ``` 68 | We only defined `id` and `embedding` here because we need to define a primary field for each collection. For embedding, we need to define the dimension. We allow `enable_dynamic_field` which support auto adding schema, but we still encourage you to add schema by yourself. This method is a thin wrapper around the official Milvus implementation ([official docs](https://milvus.io/api-reference/pymilvus/v2.4.x/MilvusClient/Collections/create_schema.md)) 69 | 70 | ### Creating an index 71 | 72 | For each schema, it is better to have an index so that the querying will be much more efficient. To create an index, we first need an index_params and later add more index data on this `IndexParams` object. 73 | ```python 74 | # Start to indexing data field 75 | index_params = milvus_client.prepare_index_params() 76 | index_params = milvus_client.add_index( 77 | index_params=index_params, # pass in index_params object 78 | field_name="embedding", 79 | index_type="AUTOINDEX", # use autoindex instead of other complex indexing method 80 | metric_type="COSINE", # L2, COSINE, or IP 81 | ) 82 | ``` 83 | This method is a thin wrapper around the official Milvus implementation ([official docs](https://milvus.io/api-reference/pymilvus/v2.4.x/MilvusClient/Management/add_index.md)). 84 | 85 | ### Create Collection 86 | 87 | After defining all the data field and indexing them, we now need to create our database collection so that we can access our data quick and precise. What's need to be mentioned is that we initialized the `enable_dynamic_field` to be true so that you can upload any data freely. The cost is the data querying might be inefficient. 88 | ```python 89 | # Create Collection 90 | milvus_client.create_collection( 91 | collection_name=COLLECTION_NAME, 92 | schema=schema, 93 | index_params=index_params 94 | ) 95 | ``` 96 | 97 | ## Uploading documents 98 | 99 | After creating a collection, we are ready to populate it with documents. In 100 | `whyhow_rbr` this is done using the `upload_documents` method of the `MilvusClient`. 101 | It performs the following steps under the hood: 102 | 103 | - **Preprocessing**: Reading and splitting the provided PDF files into chunks 104 | - **Embedding**: Embedding all the chunks using an OpenAI model 105 | - **Inserting**: Uploading both the embeddings and the metadata to a Milvus collection 106 | 107 | See below an example of how to use it. 108 | 109 | ```python 110 | # get pdfs 111 | pdfs = ["harry-potter.pdf", "game-of-thrones.pdf"] # replace to your pdfs path 112 | 113 | # Uploading the PDF document 114 | milvus_client.upload_documents( 115 | collection_name=COLLECTION_NAME, 116 | documents=pdfs 117 | ) 118 | ``` 119 | ## Question answering 120 | 121 | Now we can finally move to retrieval augmented generation. 122 | 123 | In `whyhow_rbr` with Milvus, it can be done via the `search` method. 124 | 125 | 1. Simple example: 126 | 127 | ```python 128 | # Search data and implement RAG! 129 | res = milvus_client.search( 130 | question='What food does Harry Potter like to eat?', 131 | collection_name=COLLECTION_NAME, 132 | anns_field='embedding', 133 | output_fields='text' 134 | ) 135 | print(res['answer']) 136 | print(res['matches']) 137 | ``` 138 | 139 | The `result` is a dictionary that has the following keys 140 | 141 | - `answer` - the the answer to the question 142 | - `matches` - the `limit` most relevant documents from the index 143 | 144 | Note that the number of matches will be in general equal to `limit` which 145 | can be specified as a parameter. 146 | 147 | ### Clean up 148 | 149 | At last, after implemented all the instructuons, you can clean up the database 150 | by calling `drop_collection()`. 151 | ```python 152 | # Clean up 153 | milvus_client.drop_collection( 154 | collection_name=COLLECTION_NAME 155 | ) 156 | ``` 157 | 158 | ### Rules 159 | 160 | In the previous example, every single document in our index was considered. 161 | However, sometimes it might be beneficial to only retrieve documents satisfying some 162 | predefined conditions (e.g. `filename=harry-potter.pdf`). In `whyhow_rbr` through Milvus, this 163 | can be done via adjusting searching parameters. 164 | 165 | A rule can control the following metadata attributes 166 | 167 | - `filename` - name of the file 168 | - `page_numbers` - list of integers corresponding to page numbers (0 indexing) 169 | - `id` - unique identifier of a chunk (this is the most "extreme" filter) 170 | - Other rules base on [Boolean Expressions](https://milvus.io/docs/boolean.md) 171 | 172 | Rules Example: 173 | 174 | ```python 175 | # RULES(search on book harry-potter on page 8): 176 | PARTITION_NAME='harry-potter' # search on books 177 | page_number='page_number == 8' 178 | 179 | # first create a partitions to store the book and later search on this specific partition: 180 | milvus_client.crate_partition( 181 | collection_name=COLLECTION_NAME, 182 | partition_name=PARTITION_NAME # separate base on your pdfs type 183 | ) 184 | 185 | # search with rules 186 | res = milvus_client.search( 187 | question='Tell me about the greedy method', 188 | collection_name=COLLECTION_NAME, 189 | partition_names=PARTITION_NAME, 190 | filter=page_number, # append any rules follow the Boolean Expression Rule 191 | anns_field='embedding', 192 | output_fields='text' 193 | ) 194 | print(res['answer']) 195 | print(res['matches']) 196 | ``` 197 | 198 | In this example, we first create a partition that store harry-potter related pdfs, and through searching within this partition, we can get the most direct information. 199 | Also, we apply page number as a filter to specify the exact page we wish to search on. 200 | Remember, the filer parameter need to follow the [boolean rule](https://milvus.io/docs/boolean.md). 201 | 202 | That's all for the Milvus implementation of Rule-based Retrieval. -------------------------------------------------------------------------------- /docs/pinecone.md: -------------------------------------------------------------------------------- 1 | # How to do Rule-based Retrieval through Pinecone 2 | 3 | Here is the brief introduction of the main functions of Rule-based Retrieval. For more specific tutorial of Pinecone, you can find it [here](tutorial.md). 4 | 5 | ## Set up the environment 6 | 7 | ```python 8 | export OPENAI_API_KEY= 9 | export PINECONE_API_KEY= 10 | ``` 11 | 12 | ## Create index & upload 13 | 14 | ```shell 15 | from whyhow_rbr import Client 16 | 17 | # Configure parameters 18 | index_name = "whyhow-demo" 19 | namespace = "demo" 20 | pdfs = ["harry_potter_book_1.pdf"] 21 | 22 | # Initialize client 23 | client = Client() 24 | 25 | # Create index 26 | index = client.get_index(index_name) 27 | 28 | # Upload, split, chunk, and vectorize documents in Pinecone 29 | client.upload_documents(index=index, documents=pdfs, namespace=namespace) 30 | ``` 31 | 32 | ## Query with rules 33 | 34 | ```shell 35 | from whyhow_rbr import Client, Rule 36 | 37 | # Configure query parameters 38 | index_name = "whyhow-demo" 39 | namespace = "demo" 40 | question = "What does Harry wear?" 41 | top_k = 5 42 | 43 | # Initialize client 44 | client = Client() 45 | 46 | # Create rules 47 | rules = [ 48 | Rule( 49 | filename="harry_potter_book_1.pdf", 50 | page_numbers=[21, 22, 23] 51 | ), 52 | Rule( 53 | filename="harry_potter_book_1.pdf", 54 | page_numbers=[151, 152, 153, 154] 55 | ) 56 | ] 57 | 58 | # Run query 59 | result = client.query( 60 | question=question, 61 | index=index, 62 | namespace=namespace, 63 | rules=rules, 64 | top_k=top_k, 65 | ) 66 | 67 | answer = result["answer"] 68 | used_contexts = [ 69 | result["matches"][i]["metadata"]["text"] for i in result["used_contexts"] 70 | ] 71 | print(f"Answer: {answer}") 72 | print( 73 | f"The model used {len(used_contexts)} chunk(s) from the DB to answer the question" 74 | ) 75 | ``` 76 | 77 | ## Query with keywords 78 | 79 | ```shell 80 | from whyhow_rbr import Client, Rule 81 | 82 | client = Client() 83 | 84 | index = client.get_index("amazing-index") 85 | namespace = "books" 86 | 87 | question = "What does Harry Potter like to eat?" 88 | 89 | rule = Rule( 90 | filename="harry-potter.pdf", 91 | keywords=["food", "favorite", "likes to eat"] 92 | ) 93 | 94 | result = client.query( 95 | question=question, 96 | index=index, 97 | namespace=namespace, 98 | rules=[rule], 99 | keyword_trigger=True 100 | ) 101 | 102 | print(result["answer"]) 103 | print(result["matches"]) 104 | print(result["used_contexts"]) 105 | ``` 106 | 107 | ## Query each rule separately 108 | 109 | ```shell 110 | from whyhow_rbr import Client, Rule 111 | 112 | client = Client() 113 | 114 | index = client.get_index("amazing-index") 115 | namespace = "books" 116 | 117 | question = "What is Harry Potter's favorite food?" 118 | 119 | rule_1 = Rule( 120 | filename="harry-potter.pdf", 121 | page_numbers=[120, 121, 150] 122 | ) 123 | 124 | rule_2 = Rule( 125 | filename="harry-potter-volume-2.pdf", 126 | page_numbers=[80, 81, 82] 127 | ) 128 | 129 | result = client.query( 130 | question=question, 131 | index=index, 132 | namespace=namespace, 133 | rules=[rule_1, rule_2], 134 | process_rules_separately=True 135 | ) 136 | 137 | print(result["answer"]) 138 | print(result["matches"]) 139 | print(result["used_contexts"]) 140 | ``` -------------------------------------------------------------------------------- /docs/tutorial.md: -------------------------------------------------------------------------------- 1 | # Tutorial 2 | 3 | The `whyhow_rbr` package helps create customized RAG pipelines. It is built on top 4 | of the following technologies (and their respective Python SDKs) 5 | 6 | - **OpenAI** - text generation 7 | - **Pinecone** - vector database 8 | 9 | ## Client 10 | 11 | The central object is a `Client`. It manages all necessary resources 12 | and provides a simple interface for all the RAG related tasks. 13 | 14 | First of all, to instantiate it one needs to provide the following 15 | API keys: 16 | 17 | - `OPENAI_API_KEY` 18 | - `PINECONE_API_KEY` 19 | 20 | One can either define the corresponding environment variables 21 | 22 | ```shell 23 | export OPENAI_API_KEY=... 24 | export PINECONE_API_KEY... 25 | ``` 26 | 27 | and then instantiate the client without any arguments. 28 | 29 | ```python title="getting_started.py" 30 | from whyhow_rbr import Client 31 | 32 | client = Client() 33 | 34 | ``` 35 | 36 | ```shell 37 | python getting_started.py 38 | ``` 39 | 40 | An alternative approach is to manually pass the keys when the client is 41 | being constructed 42 | 43 | ```python title="getting_started.py" 44 | from whyhow_rbr import Client 45 | 46 | client = Client( 47 | openai_api_key="...", 48 | pinecone_api_key="..." 49 | 50 | ) 51 | ``` 52 | 53 | ```shell 54 | python getting_started.py 55 | ``` 56 | 57 | ## Vector database operations 58 | 59 | `whyhow_rbr` uses Pinecone for everything related to vector databses. 60 | 61 | ### Creating an index 62 | 63 | If you don't have a Pinecone index yet, you can create it using the 64 | `create_index` method of the `Client`. This method 65 | is a thin wrapper around the Pinecone SDK ([official docs](https://docs.pinecone.io/docs/create-an-index)). 66 | 67 | First of all, you need to provide a specification. There are 2 types 68 | 69 | - **Serverless** 70 | - **Pod-based** 71 | 72 | #### Serverless 73 | 74 | To create a serverless index you can use 75 | 76 | ```python 77 | # Code above omitted 👆 78 | 79 | from pinecone import ServerlessSpec 80 | 81 | spec = ServerlessSpec( 82 | cloud="aws", 83 | region="us-west-2" 84 | ) 85 | 86 | index = client.create_index( 87 | name="great-index", # the only required argument 88 | dimension=1536 89 | metric="cosine", 90 | spec=spec 91 | ) 92 | ``` 93 | 94 | ??? note "Full code" 95 | 96 | ```python 97 | from pinecone import ServerlessSpec 98 | 99 | from whyhow_rbr import Client 100 | 101 | client = Client() 102 | 103 | spec = ServerlessSpec( 104 | cloud="aws", 105 | region="us-west-2" 106 | ) 107 | 108 | index = client.create_index( 109 | name="great-index", # the only required argument 110 | dimension=1536 111 | metric="cosine", 112 | spec=spec 113 | ) 114 | 115 | ``` 116 | 117 | #### Pod-based 118 | 119 | To create a pod-based index you can use 120 | 121 | ```python 122 | # Code above omitted 👆 123 | 124 | from pinecone import PodSpec 125 | 126 | spec = PodSpec( 127 | environment="gcp-starter" 128 | ) 129 | 130 | index = client.create_index( 131 | name="amazing-index", # the only required argument 132 | dimension=1536 133 | metric="cosine", 134 | spec=spec 135 | ) 136 | ``` 137 | 138 | ??? note "Full code" 139 | 140 | ```python 141 | from pinecone import PodSpec 142 | 143 | from whyhow_rbr import Client 144 | 145 | client = Client() 146 | 147 | spec = PodSpec( 148 | environment="gcp-starter" 149 | ) 150 | 151 | index = client.create_index( 152 | name="amazing-index", # the only required argument 153 | dimension=1536 154 | metric="cosine", 155 | spec=spec 156 | ) 157 | 158 | ``` 159 | 160 | !!! info 161 | For detailed information on what all of the parameters mean 162 | please refer to [(Pinecone) Understanding indices](https://docs.pinecone.io/docs/indexes) 163 | 164 | ### Getting an existing index 165 | 166 | If your exists already, you can use the `get_index` method to get it. 167 | 168 | ```python 169 | # Code above omitted 👆 170 | 171 | index = client.get_index("amazing-index") 172 | 173 | ``` 174 | 175 | ??? note "Full code" 176 | 177 | ```python 178 | from pinecone import PodSpec 179 | 180 | from whyhow_rbr import Client 181 | 182 | client = Client() 183 | 184 | index = client.get_index("amazing-index") 185 | 186 | ``` 187 | 188 | ### Index operations 189 | 190 | Both `create_index` and `get_index` return an instance of `pinecone.Index`. 191 | It offers multiple convenience methods. See below a few examples. 192 | 193 | #### `describe_index_stats` 194 | 195 | Shows useful information about the index. 196 | 197 | ```python 198 | index.describe_index_stats() 199 | ``` 200 | 201 | Example output: 202 | 203 | ```python 204 | {'dimension': 1536, 205 | 'index_fullness': 0.00448, 206 | 'namespaces': {'A': {'vector_count': 11}, 207 | 'B': {'vector_count': 11}, 208 | 'C': {'vector_count': 62}, 209 | 'D': {'vector_count': 82}, 210 | 'E': {'vector_count': 282}}, 211 | 'total_vector_count': 448} 212 | 213 | ``` 214 | 215 | #### `fetch` 216 | 217 | [Fetch (Pinecone docs)](https://docs.pinecone.io/docs/fetch-data) 218 | 219 | #### `upsert` 220 | 221 | [Upsert (Pinecone docs)](https://docs.pinecone.io/docs/upsert-data) 222 | 223 | #### `query` 224 | 225 | [Query (Pinecone docs)](https://docs.pinecone.io/docs/query-data) 226 | 227 | #### `delete` 228 | 229 | [Delete (Pinecone docs)](https://docs.pinecone.io/docs/delete-data) 230 | 231 | #### `update` 232 | 233 | [Update (Pinecone docs)](https://docs.pinecone.io/docs/update-data) 234 | 235 | ## Uploading documents 236 | 237 | After creating an index, we are ready to populate it with documents. In 238 | `whyhow_rbr` this is done using the `upload_documents` method of the `Client`. 239 | It performs the following steps under the hood: 240 | 241 | - **Preprocessing**: Reading and splitting the provided PDF files into chunks 242 | - **Embedding**: Embedding all the chunks using an OpenAI model 243 | - **Upserting**: Uploading both the embeddings and the metadata to a Pinecone index 244 | 245 | See below an example of how to use it. 246 | 247 | ```python 248 | # Code above omitted 👆 249 | 250 | namespace = "books" 251 | pdfs = ["harry-potter.pdf", "game-of-thrones.pdf"] 252 | 253 | client.upload_documents( 254 | index=index, 255 | documents=pdfs, 256 | namespace=namespace 257 | ) 258 | 259 | ``` 260 | 261 | ??? note "Full code" 262 | 263 | ```python 264 | from whyhow_rbr import Client 265 | 266 | client = Client() 267 | 268 | index = client.get_index("amazing-index") 269 | 270 | namespace = "books" 271 | pdfs = ["harry-potter.pdf", "game-of-thrones.pdf"] 272 | 273 | client.upload_documents( 274 | index=index, 275 | documents=pdfs, 276 | namespace=namespace 277 | ) 278 | 279 | ``` 280 | 281 | !!! warning 282 | 283 | The above example assumes you have two PDFs on your disk. 284 | 285 | * `harry-potter.pdf` 286 | * `game-of-thrones.pdf` 287 | 288 | However, feel free to provide different documents. 289 | 290 | !!! info 291 | 292 | The `upload_documents` method does not return anything. If you want to 293 | get some information about what is going on you can activate logging. 294 | 295 | 296 | ```python 297 | import logging 298 | 299 | logging.basicConfig( 300 | level=logging.WARNING, 301 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 302 | ) 303 | ``` 304 | Note that the above affects the root logger, however, you can also 305 | just customize the `whyhow_rbr` logger. 306 | 307 | Navigate to [upload_documents (API docs)](./api.md#whyhow.rag.Client.upload_documents) 308 | if you want to get more information on the parameters. 309 | 310 | ### Index schema 311 | 312 | While Pinecone does not require each document in an index to have the same schema 313 | all the document uploaded via the `upload_documents` will have a fixed schema. 314 | This schema is defined in [PineconeDocument (API docs)](./api.md#whyhow_rbr.rag.PineconeDocument). 315 | This is done in order to have a predictable set of attributes that 316 | can be used to perform advanced filtering (via rules). 317 | 318 | ## Question answering 319 | 320 | In previous sections we discussed how to to create an index and 321 | populate it with documents. Now we can finally move to retrieval augmented generation. 322 | 323 | In `whyhow_rbr`, it can be done via the `query` method. 324 | 325 | 1. Simple example: 326 | 327 | ```python 328 | 329 | from whyhow_rbr import Client, Rule 330 | 331 | client = Client() 332 | 333 | index = client.get_index("amazing-index") 334 | namespace = "books" 335 | 336 | question = "What is Harry Potter's favorite food?" 337 | 338 | rule = Rule( 339 | filename="harry-potter.pdf", 340 | page_numbers=[120, 121, 150] 341 | ) 342 | 343 | result = client.query( 344 | question=question, 345 | index=index, 346 | namespace=namespace, 347 | rules=[rule] 348 | ) 349 | 350 | print(result["answer"]) 351 | print(result["matches"]) 352 | print(result["used_contexts"]) 353 | 354 | ``` 355 | 356 | The `result` is a dictionary that has the following three keys 357 | 358 | - `answer` - the the answer to the question 359 | - `matches` - the `top_k` most relevant documents from the index 360 | - `used_contexts` - the matches (or more precisely just the texts/contexts) that 361 | the LLM used to answer the question. 362 | 363 | ```python 364 | print(result["answer"]) 365 | ``` 366 | 367 | ```python 368 | 'Treacle tart' 369 | ``` 370 | 371 | ```python 372 | print(result["matches"]) 373 | ``` 374 | 375 | ```python 376 | [{'id': 'harry-potter.pdf-120-5', 377 | 'metadata': {'chunk_number': 5, 378 | 'filename': 'harry-potter.pdf', 379 | 'page_number': 120, 380 | 'text': 'Harry loves the treacle tart.' 381 | 'uuid': '86314e32-7d88-475c-b950-f8c156ebf259'}, 382 | 'score': 0.826438308}, 383 | {'id': 'game-of-thrones.pdf-75-1', 384 | 'metadata': {'chunk_number': 1, 385 | 'filename': 'game-of-thrones.pdf', 386 | 'page_number': 75, 387 | 'text': 'Harry Strickland was the head of the exiled House Strickland.' 388 | 'He enjoys eating roasted beef.' 389 | 'uuid': '684a978b-e6e7-45e2-8ba4-5c5019c7c676'}, 390 | 'score': 0.2052352}, 391 | ... 392 | ] 393 | ``` 394 | 395 | Note that the number of matches will be in general equal to `top_k` which 396 | can be specified as a parameter. Also, each match has a fixed schema - 397 | it is a dump of [PineconeMatch (API docs)](./api.md#whyhow_rbr.rag.PineconeMatch). 398 | 399 | ```python 400 | print(result["used_contexts"]) 401 | ``` 402 | 403 | ```python 404 | [0] 405 | ``` 406 | 407 | The OpenAI model only used the context from the 1st match when answering the question. 408 | 409 | ??? note "Full code" 410 | 411 | ```python 412 | from whyhow_rbr import Client 413 | 414 | client = Client() 415 | 416 | index = client.get_index("amazing-index") 417 | 418 | namespace = "books" 419 | 420 | question = "What is Harry Potter's favourite food?" 421 | 422 | result = client.query( 423 | question=question 424 | index=index, 425 | namespace=namespace 426 | ) 427 | 428 | print(result["answer"]) 429 | print(result["matches"]) 430 | print(result["used_contexts"]) 431 | 432 | ``` 433 | 434 | Navigate to [query(API docs)](./api.md#whyhow_rbr.rag.Client.query) 435 | if you want to get more information on the parameters. 436 | 437 | ### Rules 438 | 439 | In the previous example, every single document in our index was considered. 440 | However, sometimes it might be beneficial to only retrieve documents satisfying some 441 | predefined conditions (e.g. `filename=harry-potter.pdf`). In `whyhow_rbr` this 442 | can be done via the `Rule` class. 443 | 444 | A rule can control the following metadata attributes 445 | 446 | - `filename` - name of the file 447 | - `page_numbers` - list of integers corresponding to page numbers (0 indexing) 448 | - `uuid` - unique identifier of a chunk (this is the most "extreme" filter) 449 | - `keywords` - list of keywords to trigger the rule 450 | 451 | 2. Keyword example: 452 | 453 | ```python 454 | # Code above omitted 👆 455 | 456 | from whyhow_rbr import Rule 457 | 458 | question = "What is Harry Potter's favourite food?" 459 | 460 | rule = Rule( 461 | filename="harry-potter.pdf", 462 | page_numbers=[120, 121, 150], 463 | keywords=["food", "favorite", "likes to eat"] 464 | ) 465 | result = client.query( 466 | question=question 467 | index=index, 468 | namespace=namespace, 469 | rules=[rule], 470 | keyword_trigger=True 471 | ) 472 | 473 | ``` 474 | 475 | In this example, the keyword_trigger parameter is set to True, and the rule includes keywords. Only the rules whose keywords match the words in the question will be applied. 476 | 477 | ??? note "Full code" 478 | 479 | ```python 480 | from whyhow_rbr import Client, Rule 481 | 482 | client = Client() 483 | 484 | index = client.get_index("amazing-index") 485 | namespace = "books" 486 | 487 | question = "What does Harry Potter like to eat?" 488 | 489 | rule = Rule( 490 | filename="harry-potter.pdf", 491 | keywords=["food", "favorite", "likes to eat"] 492 | ) 493 | 494 | result = client.query( 495 | question=question, 496 | index=index, 497 | namespace=namespace, 498 | rules=[rule], 499 | keyword_trigger=True 500 | ) 501 | 502 | print(result["answer"]) 503 | print(result["matches"]) 504 | print(result["used_contexts"]) 505 | ``` 506 | 507 | 3. Process rules separately example: 508 | 509 | Lastly, you can specify multiple rules at the same time. They 510 | will be evaluated using the `OR` logical operator. 511 | 512 | ```python 513 | # Code above omitted 👆 514 | 515 | from whyhow_rbr import Rule 516 | 517 | question = "What is Harry Potter's favorite food?" 518 | 519 | rule_1 = Rule( 520 | filename="harry-potter.pdf", 521 | page_numbers=[120, 121, 150] 522 | ) 523 | 524 | rule_2 = Rule( 525 | filename="harry-potter-volume-2.pdf", 526 | page_numbers=[80, 81, 82] 527 | ) 528 | 529 | result = client.query( 530 | question=question, 531 | index=index, 532 | namespace=namespace, 533 | rules=[rule_1, rule_2], 534 | process_rules_separately=True 535 | ) 536 | 537 | ``` 538 | 539 | In this example, the process_rules_separately parameter is set to True. This means that each rule (rule_1 and rule_2) will be processed independently, ensuring that both rules contribute to the final result set. 540 | 541 | By default, all rules are run as one joined query, which means that one rule can dominate the others, and given the limit by top_k, a lower priority rule might not return any results. However, by setting process_rules_separately to True, each rule will be processed independently, ensuring that every rule returns results, and the results will be combined at the end. 542 | 543 | Depending on the number of rules you use in your query, you may return more chunks than your LLM’s context window can handle. Be mindful of your model’s token limits and adjust your top_k and rule count accordingly. 544 | 545 | ??? note "Full code" 546 | 547 | ```python 548 | from whyhow_rbr import Client, Rule 549 | 550 | client = Client() 551 | 552 | index = client.get_index("amazing-index") 553 | namespace = "books" 554 | 555 | question = "What is Harry Potter's favorite food?" 556 | 557 | rule_1 = Rule( 558 | filename="harry-potter.pdf", 559 | page_numbers=[120, 121, 150] 560 | ) 561 | 562 | rule_2 = Rule( 563 | filename="harry-potter-volume-2.pdf", 564 | page_numbers=[80, 81, 82] 565 | ) 566 | 567 | result = client.query( 568 | question=question, 569 | index=index, 570 | namespace=namespace, 571 | rules=[rule_1, rule_2], 572 | process_rules_separately=True 573 | ) 574 | 575 | print(result["answer"]) 576 | print(result["matches"]) 577 | print(result["used_contexts"]) 578 | ``` 579 | 580 | Navigate to [Rule (API docs)](./api.md#whyhow_rbr.rag.Rule) 581 | if you want to get more information on the parameters. 582 | -------------------------------------------------------------------------------- /examples/create_index.py: -------------------------------------------------------------------------------- 1 | """Example of creating a Pinecone index and uploading documents to it.""" 2 | 3 | import logging 4 | 5 | from pinecone import PodSpec 6 | 7 | from whyhow_rbr import Client, IndexNotFoundException 8 | 9 | # Parameters 10 | index_name = "" # Replace with your index name 11 | namespace = "" # Replace with your namespace name 12 | pdfs = [ 13 | "", 14 | "", 15 | ] # Replace with the paths to your PDFs, e.g. ["path/to/pdf1.pdf", "path/to/pdf2.pdf 16 | logging_level = logging.INFO 17 | 18 | # Logging 19 | logging.basicConfig( 20 | level=logging.WARNING, 21 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 22 | ) 23 | logger = logging.getLogger("create_index") 24 | logger.setLevel(logging_level) 25 | 26 | # Define OPENAI_API_KEY and PINECONE_API_KEY as environment variables 27 | client = Client() 28 | 29 | try: 30 | index = client.get_index(index_name) 31 | logger.info(f"Index {index_name} already exists, reusing it") 32 | except IndexNotFoundException: 33 | spec = PodSpec(environment="gcp-starter") 34 | index = client.create_index(index_name, spec=spec) 35 | logger.info(f"Index {index_name} created") 36 | 37 | client.upload_documents(index=index, documents=pdfs, namespace=namespace) 38 | -------------------------------------------------------------------------------- /examples/demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## RULE BASED RETRIEVAL " 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "### aka why won't my LLM do what I tell it to when I tell it to " 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "### SETUP" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 1, 27 | "metadata": {}, 28 | "outputs": [ 29 | { 30 | "name": "stderr", 31 | "output_type": "stream", 32 | "text": [ 33 | "/Users/tomsmoker/Projects/whyhow/rule-based-retrieval/venv/lib/python3.10/site-packages/pinecone/data/index.py:1: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 34 | " from tqdm.autonotebook import tqdm\n" 35 | ] 36 | } 37 | ], 38 | "source": [ 39 | "import logging\n", 40 | "\n", 41 | "from pinecone import PodSpec\n", 42 | "\n", 43 | "from whyhow_rbr import Client, Rule, IndexNotFoundException" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 2, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "# Configure parameters\n", 53 | "index_name = \"whyhow-demo\"\n", 54 | "namespace = \"BC-CS688\"\n", 55 | "pdfs = [\"../data/full_book_one.pdf\"]" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 3, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "# Logging\n", 65 | "logging_level = logging.INFO\n", 66 | "\n", 67 | "logging.basicConfig(\n", 68 | " level=logging.WARNING,\n", 69 | " format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n", 70 | ")\n", 71 | "logger = logging.getLogger(\"create_index\")\n", 72 | "logger.setLevel(logging_level)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 4, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "# Initialize client\n", 82 | "client = Client()" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 5, 88 | "metadata": {}, 89 | "outputs": [ 90 | { 91 | "name": "stderr", 92 | "output_type": "stream", 93 | "text": [ 94 | "2024-04-01 16:10:36,135 - INFO - create_index - Index whyhow-demo already exists, reusing it\n" 95 | ] 96 | } 97 | ], 98 | "source": [ 99 | "try:\n", 100 | " index = client.get_index(index_name)\n", 101 | " logger.info(f\"Index {index_name} already exists, reusing it\")\n", 102 | "except IndexNotFoundException:\n", 103 | " spec = PodSpec(environment=\"gcp-starter\")\n", 104 | " index = client.create_index(index_name, spec=spec)\n", 105 | " logger.info(f\"Index {index_name} created\")" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 6, 111 | "metadata": {}, 112 | "outputs": [ 113 | { 114 | "name": "stderr", 115 | "output_type": "stream", 116 | "text": [ 117 | "Upserted vectors: 100%|██████████| 1156/1156 [00:08<00:00, 133.76it/s]\n" 118 | ] 119 | } 120 | ], 121 | "source": [ 122 | "# Upload, split, chunk, and vectorize documents in Pinecone\n", 123 | "client.upload_documents(index=index, documents=pdfs, namespace=namespace)" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "metadata": {}, 129 | "source": [ 130 | "### RULES" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 29, 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "rules = [\n", 140 | " Rule(\n", 141 | " # Replace with your filename\n", 142 | " filename=\"full_book_one.pdf\",\n", 143 | " page_numbers=[40],\n", 144 | " keywords=['friends']\n", 145 | " ),\n", 146 | " Rule(\n", 147 | " # Replace with your filename\n", 148 | " filename=\"doc2.pdf\",\n", 149 | " page_numbers=[2],\n", 150 | " keywords=[],\n", 151 | " )\n", 152 | "]" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 35, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "question = \"Who does Harry know? Like who are his friends?\"\n", 162 | "top_k = 5" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 36, 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "result = client.query(\n", 172 | " question=question,\n", 173 | " index=index,\n", 174 | " namespace=namespace,\n", 175 | " rules=rules,\n", 176 | " top_k=top_k,\n", 177 | " process_rules_separately=False,\n", 178 | " keyword_trigger=False\n", 179 | ")" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 37, 185 | "metadata": {}, 186 | "outputs": [ 187 | { 188 | "name": "stderr", 189 | "output_type": "stream", 190 | "text": [ 191 | "2024-04-01 16:18:22,626 - INFO - create_index - Answer: I don't have the context documents to answer who Harry's friends are. Please provide the relevant context or specify which Harry you are referring to.\n" 192 | ] 193 | } 194 | ], 195 | "source": [ 196 | "answer = result[\"answer\"]\n", 197 | "\n", 198 | "logger.info(f\"Answer: {answer}\")" 199 | ] 200 | }, 201 | { 202 | "cell_type": "markdown", 203 | "metadata": {}, 204 | "source": [ 205 | "### WHAT IF I WANT IT TO FIND KEYWORDS " 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 38, 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "question = \"What does Harry Potter like to eat?\"" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 39, 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "rule = Rule(\n", 224 | " filename=\"../data/full_book_one.pdf\",\n", 225 | " page_numbers=[15, 30, 45],\n", 226 | " keywords=[\"food\", \"favorite\"]\n", 227 | ")" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": 40, 233 | "metadata": {}, 234 | "outputs": [ 235 | { 236 | "name": "stdout", 237 | "output_type": "stream", 238 | "text": [ 239 | "Harry Potter likes to eat roast beef, roast chicken, pork chops, lamb chops, sausages, bacon, steak, boiled potatoes, roast potatoes, chips, Yorkshire pudding, peas, carrots, gravy, ketchup, chocolate éclairs, jam doughnuts, trifle, strawberries, jelly, rice pudding, treacle tart, and Bertie Bott's Every-Flavour Beans, but not mint humbugs.\n", 240 | "[{'id': '../data/full_book_one.pdf-85-1', 'score': 0.6526559, 'metadata': {'text': 'piled with food. He had never seen so many things he liked to eat on one table: roast beef, roast chicken, pork chops and lamb chops, sausages, bacon and steak, boiled potatoes, roast potatoes, chips, Yorkshire pudding, peas, carrots, gravy , ketchup and, for some strange reason, mint humbugs. The Dursleys had never exactly starved Harry , but he’d never been allowed to eat as much as he liked. Dudley had always taken anything that Harry really wanted, even if it made him sick. Harry', 'page_number': 85, 'chunk_number': 1, 'filename': '../data/full_book_one.pdf', 'uuid': '5361d3da-7ea6-457c-80d4-54f599976ffe'}}, {'id': '../data/full_book_one.pdf-85-2', 'score': 0.599429369, 'metadata': {'text': 'anything that Harry really wanted, even if it made him sick. Harry piled his plate with a bit of everything except the humbugs and began to eat. It was all delicious. ‘That does look good,’ said the ghost in the ruff sadly , watching Harry cut up his steak. ‘Can’t you –?’ ‘I haven’t eaten for nearly five hundred years,’ said the ghost. ‘I don’t need to, of course, but one does miss it. I don’t think I’ve introduced myself? Sir Nicholas de Mimsy-Porpington at your', 'page_number': 85, 'chunk_number': 2, 'filename': '../data/full_book_one.pdf', 'uuid': 'e90dff94-e1f5-4d02-9cbb-8d2474cd44b0'}}, {'id': '../data/full_book_one.pdf-86-2', 'score': 0.585694432, 'metadata': {'text': 'chocolate éclairs and jam doughnuts, trifle, strawberries, jelly , rice pudding ... As Harry helped himself to a treacle tart, the talk turned to their families. ‘I’m half and half,’ said Seamus. ‘Me dad’s a Muggle. Mam didn’t tell him she was a witch ’til after they were married. Bit of a nasty shock for him.’ The others laughed. ‘What about you, Neville?’ said Ron. ‘Well, my gran brought me up and she’s a witch,’ said Neville, ‘but the family thought I was all Muggle for ages. My great-uncle', 'page_number': 86, 'chunk_number': 2, 'filename': '../data/full_book_one.pdf', 'uuid': '56c894ab-3d76-4e5c-a8bf-c5eba50d4e3c'}}, {'id': '../data/full_book_one.pdf-71-0', 'score': 0.58425, 'metadata': {'text': '78 H ARRY POTTER eating the frogs than looking at the Famous Witches and Wizards cards, but Harry couldn’t keep his eyes off them. Soon he had not only Dumbledore and Morgana, but Hengist of Woodcraft, Alberic Grunnion, Circe, Paracelsus and Merlin. He finally tore his eyes away from the druidess Cliodna, who was scratching her nose, to open a bag of Bertie Bott’s Every-Flavour Beans. ‘You want to be careful with those,’ Ron warned Harry . ‘When', 'page_number': 71, 'chunk_number': 0, 'filename': '../data/full_book_one.pdf', 'uuid': '81533a65-f364-46ce-ac84-8624d4e55c83'}}, {'id': '../data/full_book_one.pdf-85-0', 'score': 0.575736284, 'metadata': {'text': '92 H ARRY POTTER here they are: Nitwit! Blubber! Oddment! T weak! ‘Thank you!’ He sat back down. Everybody clapped and cheered. Harry didn’t know whether to laugh or not. ‘Is he – a bit mad?’ he asked Percy uncertainly . ‘Mad?’ said Percy airily . ‘He’s a genius! Best wizard in the world! But he is a bit mad, yes. Potatoes, Harry?’ Harry’s mouth fell open. The dishes in front of him were now piled with food. He had never seen so many things he liked to eat', 'page_number': 85, 'chunk_number': 0, 'filename': '../data/full_book_one.pdf', 'uuid': '0087d8d3-2e7e-4a0d-ba33-fe9bf4d73161'}}]\n", 241 | "[0, 1, 2, 3, 4]\n" 242 | ] 243 | } 244 | ], 245 | "source": [ 246 | "result = client.query(\n", 247 | " question=question,\n", 248 | " index=index,\n", 249 | " namespace=namespace,\n", 250 | " rules=[rule],\n", 251 | " keyword_trigger=True\n", 252 | ")\n", 253 | "\n", 254 | "print(result[\"answer\"])\n", 255 | "print(result[\"matches\"])\n", 256 | "print(result[\"used_contexts\"])" 257 | ] 258 | }, 259 | { 260 | "cell_type": "markdown", 261 | "metadata": {}, 262 | "source": [ 263 | "### WHAT IF WE WANT IT TO RUN EACH RULE IN A ROW" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": 41, 269 | "metadata": {}, 270 | "outputs": [], 271 | "source": [ 272 | "question = \"What is Harry Potter's favorite food?\"" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": 43, 278 | "metadata": {}, 279 | "outputs": [], 280 | "source": [ 281 | "rule_1 = Rule(\n", 282 | " filename=\"data/full_book_one.pdf\",\n", 283 | " page_numbers=[120, 121, 150]\n", 284 | ")\n", 285 | "\n", 286 | "rule_2 = Rule(\n", 287 | " filename=\"data/full_book_one.pdf\",\n", 288 | " page_numbers=[80, 81, 82]\n", 289 | ")\n", 290 | "\n", 291 | "result = client.query(\n", 292 | " question=question,\n", 293 | " index=index,\n", 294 | " namespace=namespace,\n", 295 | " rules=[rule_1, rule_2],\n", 296 | " process_rules_separately=True\n", 297 | ")" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": null, 303 | "metadata": {}, 304 | "outputs": [], 305 | "source": [] 306 | } 307 | ], 308 | "metadata": { 309 | "kernelspec": { 310 | "display_name": "venv", 311 | "language": "python", 312 | "name": "python3" 313 | }, 314 | "language_info": { 315 | "codemirror_mode": { 316 | "name": "ipython", 317 | "version": 3 318 | }, 319 | "file_extension": ".py", 320 | "mimetype": "text/x-python", 321 | "name": "python", 322 | "nbconvert_exporter": "python", 323 | "pygments_lexer": "ipython3", 324 | "version": "3.10.13" 325 | } 326 | }, 327 | "nbformat": 4, 328 | "nbformat_minor": 2 329 | } 330 | -------------------------------------------------------------------------------- /examples/milvus_tutorial.py: -------------------------------------------------------------------------------- 1 | """Script that demonstrates how to use the RAG model with Milvus to implement rule-based retrieval.""" 2 | 3 | import os 4 | 5 | from pymilvus import DataType 6 | 7 | from src.whyhow_rbr.rag_milvus import ClientMilvus 8 | 9 | # Set up your Milvus Cloud information 10 | YOUR_MILVUS_CLOUD_END_POINT = os.getenv("YOUR_MILVUS_CLOUD_END_POINT") 11 | YOUR_MILVUS_CLOUD_TOKEN = os.getenv("YOUR_MILVUS_CLOUD_TOKEN") 12 | 13 | # Initialize the ClientMilvus 14 | milvus_client = ClientMilvus( 15 | milvus_uri=YOUR_MILVUS_CLOUD_END_POINT, 16 | milvus_token=YOUR_MILVUS_CLOUD_TOKEN, 17 | ) 18 | 19 | 20 | # Define collection name 21 | COLLECTION_NAME = "YOUR_COLLECTION_NAME" # take your own collection name 22 | 23 | 24 | # Create necessary schema to store data 25 | DIMENSION = 1536 # decide by the model you use 26 | 27 | schema = milvus_client.create_schema(auto_id=True) # Enable id matching 28 | 29 | schema = milvus_client.add_field( 30 | schema=schema, field_name="id", datatype=DataType.INT64, is_primary=True 31 | ) 32 | schema = milvus_client.add_field( 33 | schema=schema, 34 | field_name="embedding", 35 | datatype=DataType.FLOAT_VECTOR, 36 | dim=DIMENSION, 37 | ) 38 | 39 | 40 | # Start to indexing data field 41 | index_params = milvus_client.prepare_index_params() 42 | index_params = milvus_client.add_index( 43 | index_params=index_params, # pass in index_params object 44 | field_name="embedding", 45 | index_type="AUTOINDEX", # use autoindex instead of other complex indexing method 46 | metric_type="COSINE", # L2, COSINE, or IP 47 | ) 48 | 49 | 50 | # Create Collection 51 | milvus_client.create_collection( 52 | collection_name=COLLECTION_NAME, schema=schema, index_params=index_params 53 | ) 54 | 55 | 56 | # Create a Partition, list it out 57 | milvus_client.crate_partition( 58 | collection_name=COLLECTION_NAME, 59 | partition_name="xxx", # Put in your own partition name, better fit the document you upload 60 | ) 61 | 62 | partitions = milvus_client.list_partition(collection_name=COLLECTION_NAME) 63 | print(partitions) 64 | 65 | 66 | # Uploading the PDF document 67 | # get pdfs 68 | pdfs = ["harry-potter.pdf", "game-of-thrones.pdf"] # replace to your pdfs path 69 | 70 | milvus_client.upload_documents( 71 | collection_name=COLLECTION_NAME, partition_name="xxx", documents=pdfs 72 | ) 73 | 74 | 75 | # add your rules: 76 | filter = "" 77 | partition_names = None 78 | 79 | 80 | # Search data and implement RAG! 81 | res = milvus_client.search( 82 | question="Tell me about the greedy method", 83 | collection_name=COLLECTION_NAME, 84 | filter=filter, 85 | partition_names=None, 86 | anns_field="embedding", 87 | output_fields="text", 88 | ) 89 | print(res["answer"]) 90 | print(res["matches"]) 91 | 92 | 93 | # Clean up 94 | milvus_client.drop_collection(collection_name=COLLECTION_NAME) 95 | -------------------------------------------------------------------------------- /examples/qdrant/create_collection.py: -------------------------------------------------------------------------------- 1 | """Example of creating a Pinecone index and uploading documents to it.""" 2 | 3 | import logging 4 | 5 | from openai import OpenAI 6 | from qdrant_client import QdrantClient 7 | 8 | from src.whyhow_rbr.rag_qdrant import Client 9 | 10 | # Parameters 11 | collection_name = "" # Replace with your collection name 12 | pdfs = ( 13 | [] 14 | ) # Replace with the paths to your PDFs, e.g. ["path/to/pdf1.pdf", "path/to/pdf2.pdf 15 | logging_level = logging.INFO 16 | 17 | # Logging 18 | logging.basicConfig( 19 | level=logging.WARNING, 20 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 21 | ) 22 | logger = logging.getLogger("create_index") 23 | logger.setLevel(logging_level) 24 | 25 | 26 | client = Client( 27 | OpenAI(), # Set OPENAI_API_KEY environment variable 28 | QdrantClient(url="http://localhost:6333"), 29 | ) 30 | 31 | client.create_collection(collection_name) 32 | client.upload_documents(collection_name, documents=pdfs) 33 | -------------------------------------------------------------------------------- /examples/qdrant/query.py: -------------------------------------------------------------------------------- 1 | """Example demonostating how to perform RAG.""" 2 | 3 | import logging 4 | 5 | from openai import OpenAI 6 | from qdrant_client import QdrantClient 7 | 8 | from src.whyhow_rbr.rag_qdrant import Client, Rule 9 | 10 | # Parameters 11 | collection_name = "" 12 | question = "" # Replace with your question 13 | logging_level = logging.INFO # Set to logging.DEBUG for more verbosity 14 | top_k = 5 15 | 16 | # Logging 17 | logging.basicConfig( 18 | level=logging.WARNING, 19 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 20 | ) 21 | logger = logging.getLogger("querying") 22 | logger.setLevel(logging_level) 23 | logging.getLogger("whyhow_rbr").setLevel(logging_level) 24 | 25 | 26 | client = Client( 27 | OpenAI(), # Set OPENAI_API_KEY environment variable 28 | QdrantClient(url="http://localhost:6333"), 29 | ) 30 | 31 | rules = [ 32 | Rule( 33 | # Replace with your filename 34 | filename="name/of/pdf_1.pdf", 35 | page_numbers=[2], 36 | keywords=["keyword1", "keyword2"], 37 | ), 38 | Rule( 39 | # Replace with your filename 40 | filename="name/of/pdf_1.pdf", 41 | page_numbers=[1], 42 | keywords=[], 43 | ), 44 | ] 45 | 46 | result = client.query( 47 | question=question, 48 | collection_name=collection_name, 49 | rules=rules, 50 | top_k=top_k, 51 | process_rules_separately=False, 52 | keyword_trigger=False, 53 | ) 54 | answer = result["answer"] 55 | 56 | 57 | logger.info(f"Answer: {answer}") 58 | -------------------------------------------------------------------------------- /examples/querying.py: -------------------------------------------------------------------------------- 1 | """Example demonostating how to perform RAG.""" 2 | 3 | import logging 4 | 5 | from whyhow_rbr import Client, Rule 6 | 7 | # Parameters 8 | index_name = "" # Replace with your index name 9 | namespace = "" # Replace with your namespace name 10 | question = "" # Replace with your question 11 | logging_level = logging.INFO # Set to logging.DEBUG for more verbosity 12 | top_k = 5 13 | 14 | # Logging 15 | logging.basicConfig( 16 | level=logging.WARNING, 17 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 18 | ) 19 | logger = logging.getLogger("querying") 20 | logger.setLevel(logging_level) 21 | logging.getLogger("whyhow_rbr").setLevel(logging_level) 22 | 23 | 24 | # Define OPENAI_API_KEY and PINECONE_API_KEY as environment variables 25 | client = Client() 26 | 27 | index = client.get_index(index_name) 28 | logger.info(f"Index {index_name} exists") 29 | 30 | rules = [ 31 | Rule( 32 | # Replace with your filename 33 | filename="doc1.pdf", 34 | page_numbers=[26], 35 | keywords=["word", "test"], 36 | ), 37 | Rule( 38 | # Replace with your filename 39 | filename="doc2.pdf", 40 | page_numbers=[2], 41 | keywords=[], 42 | ), 43 | ] 44 | 45 | result = client.query( 46 | question=question, 47 | index=index, 48 | namespace=namespace, 49 | rules=rules, 50 | top_k=top_k, 51 | process_rules_separately=False, 52 | keyword_trigger=False, 53 | ) 54 | answer = result["answer"] 55 | 56 | 57 | logger.info(f"Answer: {answer}") 58 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: WhyHow 2 | nav: 3 | - Home: index.md 4 | - Installation: installation.md 5 | - Tutorial: tutorial.md 6 | - API Documentation: api.md 7 | 8 | theme: 9 | name: material 10 | palette: 11 | scheme: slate 12 | features: 13 | - content.code.copy 14 | - search.suggest 15 | - search.highlight 16 | - toc.follow 17 | 18 | plugins: 19 | - search 20 | - mkdocstrings: 21 | handlers: 22 | python: 23 | options: 24 | docstring_style: numpy 25 | show_root_heading: true 26 | 27 | markdown_extensions: 28 | - toc: 29 | permalink: true 30 | toc_depth: 3 31 | - admonition 32 | - tables 33 | - pymdownx.details 34 | - pymdownx.highlight: 35 | anchor_linenums: true 36 | line_spans: __span 37 | pygments_lang_class: true 38 | - pymdownx.inlinehilite 39 | - pymdownx.snippets 40 | - pymdownx.superfences 41 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.2"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "rule-based-retrieval" 7 | authors = [{name = "Tom Smoker"}] 8 | description = "Python package for Rule-based Retrieval using RAG" 9 | keywords = ["retrieval", "RAG", "Pinecone", "openai", "LLM"] 10 | classifiers = ["Programming Language :: Python :: 3"] 11 | requires-python = ">=3.10" 12 | readme = "README.md" 13 | dependencies = [ 14 | "langchain_core", 15 | "langchain_community", 16 | "langchain_openai", 17 | "langchain_text_splitters", 18 | "openai>=1", 19 | "pinecone-client", 20 | "pydantic>1", 21 | "pypdf", 22 | "tiktoken", 23 | "qdrant-client" 24 | ] 25 | dynamic = ["version"] 26 | 27 | [project.urls] 28 | Homepage = "https://whyhow.ai" 29 | Documentation = "https://whyhow-ai.github.io/rule-based-retrieval/" 30 | "Issue Tracker" = "https://github.com/whyhow-ai/rule-based-retrieval/issues" 31 | 32 | 33 | [project.optional-dependencies] 34 | dev = [ 35 | "bandit[toml]", 36 | "black", 37 | "flake8", 38 | "flake8-docstrings", 39 | "fpdf", 40 | "isort", 41 | "mypy", 42 | "pydocstyle[toml]", 43 | "pytest-cov", 44 | "pytest", 45 | ] 46 | docs = [ 47 | "mkdocs", 48 | "mkdocstrings[python]", 49 | "mkdocs-material", 50 | "pymdown-extensions", 51 | ] 52 | 53 | [project.scripts] 54 | 55 | [tool.setuptools] 56 | zip-safe = false 57 | include-package-data = true 58 | package-dir = {"" = "src"} 59 | 60 | [tool.setuptools.packages.find] 61 | where = ["src"] 62 | namespaces = false 63 | 64 | [tool.setuptools.package-data] 65 | "*" = ["*.txt", "*.rst", "*.typed"] 66 | 67 | [tool.setuptools.dynamic] 68 | version = {attr = "whyhow_rbr.__version__"} 69 | 70 | [tool.pydocstyle] 71 | convention = "numpy" 72 | add-ignore = "D301" 73 | 74 | [tool.bandit] 75 | 76 | [tool.black] 77 | line-length = 79 78 | preview = true 79 | 80 | [tool.isort] 81 | profile = "black" 82 | line_length = 79 83 | 84 | [tool.mypy] 85 | plugins = [ 86 | "pydantic.mypy" 87 | ] 88 | python_version = "3.10" 89 | ignore_missing_imports = true 90 | no_implicit_optional = true 91 | check_untyped_defs = true 92 | strict_equality = true 93 | warn_redundant_casts = true 94 | warn_unused_ignores = true 95 | show_error_codes = true 96 | disallow_any_generics = true 97 | disallow_incomplete_defs = true 98 | disallow_untyped_defs = true 99 | 100 | [tool.pydantic-mypy] 101 | init_forbid_extra = true 102 | init_typed = true 103 | warn_required_dynamic_aliases = true 104 | 105 | [tool.pytest.ini_options] 106 | filterwarnings = [ 107 | "error" 108 | ] 109 | testpaths = [ 110 | "tests", 111 | ] 112 | addopts = "--cov=src/ -v --cov-report=term-missing --durations=20" 113 | log_cli = false 114 | -------------------------------------------------------------------------------- /src/DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whyhow-ai/rule-based-retrieval/91701f45822823d6c54cac3b526e43cdb409e4e3/src/DS_Store -------------------------------------------------------------------------------- /src/whyhow_rbr/__init__.py: -------------------------------------------------------------------------------- 1 | """SDK.""" 2 | 3 | from whyhow_rbr.exceptions import ( 4 | IndexAlreadyExistsException, 5 | IndexNotFoundException, 6 | ) 7 | from whyhow_rbr.rag import Client, Rule 8 | 9 | __version__ = "v0.1.4" 10 | __all__ = [ 11 | "Client", 12 | "IndexAlreadyExistsException", 13 | "IndexNotFoundException", 14 | "Rule", 15 | ] 16 | -------------------------------------------------------------------------------- /src/whyhow_rbr/embedding.py: -------------------------------------------------------------------------------- 1 | """Collection of utilities for working with embeddings.""" 2 | 3 | from langchain_openai import OpenAIEmbeddings 4 | 5 | 6 | def generate_embeddings( 7 | openai_api_key: str, 8 | chunks: list[str], 9 | model: str = "text-embedding-3-small", 10 | ) -> list[list[float]]: 11 | """Generate embeddings for a list of chunks. 12 | 13 | Parameters 14 | ---------- 15 | openai_api_key : str 16 | OpenAI API key. 17 | 18 | chunks : list[str] 19 | List of chunks to generate embeddings for. 20 | 21 | model : str 22 | OpenAI model to use for generating embeddings. 23 | 24 | Returns 25 | ------- 26 | list[list[float]] 27 | List of embeddings for each chunk. 28 | 29 | """ 30 | embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key, model=model) # type: ignore 31 | embeddings_array = embeddings.embed_documents(chunks) 32 | 33 | return embeddings_array 34 | -------------------------------------------------------------------------------- /src/whyhow_rbr/exceptions.py: -------------------------------------------------------------------------------- 1 | """Collection of all custom exceptions for the package.""" 2 | 3 | 4 | class IndexAlreadyExistsException(Exception): 5 | """Raised when the index already exists.""" 6 | 7 | pass 8 | 9 | 10 | class IndexNotFoundException(Exception): 11 | """Raised when the index is not found.""" 12 | 13 | pass 14 | 15 | 16 | class OpenAIException(Exception): 17 | """Raised when the OpenAI API returns an error.""" 18 | 19 | pass 20 | 21 | 22 | class CollectionNotFoundException(Exception): 23 | """Raised when the Collection is not found.""" 24 | 25 | pass 26 | 27 | 28 | class CollectionAlreadyExistsException(Exception): 29 | """Raised when the collection already exists.""" 30 | 31 | pass 32 | 33 | 34 | class SchemaCreateFailureException(Exception): 35 | """Raised when fail to create a new schema.""" 36 | 37 | pass 38 | 39 | 40 | class CollectionCreateFailureException(Exception): 41 | """Raised when fail to create a new collection.""" 42 | 43 | pass 44 | 45 | 46 | class AddSchemaFieldFailureException(Exception): 47 | """Raised when fail to add a field to schema.""" 48 | 49 | pass 50 | 51 | 52 | class PartitionCreateFailureException(Exception): 53 | """Raised when fail to create a partition.""" 54 | 55 | pass 56 | 57 | 58 | class PartitionDropFailureException(Exception): 59 | """Raised when fail to drop a partition.""" 60 | 61 | pass 62 | 63 | 64 | class PartitionListFailureException(Exception): 65 | """Raised when fail to list all partitions.""" 66 | 67 | pass 68 | -------------------------------------------------------------------------------- /src/whyhow_rbr/processing.py: -------------------------------------------------------------------------------- 1 | """Collection of utilities for extracting and processing text.""" 2 | 3 | import copy 4 | import pathlib 5 | import re 6 | 7 | from langchain_community.document_loaders import PyPDFLoader 8 | from langchain_core.documents import Document 9 | from langchain_text_splitters import RecursiveCharacterTextSplitter 10 | 11 | 12 | def parse_and_split( 13 | path: str | pathlib.Path, 14 | chunk_size: int = 512, 15 | chunk_overlap: int = 100, 16 | ) -> list[Document]: 17 | """Parse a PDF and split it into chunks. 18 | 19 | Parameters 20 | ---------- 21 | path : str or pathlib.Path 22 | Path to the document to process. 23 | 24 | chunk_size : int 25 | Size of the chunks. 26 | 27 | chunk_overlap : int 28 | Overlap between chunks. 29 | 30 | Returns 31 | ------- 32 | list[Document] 33 | The chunks of the pdf. 34 | """ 35 | loader = PyPDFLoader(str(path)) 36 | docs = loader.load() 37 | splitter = RecursiveCharacterTextSplitter( 38 | chunk_size=chunk_size, 39 | chunk_overlap=chunk_overlap, 40 | ) 41 | chunks = splitter.split_documents(docs) 42 | 43 | # Assign the change number (within a page) to each chunk 44 | i_page = 0 45 | i_chunk = 0 46 | 47 | for chunk in chunks: 48 | if chunk.metadata["page"] != i_page: 49 | i_page = chunk.metadata["page"] 50 | i_chunk = 0 51 | 52 | chunk.metadata["chunk"] = i_chunk 53 | i_chunk += 1 54 | 55 | return chunks 56 | 57 | 58 | def clean_chunks( 59 | chunks: list[Document], 60 | ) -> list[Document]: 61 | """Clean the chunks of a pdf. 62 | 63 | No modifications in-place. 64 | 65 | Parameters 66 | ---------- 67 | chunks : list[Document] 68 | The chunks of the pdf. 69 | 70 | Returns 71 | ------- 72 | list[Document] 73 | The cleaned chunks. 74 | """ 75 | pattern = re.compile(r"(\r\n|\n|\r)") 76 | clean_chunks: list[Document] = [] 77 | 78 | for chunk in chunks: 79 | text = re.sub(pattern, "", chunk.page_content) 80 | new_chunk = Document( 81 | page_content=text, 82 | metadata=copy.deepcopy(chunk.metadata), 83 | ) 84 | 85 | clean_chunks.append(new_chunk) 86 | 87 | return clean_chunks 88 | -------------------------------------------------------------------------------- /src/whyhow_rbr/rag.py: -------------------------------------------------------------------------------- 1 | """Retrieval augmented generation logic.""" 2 | 3 | import logging 4 | import os 5 | import pathlib 6 | import re 7 | import uuid 8 | from typing import Any, Literal, TypedDict, cast 9 | 10 | from langchain_core.documents import Document 11 | from openai import OpenAI 12 | from pinecone import ( 13 | Index, 14 | NotFoundException, 15 | Pinecone, 16 | PodSpec, 17 | ServerlessSpec, 18 | ) 19 | from pydantic import ( 20 | BaseModel, 21 | Field, 22 | ValidationError, 23 | field_validator, 24 | model_validator, 25 | ) 26 | 27 | from whyhow_rbr.embedding import generate_embeddings 28 | from whyhow_rbr.exceptions import ( 29 | IndexAlreadyExistsException, 30 | IndexNotFoundException, 31 | OpenAIException, 32 | ) 33 | from whyhow_rbr.processing import clean_chunks, parse_and_split 34 | 35 | logger = logging.getLogger(__name__) 36 | 37 | 38 | # Defaults 39 | DEFAULT_SPEC = ServerlessSpec(cloud="aws", region="us-west-2") 40 | 41 | 42 | # Custom classes 43 | class PineconeMetadata(BaseModel, extra="forbid"): 44 | """The metadata to be stored in Pinecone. 45 | 46 | Attributes 47 | ---------- 48 | text : str 49 | The text of the document. 50 | 51 | page_number : int 52 | The page number of the document. 53 | 54 | chunk_number : int 55 | The chunk number of the document. 56 | 57 | filename : str 58 | The filename of the document. 59 | 60 | uuid : str 61 | The UUID of the document. Note that this is not required to be 62 | provided when creating the metadata. It is generated automatically 63 | when creating the PineconeDocument. 64 | """ 65 | 66 | text: str 67 | page_number: int 68 | chunk_number: int 69 | filename: str 70 | uuid: str = Field(default_factory=lambda: str(uuid.uuid4())) 71 | 72 | 73 | class PineconeDocument(BaseModel, extra="forbid"): 74 | """The actual document to be stored in Pinecone. 75 | 76 | Attributes 77 | ---------- 78 | metadata : PineconeMetadata 79 | The metadata of the document. 80 | 81 | values : list[float] | None 82 | The embedding of the document. The None is used when querying 83 | the index since the values are not needed. At upsert time, the 84 | values are required. 85 | 86 | id : str | None 87 | The human-readable identifier of the document. This is generated 88 | automatically when creating the PineconeDocument unless it is 89 | provided. 90 | 91 | """ 92 | 93 | metadata: PineconeMetadata 94 | values: list[float] | None = None 95 | id: str | None = None 96 | 97 | @model_validator(mode="after") 98 | def generate_human_readable_id(self) -> "PineconeDocument": 99 | """Generate a human-readable identifier for the document.""" 100 | if self.id is None: 101 | meta = self.metadata 102 | hr_id = f"{meta.filename}-{meta.page_number}-{meta.chunk_number}" 103 | self.id = hr_id 104 | 105 | return self 106 | 107 | 108 | class PineconeMatch(BaseModel, extra="ignore"): 109 | """The match returned from Pinecone. 110 | 111 | Attributes 112 | ---------- 113 | id : str 114 | The ID of the document. 115 | 116 | score : float 117 | The score of the match. Its meaning depends on the metric used for 118 | the index. 119 | 120 | metadata : PineconeMetadata 121 | The metadata of the document. 122 | 123 | """ 124 | 125 | id: str 126 | score: float 127 | metadata: PineconeMetadata 128 | 129 | 130 | class Rule(BaseModel): 131 | """Retrieval rule. 132 | 133 | The rule is used to filter the documents in the index. 134 | 135 | Attributes 136 | ---------- 137 | filename : str | None 138 | The filename of the document. 139 | 140 | uuid : str | None 141 | The UUID of the document. 142 | 143 | page_numbers : list[int] | None 144 | The page numbers of the document. 145 | 146 | keywords : list[str] | None 147 | The keywords to trigger a rule. 148 | """ 149 | 150 | filename: str | None = None 151 | uuid: str | None = None 152 | page_numbers: list[int] | None = None 153 | keywords: list[str] | None = None 154 | 155 | @field_validator("page_numbers", mode="before") 156 | @classmethod 157 | def convert_empty_to_none(cls, v: list[int] | None) -> list[int] | None: 158 | """Convert empty list to None.""" 159 | if v is not None and not v: 160 | return None 161 | return v 162 | 163 | def convert_empty_str_to_none( 164 | cls, s: list[str] | None 165 | ) -> list[str] | None: 166 | """Convert empty string list to None.""" 167 | if s is not None and not s: 168 | return None 169 | return s 170 | 171 | def to_filter(self) -> dict[str, list[dict[str, Any]]] | None: 172 | """Convert rule to Pinecone filter format.""" 173 | if not any([self.filename, self.uuid, self.page_numbers]): 174 | return None 175 | 176 | conditions: list[dict[str, Any]] = [] 177 | if self.filename is not None: 178 | conditions.append({"filename": {"$eq": self.filename}}) 179 | if self.uuid is not None: 180 | conditions.append({"uuid": {"$eq": self.uuid}}) 181 | if self.page_numbers is not None: 182 | conditions.append({"page_number": {"$in": self.page_numbers}}) 183 | 184 | filter_ = {"$and": conditions} 185 | return filter_ 186 | 187 | 188 | class Input(BaseModel): 189 | """Example input for the prompt. 190 | 191 | Attributes 192 | ---------- 193 | question : str 194 | The question to ask. 195 | 196 | contexts : list[str] 197 | The contexts to use for answering the question. 198 | """ 199 | 200 | question: str 201 | contexts: list[str] 202 | 203 | 204 | class Output(BaseModel): 205 | """Example output for the prompt. 206 | 207 | Attributes 208 | ---------- 209 | answer : str 210 | The answer to the question. 211 | 212 | contexts : list[int] 213 | The indices of the contexts that were used to answer the question. 214 | """ 215 | 216 | answer: str 217 | contexts: list[int] 218 | 219 | 220 | input_example_1 = Input( 221 | question="What is the capital of France?", 222 | contexts=[ 223 | "The capital of France is Paris.", 224 | "The capital of France is not London.", 225 | "Paris is beautiful and it is also the capital of France.", 226 | ], 227 | ) 228 | output_example_1 = Output(answer="Paris", contexts=[0, 2]) 229 | 230 | input_example_2 = Input( 231 | question="What are the impacts of climate change on global agriculture?", 232 | contexts=[ 233 | "Climate change can lead to more extreme weather patterns, affecting crop yields.", 234 | "Rising sea levels due to climate change can inundate agricultural lands in coastal areas, reducing arable land.", 235 | "Changes in temperature and precipitation patterns can shift agricultural zones, impacting food security.", 236 | ], 237 | ) 238 | 239 | output_example_2 = Output( 240 | answer="Variable impacts including altered weather patterns, reduced arable land, shifting agricultural zones, increased pests and diseases, with potential mitigation through technology and sustainable practices", 241 | contexts=[0, 1, 2], 242 | ) 243 | 244 | input_example_3 = Input( 245 | question="How has the concept of privacy evolved with the advent of digital technology?", 246 | contexts=[ 247 | "Digital technology has made it easier to collect, store, and analyze personal data, raising privacy concerns.", 248 | "Social media platforms and smartphones often track user activity and preferences, leading to debates over consent and data ownership.", 249 | "Encryption and secure communication technologies have evolved as means to protect privacy in the digital age.", 250 | "Legislation like the GDPR in the EU has been developed to address privacy concerns and regulate data handling by companies.", 251 | "The concept of privacy is increasingly being viewed through the lens of digital rights and cybersecurity.", 252 | ], 253 | ) 254 | 255 | output_example_3 = Output( 256 | answer="Evolving with challenges due to data collection and analysis, changes in legislation, and advancements in encryption and security, amidst ongoing debates over consent and data ownership", 257 | contexts=[0, 1, 2, 3, 4], 258 | ) 259 | 260 | # Custom types 261 | Metric = Literal["cosine", "euclidean", "dotproduct"] 262 | 263 | 264 | class QueryReturnType(TypedDict): 265 | """The return type of the query method. 266 | 267 | Attributes 268 | ---------- 269 | answer : str 270 | The answer to the question. 271 | 272 | matches : list[dict[str, Any]] 273 | The retrieved documents from the index. 274 | 275 | used_contexts : list[int] 276 | The indices of the matches that were actually used to answer the question. 277 | """ 278 | 279 | answer: str 280 | matches: list[dict[str, Any]] 281 | used_contexts: list[int] 282 | 283 | 284 | PROMPT_START = f"""\ 285 | You are a helpful assistant. I will give you a question and provide multiple 286 | context documents. You will need to answer the question based on the contexts 287 | and also specify in which context(s) you found the answer. 288 | If you don't find the answer in the context, you can use your own knowledge, however, 289 | in that case, the contexts array should be empty. 290 | 291 | Both the input and the output are JSON objects. 292 | 293 | # EXAMPLE INPUT 294 | # ```json 295 | # {input_example_1.model_dump_json()} 296 | # ``` 297 | 298 | # EXAMPLE OUTPUT 299 | # ```json 300 | # {output_example_1.model_dump_json()} 301 | 302 | # EXAMPLE INPUT 303 | # ```json 304 | # {input_example_2.model_dump_json()} 305 | # ``` 306 | 307 | # EXAMPLE OUTPUT 308 | # ```json 309 | # {output_example_2.model_dump_json()} 310 | 311 | # EXAMPLE INPUT 312 | # ```json 313 | # {input_example_3.model_dump_json()} 314 | # ``` 315 | 316 | # EXAMPLE OUTPUT 317 | # ```json 318 | # {output_example_3.model_dump_json()} 319 | 320 | """ 321 | 322 | 323 | class Client: 324 | """Synchronous client.""" 325 | 326 | def __init__( 327 | self, 328 | openai_api_key: str | None = None, 329 | pinecone_api_key: str | None = None, 330 | ): 331 | if openai_api_key is None: 332 | openai_api_key = os.environ.get("OPENAI_API_KEY") 333 | if openai_api_key is None: 334 | raise ValueError( 335 | "No OPENAI_API_KEY provided must be provided." 336 | ) 337 | 338 | if pinecone_api_key is None: 339 | pinecone_api_key = os.environ.get("PINECONE_API_KEY") 340 | if pinecone_api_key is None: 341 | raise ValueError("No PINECONE_API_KEY provided") 342 | 343 | self.openai_client = OpenAI(api_key=openai_api_key) 344 | self.pinecone_client = Pinecone(api_key=pinecone_api_key) 345 | 346 | def get_index(self, name: str) -> Index: 347 | """Get an existing index. 348 | 349 | Parameters 350 | ---------- 351 | name : str 352 | The name of the index. 353 | 354 | 355 | Returns 356 | ------- 357 | Index 358 | The index. 359 | 360 | Raises 361 | ------ 362 | IndexNotFoundException 363 | If the index does not exist. 364 | 365 | """ 366 | try: 367 | index = self.pinecone_client.Index(name) 368 | except NotFoundException as e: 369 | raise IndexNotFoundException(f"Index {name} does not exist") from e 370 | 371 | return index 372 | 373 | def create_index( 374 | self, 375 | name: str, 376 | dimension: int = 1536, 377 | metric: Metric = "cosine", 378 | spec: ServerlessSpec | PodSpec | None = None, 379 | ) -> Index: 380 | """Create a new index. 381 | 382 | If the index does not exist, it creates a new index with the specified. 383 | 384 | Parameters 385 | ---------- 386 | name : str 387 | The name of the index. 388 | 389 | dimension : int 390 | The dimension of the index. 391 | 392 | metric : Metric 393 | The metric of the index. 394 | 395 | spec : ServerlessSpec | PodSpec | None 396 | The spec of the index. If None, it uses the default spec. 397 | 398 | Raises 399 | ------ 400 | IndexAlreadyExistsException 401 | If the index already exists. 402 | 403 | """ 404 | try: 405 | self.get_index(name) 406 | except IndexNotFoundException: 407 | pass 408 | else: 409 | raise IndexAlreadyExistsException(f"Index {name} already exists") 410 | 411 | if spec is None: 412 | spec = DEFAULT_SPEC 413 | logger.info(f"Using default spec {spec}") 414 | 415 | self.pinecone_client.create_index( 416 | name=name, dimension=dimension, metric=metric, spec=spec 417 | ) 418 | index = self.pinecone_client.Index(name) 419 | 420 | return index 421 | 422 | def upload_documents( 423 | self, 424 | index: Index, 425 | documents: list[str | pathlib.Path], 426 | namespace: str, 427 | embedding_model: str = "text-embedding-3-small", 428 | batch_size: int = 100, 429 | ) -> None: 430 | """Upload documents to the index. 431 | 432 | Parameters 433 | ---------- 434 | index : Index 435 | The index. 436 | 437 | documents : list[str | pathlib.Path] 438 | The documents to upload. 439 | 440 | namespace : str 441 | The namespace within the index to use. 442 | 443 | batch_size : int 444 | The number of documents to upload at a time. 445 | 446 | embedding_model : str 447 | The OpenAI embedding model to use. 448 | 449 | """ 450 | # don't allow for duplicate documents 451 | documents = list(set(documents)) 452 | if not documents: 453 | logger.info("No documents to upload") 454 | return 455 | 456 | logger.info(f"Parsing {len(documents)} documents") 457 | all_chunks: list[Document] = [] 458 | for document in documents: 459 | chunks_ = parse_and_split(document) 460 | chunks = clean_chunks(chunks_) 461 | all_chunks.extend(chunks) 462 | 463 | logger.info(f"Embedding {len(all_chunks)} chunks") 464 | embeddings = generate_embeddings( 465 | openai_api_key=self.openai_client.api_key, 466 | chunks=[c.page_content for c in all_chunks], 467 | model=embedding_model, 468 | ) 469 | 470 | if len(embeddings) != len(all_chunks): 471 | raise ValueError( 472 | "Number of embeddings does not match number of chunks" 473 | ) 474 | 475 | # create PineconeDocuments 476 | pinecone_documents = [] 477 | for i, (chunk, embedding) in enumerate(zip(all_chunks, embeddings)): 478 | metadata = PineconeMetadata( 479 | text=chunk.page_content, 480 | page_number=chunk.metadata["page"], 481 | chunk_number=chunk.metadata["chunk"], 482 | filename=chunk.metadata["source"], 483 | ) 484 | pinecone_document = PineconeDocument( 485 | values=embedding, 486 | metadata=metadata, 487 | ) 488 | pinecone_documents.append(pinecone_document) 489 | 490 | upsert_documents = [d.model_dump() for d in pinecone_documents] 491 | 492 | response = index.upsert( 493 | upsert_documents, namespace=namespace, batch_size=batch_size 494 | ) 495 | n_upserted = response["upserted_count"] 496 | logger.info(f"Upserted {n_upserted} documents") 497 | 498 | def clean_text(self, text: str) -> str: 499 | """Return a lower case version of text with punctuation removed. 500 | 501 | Parameters 502 | ---------- 503 | text : str 504 | The raw text to be cleaned. 505 | 506 | Returns 507 | ------- 508 | str: The cleaned text string. 509 | """ 510 | text_processed = re.sub("[^0-9a-zA-Z ]+", "", text.lower()) 511 | text_processed_further = re.sub(" +", " ", text_processed) 512 | return text_processed_further 513 | 514 | def query( 515 | self, 516 | question: str, 517 | index: Index, 518 | namespace: str, 519 | rules: list[Rule] | None = None, 520 | top_k: int = 5, 521 | chat_model: str = "gpt-4-1106-preview", 522 | chat_temperature: float = 0.0, 523 | chat_max_tokens: int = 1000, 524 | chat_seed: int = 2, 525 | embedding_model: str = "text-embedding-3-small", 526 | process_rules_separately: bool = False, 527 | keyword_trigger: bool = False, 528 | ) -> QueryReturnType: 529 | """Query the index. 530 | 531 | Parameters 532 | ---------- 533 | question : str 534 | The question to ask. 535 | 536 | index : Index 537 | The index to query. 538 | 539 | namespace : str 540 | The namespace within the index to use. 541 | 542 | rules : list[Rule] | None 543 | The rules to use for filtering the documents. 544 | 545 | top_k : int 546 | The number of matches to return per rule. 547 | 548 | chat_model : str 549 | The OpenAI chat model to use. 550 | 551 | chat_temperature : float 552 | The temperature for the chat model. 553 | 554 | chat_max_tokens : int 555 | The maximum number of tokens for the chat model. 556 | 557 | chat_seed : int 558 | The seed for the chat model. 559 | 560 | embedding_model : str 561 | The OpenAI embedding model to use. 562 | 563 | process_rules_separately : bool, optional 564 | Whether to process each rule individually and combine the results at the end. 565 | When set to True, each rule will be run independently, ensuring that every rule 566 | returns results. When set to False (default), all rules will be run as one joined 567 | query, potentially allowing one rule to dominate the others. 568 | Default is False. 569 | 570 | keyword_trigger : bool, optional 571 | Whether to trigger rules based on keyword matches in the question. 572 | Default is False. 573 | 574 | Returns 575 | ------- 576 | QueryReturnType 577 | Dictionary with keys "answer", "matches", and "used_contexts". 578 | The "answer" is the answer to the question. 579 | The "matches" are the "top_k" matches from the index. 580 | The "used_contexts" are the indices of the matches 581 | that were actually used to answer the question. 582 | 583 | Raises 584 | ------ 585 | OpenAIException 586 | If there is an error with the OpenAI API. Some possible reasons 587 | include the chat model not finishing or the response not being 588 | valid JSON. 589 | """ 590 | logger.info(f"Raw rules: {rules}") 591 | 592 | if rules is None: 593 | rules = [] 594 | 595 | if keyword_trigger: 596 | triggered_rules = [] 597 | clean_question = self.clean_text(question).split(" ") 598 | 599 | for rule in rules: 600 | if rule.keywords: 601 | clean_keywords = [ 602 | self.clean_text(keyword) for keyword in rule.keywords 603 | ] 604 | 605 | if bool(set(clean_keywords) & set(clean_question)): 606 | triggered_rules.append(rule) 607 | 608 | rules = triggered_rules 609 | 610 | rule_filters = [rule.to_filter() for rule in rules if rule is not None] 611 | 612 | question_embedding = generate_embeddings( 613 | openai_api_key=self.openai_client.api_key, 614 | chunks=[question], 615 | model=embedding_model, 616 | )[0] 617 | 618 | matches = ( 619 | [] 620 | ) # Initialize matches outside the loop to collect matches from all queries 621 | match_texts = [] 622 | 623 | # Check if there are any rule filters, and if not, proceed with a default query 624 | if not rule_filters: 625 | # Perform a default query 626 | query_response = index.query( 627 | namespace=namespace, 628 | top_k=top_k, 629 | vector=question_embedding, 630 | filter=None, # No specific filter, or you can define a default filter as per your application's logic 631 | include_metadata=True, 632 | ) 633 | matches = [ 634 | PineconeMatch(**m.to_dict()) for m in query_response["matches"] 635 | ] 636 | match_texts = [m.metadata.text for m in matches] 637 | 638 | else: 639 | 640 | if process_rules_separately: 641 | for rule_filter in rule_filters: 642 | if rule_filter: 643 | query_response = index.query( 644 | namespace=namespace, 645 | top_k=top_k, 646 | vector=question_embedding, 647 | filter=rule_filter, 648 | include_metadata=True, 649 | ) 650 | matches.extend( 651 | [ 652 | PineconeMatch(**m.to_dict()) 653 | for m in query_response["matches"] 654 | ] 655 | ) 656 | match_texts += [m.metadata.text for m in matches] 657 | match_texts = list( 658 | set(match_texts) 659 | ) # Ensure unique match texts 660 | else: 661 | if rule_filters: 662 | combined_filters = [] 663 | for rule_filter in rule_filters: 664 | if rule_filter: 665 | combined_filters.append(rule_filter) 666 | 667 | rule_filter = ( 668 | {"$or": combined_filters} if combined_filters else None 669 | ) 670 | else: 671 | rule_filter = None # Fallback to a default query when no rules are provided or valid 672 | 673 | if rule_filter is not None: 674 | query_response = index.query( 675 | namespace=namespace, 676 | top_k=top_k, 677 | vector=question_embedding, 678 | filter=rule_filter, 679 | include_metadata=True, 680 | ) 681 | matches = [ 682 | PineconeMatch(**m.to_dict()) 683 | for m in query_response["matches"] 684 | ] 685 | match_texts = [m.metadata.text for m in matches] 686 | 687 | # Proceed to create prompt, send it to OpenAI, and handle the response 688 | prompt = self.create_prompt(question, match_texts) 689 | response = self.openai_client.chat.completions.create( 690 | model=chat_model, 691 | seed=chat_seed, 692 | temperature=chat_temperature, 693 | messages=[{"role": "user", "content": prompt}], 694 | max_tokens=chat_max_tokens, 695 | ) 696 | 697 | output = self.process_response(response) 698 | 699 | return_dict: QueryReturnType = { 700 | "answer": output.answer, 701 | "matches": [m.model_dump() for m in matches], 702 | "used_contexts": output.contexts, 703 | } 704 | 705 | return return_dict 706 | 707 | def create_prompt(self, question: str, match_texts: list[str]) -> str: 708 | """Create the prompt for the OpenAI chat completion. 709 | 710 | Parameters 711 | ---------- 712 | question : str 713 | The question to ask. 714 | 715 | match_texts : list[str] 716 | The list of context strings to include in the prompt. 717 | 718 | Returns 719 | ------- 720 | str 721 | The generated prompt. 722 | """ 723 | input_actual = Input(question=question, contexts=match_texts) 724 | prompt_end = f""" 725 | ACTUAL INPUT 726 | ```json 727 | {input_actual.model_dump_json()} 728 | ``` 729 | 730 | ACTUAL OUTPUT 731 | """ 732 | return f"{PROMPT_START}\n{prompt_end}" 733 | 734 | def process_response(self, response: Any) -> Output: 735 | """Process the OpenAI chat completion response. 736 | 737 | Parameters 738 | ---------- 739 | response : Any 740 | The OpenAI chat completion response. 741 | 742 | Returns 743 | ------- 744 | Output 745 | The processed output. 746 | 747 | Raises 748 | ------ 749 | OpenAIException 750 | If the chat model did not finish or the response is not valid JSON. 751 | """ 752 | choice = response.choices[0] 753 | if choice.finish_reason != "stop": 754 | raise OpenAIException( 755 | f"Chat did not finish. Reason: {choice.finish_reason}" 756 | ) 757 | 758 | response_raw = cast(str, response.choices[0].message.content) 759 | 760 | if response_raw.startswith("```json"): 761 | start_i = response_raw.index("{") 762 | end_i = response_raw.rindex("}") 763 | response_raw = response_raw[start_i : end_i + 1] 764 | 765 | try: 766 | output = Output.model_validate_json(response_raw) 767 | except ValidationError as e: 768 | raise OpenAIException( 769 | f"OpenAI did not return a valid JSON: {response_raw}" 770 | ) from e 771 | 772 | return output 773 | -------------------------------------------------------------------------------- /src/whyhow_rbr/rag_milvus.py: -------------------------------------------------------------------------------- 1 | """Retrieval augmented generation logic.""" 2 | 3 | import logging 4 | import os 5 | import pathlib 6 | import re 7 | from typing import Any, Dict, List, Optional, TypedDict, cast 8 | 9 | from langchain_core.documents import Document 10 | from openai import OpenAI 11 | from pydantic import BaseModel, ValidationError 12 | from pymilvus import CollectionSchema, DataType, MilvusClient, MilvusException 13 | from pymilvus.milvus_client import IndexParams 14 | 15 | from whyhow_rbr.embedding import generate_embeddings 16 | from whyhow_rbr.exceptions import ( 17 | AddSchemaFieldFailureException, 18 | CollectionAlreadyExistsException, 19 | CollectionCreateFailureException, 20 | CollectionNotFoundException, 21 | OpenAIException, 22 | PartitionCreateFailureException, 23 | PartitionDropFailureException, 24 | PartitionListFailureException, 25 | SchemaCreateFailureException, 26 | ) 27 | from whyhow_rbr.processing import clean_chunks, parse_and_split 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | 32 | class MilvusMetadata(BaseModel): 33 | """The metadata to be stored in Milvus. 34 | 35 | Attributes 36 | ---------- 37 | text : str 38 | The text of the document. 39 | 40 | page_number : int 41 | The page number of the document. 42 | 43 | chunk_number : int 44 | The chunk number of the document. 45 | 46 | filename : str 47 | The filename of the document. 48 | """ 49 | 50 | text: str 51 | page_number: int 52 | chunk_number: int 53 | filename: str 54 | vector: List[float] 55 | 56 | 57 | """Custom classes for constructing prompt, output and query result with examples""" 58 | 59 | 60 | class Input(BaseModel): 61 | """Example input for the prompt. 62 | 63 | Attributes 64 | ---------- 65 | question : str 66 | The question to ask. 67 | 68 | contexts : list[str] 69 | The contexts to use for answering the question. 70 | """ 71 | 72 | question: str 73 | contexts: list[str] 74 | 75 | 76 | class Output(BaseModel): 77 | """Example output for the prompt. 78 | 79 | Attributes 80 | ---------- 81 | answer : str 82 | The answer to the question. 83 | 84 | contexts : list[int] 85 | The indices of the contexts that were used to answer the question. 86 | """ 87 | 88 | answer: str 89 | contexts: list[int] 90 | 91 | 92 | input_example_1 = Input( 93 | question="What is the capital of France?", 94 | contexts=[ 95 | "The capital of France is Paris.", 96 | "The capital of France is not London.", 97 | "Paris is beautiful and it is also the capital of France.", 98 | ], 99 | ) 100 | output_example_1 = Output(answer="Paris", contexts=[0, 2]) 101 | 102 | input_example_2 = Input( 103 | question="What are the impacts of climate change on global agriculture?", 104 | contexts=[ 105 | "Climate change can lead to more extreme weather patterns, affecting crop yields.", 106 | "Rising sea levels due to climate change can inundate agricultural lands in coastal areas, reducing arable land.", 107 | "Changes in temperature and precipitation patterns can shift agricultural zones, impacting food security.", 108 | ], 109 | ) 110 | 111 | output_example_2 = Output( 112 | answer="Variable impacts including altered weather patterns, reduced arable land, shifting agricultural zones, increased pests and diseases, with potential mitigation through technology and sustainable practices", 113 | contexts=[0, 1, 2], 114 | ) 115 | 116 | input_example_3 = Input( 117 | question="How has the concept of privacy evolved with the advent of digital technology?", 118 | contexts=[ 119 | "Digital technology has made it easier to collect, store, and analyze personal data, raising privacy concerns.", 120 | "Social media platforms and smartphones often track user activity and preferences, leading to debates over consent and data ownership.", 121 | "Encryption and secure communication technologies have evolved as means to protect privacy in the digital age.", 122 | "Legislation like the GDPR in the EU has been developed to address privacy concerns and regulate data handling by companies.", 123 | "The concept of privacy is increasingly being viewed through the lens of digital rights and cybersecurity.", 124 | ], 125 | ) 126 | 127 | output_example_3 = Output( 128 | answer="Evolving with challenges due to data collection and analysis, changes in legislation, and advancements in encryption and security, amidst ongoing debates over consent and data ownership", 129 | contexts=[0, 1, 2, 3, 4], 130 | ) 131 | 132 | 133 | class QueryReturnType(TypedDict): 134 | """The return type of the query method. 135 | 136 | Attributes 137 | ---------- 138 | answer : str 139 | The answer to the question. 140 | 141 | matches : List[dict] 142 | The retrieved documents from the collection. 143 | 144 | used_contexts : list[int] 145 | The indices of the matches that were actually used to answer the question. 146 | """ 147 | 148 | answer: str 149 | matches: List[Dict[Any, Any]] 150 | used_contexts: list[int] 151 | 152 | 153 | PROMPT_START = f"""\ 154 | You are a helpful assistant. I will give you a question and provide multiple 155 | context documents. You will need to answer the question based on the contexts 156 | and also specify in which context(s) you found the answer. 157 | If you don't find the answer in the context, you can use your own knowledge, however, 158 | in that case, the contexts array should be empty. 159 | 160 | Both the input and the output are JSON objects. 161 | 162 | # EXAMPLE INPUT 163 | # ```json 164 | # {input_example_1.model_dump_json()} 165 | # ``` 166 | 167 | # EXAMPLE OUTPUT 168 | # ```json 169 | # {output_example_1.model_dump_json()} 170 | 171 | # EXAMPLE INPUT 172 | # ```json 173 | # {input_example_2.model_dump_json()} 174 | # ``` 175 | 176 | # EXAMPLE OUTPUT 177 | # ```json 178 | # {output_example_2.model_dump_json()} 179 | 180 | # EXAMPLE INPUT 181 | # ```json 182 | # {input_example_3.model_dump_json()} 183 | # ``` 184 | 185 | # EXAMPLE OUTPUT 186 | # ```json 187 | # {output_example_3.model_dump_json()} 188 | 189 | """ 190 | 191 | 192 | """Implementing RAG by Milvus""" 193 | 194 | 195 | class ClientMilvus: 196 | """Synchronous client.""" 197 | 198 | def __init__( 199 | self, 200 | milvus_uri: str, 201 | milvus_token: str, 202 | milvus_db_name: Optional[str] = None, 203 | timeout: float | None = None, 204 | openai_api_key: str | None = None, 205 | ): 206 | if openai_api_key is None: 207 | openai_api_key = os.environ.get("OPENAI_API_KEY") 208 | if openai_api_key is None: 209 | raise ValueError( 210 | "No OPENAI_API_KEY provided must be provided." 211 | ) 212 | 213 | self.openai_client = OpenAI(api_key=openai_api_key) 214 | self.milvus_client = MilvusClient( 215 | uri=milvus_uri, 216 | token=milvus_token, 217 | db_name=milvus_db_name, 218 | timeout=timeout, 219 | ) 220 | 221 | def get_collection_stats( 222 | self, collection_name: str, timeout: Optional[float] = None 223 | ) -> Dict[str, Any]: 224 | """Get an existing collection. 225 | 226 | Parameters 227 | ---------- 228 | collection_name : str 229 | The name of the collection. 230 | 231 | timeout : Optional[float] 232 | The timeout duration for this operation. 233 | Setting this to None indicates that this operation timeouts when any response returns or error occurs. 234 | 235 | Returns 236 | ------- 237 | Dict 238 | A dictionary that contains detailed information about the specified collection. 239 | 240 | Raises 241 | ------ 242 | CollectionNotFoundException 243 | If the collection does not exist. 244 | """ 245 | try: 246 | collection_stats = self.milvus_client.describe_collection( 247 | collection_name, timeout 248 | ) 249 | except MilvusException as e: 250 | raise CollectionNotFoundException( 251 | f"Collection {collection_name} does not exist" 252 | ) from e 253 | 254 | return collection_stats 255 | 256 | def create_schema( 257 | self, 258 | auto_id: bool = False, 259 | enable_dynamic_field: bool = True, 260 | **kwargs: Any, 261 | ) -> CollectionSchema: 262 | """Create a schema to add in collection. 263 | 264 | Parameters 265 | ---------- 266 | auto_id : bool 267 | Whether allows the primary field to automatically increment. 268 | 269 | enable_dynamic_field : bool 270 | Whether allows Milvus saves the values of undefined fields in a dynamic field 271 | if the data being inserted into the target collection includes fields that are not defined in the collection's schema. 272 | 273 | Returns 274 | ------- 275 | CollectionSchema 276 | A Schema instance represents the schema of a collection. 277 | 278 | Raises 279 | ------ 280 | SchemaCreateFailureException 281 | If schema create failure. 282 | """ 283 | try: 284 | schema = MilvusClient.create_schema( 285 | auto_id=auto_id, 286 | enable_dynamic_field=enable_dynamic_field, 287 | **kwargs, 288 | ) 289 | except MilvusException as e: 290 | raise SchemaCreateFailureException("Schema create failure.") from e 291 | 292 | return schema 293 | 294 | def add_field( 295 | self, 296 | schema: CollectionSchema, 297 | field_name: str, 298 | datatype: DataType, 299 | is_primary: bool = False, 300 | **kwargs: Any, 301 | ) -> CollectionSchema: 302 | """Add Field to current schema. 303 | 304 | Parameters 305 | ---------- 306 | schema : CollectionSchema 307 | The exist schema object. 308 | 309 | field_name : str 310 | The name of the new field. 311 | 312 | datatype : DataType 313 | The data type of the field. 314 | You can choose from the following options when selecting a data type for different fields: 315 | 316 | Primary key field: Use DataType.INT64 or DataType.VARCHAR. 317 | 318 | Scalar fields: Choose from a variety of options, including: 319 | 320 | DataType.BOOL, DataType.INT8, DataType.INT16, DataType.INT32, DataType.INT64, 321 | DataType.FLOAT, DataType.DOUBLE, DataType.BINARY_VECTOR, DataType.FLOAT_VECTOR, 322 | DataType.FLOAT16_VECTOR, __DataType.BFLOAT16_VECTOR, DataType.VARCHAR, 323 | DataType.JSON, and DataType.ARRAY. 324 | 325 | Vector fields: Select DataType.BINARY_VECTOR or DataType.FLOAT_VECTOR. 326 | 327 | is_primary : bool 328 | Whether the current field is the primary field in a collection. 329 | **Each collection has only one primary field. 330 | 331 | **kwargs : Any 332 | 333 | max_length (int) - 334 | The maximum length of the field value. 335 | This is mandatory for a DataType.VARCHAR field. 336 | 337 | element_type (str) - 338 | The data type of the elements in the field value. 339 | This is mandatory for a DataType.Array field. 340 | 341 | max_capacity (int) - 342 | The number of elements in an Array field value. 343 | This is mandatory for a DataType.Array field. 344 | 345 | dim (int) - 346 | The dimension of the vector embeddings. 347 | This is mandatory for a DataType.FLOAT_VECTOR field or a DataType.BINARY_VECTOR field. 348 | 349 | Returns 350 | ------- 351 | CollectionSchema 352 | A Schema instance represents the schema of a collection. 353 | 354 | Raises 355 | ------ 356 | AddSchemaFieldFailureException 357 | If schema create failure. 358 | """ 359 | try: 360 | schema.add_field( 361 | field_name=field_name, 362 | datatype=datatype, 363 | is_primary=is_primary, 364 | **kwargs, 365 | ) 366 | except MilvusException as e: 367 | raise AddSchemaFieldFailureException( 368 | f"Fail to add {field_name} to current schema." 369 | ) from e 370 | 371 | return schema 372 | 373 | def prepare_index_params(self) -> IndexParams: 374 | """Prepare an index object.""" 375 | index_params = self.milvus_client.prepare_index_params() 376 | 377 | return index_params 378 | 379 | def add_index( 380 | self, 381 | index_params: IndexParams, 382 | field_name: str, 383 | index_type: str = "AUTOINDEX", 384 | index_name: Optional[str] = None, 385 | metric_type: str = "COSINE", 386 | params: Optional[Dict[str, Any]] = None, 387 | ) -> IndexParams: 388 | """Add an index to IndexParams Object. 389 | 390 | Parameters 391 | ---------- 392 | index_params : IndexParams 393 | index object 394 | 395 | field_name : str 396 | The name of the target file to apply this object applies. 397 | 398 | index_name : str 399 | The name of the index file generated after this object has been applied. 400 | 401 | index_type : str 402 | The name of the algorithm used to arrange data in the specific field. 403 | 404 | metric_type : str 405 | The algorithm that is used to measure similarity between vectors. Possible values are IP, L2, and COSINE. 406 | 407 | params : dict 408 | The fine-tuning parameters for the specified index type. For details on possible keys and value ranges, refer to In-memory Index. 409 | """ 410 | index_params.add_index( 411 | field_name=field_name, 412 | index_type=index_type, 413 | index_name=index_name, 414 | metric_type=metric_type, 415 | params=params, 416 | ) 417 | 418 | return index_params 419 | 420 | def create_index( 421 | self, 422 | collection_name: str, 423 | index_params: IndexParams, 424 | timeout: Optional[float] = None, 425 | **kwargs: Dict[str, Any], 426 | ) -> None: 427 | """Create an index. 428 | 429 | Parameters 430 | ---------- 431 | index_params : IndexParams 432 | index object 433 | 434 | collection_name : str 435 | The name of the collection. 436 | 437 | timeout : Optional[float] 438 | The maximum duration to wait for the operation to complete before timing out. 439 | """ 440 | self.milvus_client.create_index( 441 | collection_name=collection_name, 442 | index_params=index_params, 443 | timeout=timeout, 444 | **kwargs, 445 | ) 446 | 447 | def create_collection( 448 | self, 449 | collection_name: str, 450 | dimension: Optional[int] = None, 451 | metric_type: str = "COSINE", 452 | timeout: Optional[float] = None, 453 | schema: Optional[CollectionSchema] = None, 454 | index_params: Optional[IndexParams] = None, 455 | enable_dynamic_field: bool = True, 456 | **kwargs: Any, 457 | ) -> None: 458 | """Create a new collection. 459 | 460 | If the collection does not exist, it creates a new collection with the specified. 461 | 462 | Parameters 463 | ---------- 464 | collection_name : str 465 | [REQUIRED] 466 | The name of the collection to create. 467 | 468 | dimension : int 469 | The dimension of the vector field in the collection. 470 | The reason choosing 1024 as default is that the model 471 | "text-embedding-3-small" we use generates a size of 1024 embeddings 472 | 473 | metric_type : str 474 | The metric used to measure similarities between vector embeddings in the collection. 475 | 476 | timeout : Optional[float] 477 | The maximum duration to wait for the operation to complete before timing out. 478 | 479 | schema : Optional[CollectionSchema] 480 | Defines the structure of the collection. 481 | 482 | enable_dynamic_field: bool: 483 | True can insert data without creating a schema first. 484 | 485 | Raises 486 | ------ 487 | CollectionAlreadyExistsException 488 | If the collection already exists. 489 | """ 490 | try: 491 | # Detect whether the collection exist or not 492 | self.get_collection_stats(collection_name, timeout) 493 | except CollectionNotFoundException: 494 | pass 495 | else: 496 | raise CollectionAlreadyExistsException( 497 | f"Collection {collection_name} already exists" 498 | ) 499 | 500 | try: 501 | self.milvus_client.create_collection( 502 | collection_name=collection_name, 503 | dimension=dimension, 504 | metric_type=metric_type, 505 | schema=schema, 506 | index_params=index_params, 507 | timeout=timeout, 508 | enable_dynamic_field=enable_dynamic_field, 509 | **kwargs, 510 | ) 511 | except MilvusException as e: 512 | raise CollectionCreateFailureException( 513 | f"Collection {collection_name} fail to create" 514 | ) from e 515 | 516 | def crate_partition( 517 | self, 518 | collection_name: str, 519 | partition_name: str, 520 | timeout: Optional[float] = None, 521 | ) -> None: 522 | """Create a partition in collection. 523 | 524 | Parameters 525 | ---------- 526 | collection_name : str 527 | [REQUIRED] 528 | The name of the collection to add partition. 529 | 530 | partition_name : str 531 | [REQUIRED] 532 | The name of the partition to create. 533 | 534 | timeout : Optional[float] 535 | The timeout duration for this operation. 536 | Setting this to None indicates that this operation timeouts when any response arrives or any error occurs. 537 | 538 | Raises 539 | ------ 540 | PartitionCreateFailureException 541 | If partition create failure. 542 | """ 543 | try: 544 | self.milvus_client.create_partition( 545 | collection_name=collection_name, 546 | partition_name=partition_name, 547 | timeout=timeout, 548 | ) 549 | except MilvusException as e: 550 | raise PartitionCreateFailureException( 551 | f"Partition {partition_name} fail to create" 552 | ) from e 553 | 554 | def drop_partition( 555 | self, 556 | collection_name: str, 557 | partition_name: str, 558 | timeout: Optional[float] = None, 559 | ) -> None: 560 | """Drop a partition in collection. 561 | 562 | Parameters 563 | ---------- 564 | collection_name : str 565 | [REQUIRED] 566 | The name of the collection to drop partition. 567 | 568 | partition_name : str 569 | [REQUIRED] 570 | The name of the partition to drop. 571 | 572 | timeout : Optional[float] 573 | The timeout duration for this operation. 574 | Setting this to None indicates that this operation timeouts when any response arrives or any error occurs. 575 | 576 | Raises 577 | ------ 578 | PartitionDropFailureException 579 | If partition drop failure. 580 | """ 581 | try: 582 | self.milvus_client.drop_partition( 583 | collection_name=collection_name, 584 | partition_name=partition_name, 585 | timeout=timeout, 586 | ) 587 | except MilvusException as e: 588 | raise PartitionDropFailureException( 589 | f"Partition {partition_name} fail to drop" 590 | ) from e 591 | 592 | def list_partition( 593 | self, collection_name: str, timeout: Optional[float] = None 594 | ) -> List[str]: 595 | """List all partitions in the specific collection. 596 | 597 | Parameters 598 | ---------- 599 | collection_name : str 600 | [REQUIRED] 601 | The name of the collection to add partition. 602 | 603 | timeout : Optional[float] 604 | The timeout duration for this operation. 605 | Setting this to None indicates that this operation timeouts when any response arrives or any error occurs. 606 | 607 | Returns 608 | ------- 609 | partitions : list[str] 610 | All the partitions in that specific collection. 611 | 612 | Raises 613 | ------ 614 | PartitionListFailureException 615 | If partition listing failure. 616 | """ 617 | try: 618 | partitions = self.milvus_client.list_partitions( 619 | collection_name=collection_name, timeout=timeout 620 | ) 621 | except MilvusException as e: 622 | raise PartitionListFailureException( 623 | f"Partitions from {collection_name} fail to list" 624 | ) from e 625 | 626 | return partitions 627 | 628 | def drop_collection(self, collection_name: str) -> None: 629 | """Delete an existing collection. 630 | 631 | Parameters 632 | ---------- 633 | collection_name : str 634 | The name of the collection. 635 | 636 | Raises 637 | ------ 638 | CollectionNotFoundException 639 | If the collection does not exist. 640 | """ 641 | try: 642 | self.milvus_client.drop_collection(collection_name=collection_name) 643 | except MilvusException as e: 644 | raise CollectionNotFoundException( 645 | f"Collection {collection_name} not found" 646 | ) from e 647 | 648 | def upload_documents( 649 | self, 650 | collection_name: str, 651 | documents: List[str | pathlib.Path], 652 | partition_name: Optional[str] = None, 653 | embedding_model: str = "text-embedding-3-small", 654 | ) -> None: 655 | """Upload documents to the index. 656 | 657 | Parameters 658 | ---------- 659 | collection_name : str 660 | The name of the collection 661 | 662 | documents : list[str | pathlib.Path] 663 | The documents to upload. 664 | 665 | partition_name : str | None 666 | The name of the partition in that collection to insert the data 667 | 668 | embedding_model : str 669 | The OpenAI embedding model to use. 670 | """ 671 | # don't allow for duplicate documents 672 | documents = list(set(documents)) 673 | if not documents: 674 | logger.info("No documents to upload") 675 | return 676 | 677 | logger.info(f"Parsing {len(documents)} documents") 678 | all_chunks: list[Document] = [] 679 | for document in documents: 680 | chunks_ = parse_and_split(document) 681 | chunks = clean_chunks(chunks_) 682 | all_chunks.extend(chunks) 683 | 684 | logger.info(f"Embedding {len(all_chunks)} chunks") 685 | embeddings = generate_embeddings( 686 | openai_api_key=self.openai_client.api_key, 687 | chunks=[c.page_content for c in all_chunks], 688 | model=embedding_model, 689 | ) 690 | 691 | if len(embeddings) != len(all_chunks): 692 | raise ValueError( 693 | "Number of embeddings does not match number of chunks" 694 | ) 695 | 696 | data = [] 697 | for i, (chunk, embedding) in enumerate(zip(all_chunks, embeddings)): 698 | rawdata = MilvusMetadata( 699 | text=chunk.page_content, 700 | page_number=chunk.metadata["page"], 701 | chunk_number=chunk.metadata["chunk"], 702 | filename=chunk.metadata["source"], 703 | vector=embedding, 704 | ) 705 | metadata = { 706 | "text": rawdata.text, 707 | "page_number": str(rawdata.page_number), 708 | "chunk_number": str(rawdata.chunk_number), 709 | "filename": rawdata.filename, 710 | "embedding": list(rawdata.vector), 711 | } 712 | 713 | data.append(metadata) 714 | 715 | response = self.milvus_client.insert( 716 | collection_name=collection_name, 717 | partition_name=partition_name, 718 | data=data, 719 | ) 720 | 721 | insert_count = response["insert_count"] 722 | logger.info(f"Inserted {insert_count} documents") 723 | 724 | def clean_text(self, text: str) -> str: 725 | """Return a lower case version of text with punctuation removed. 726 | 727 | Parameters 728 | ---------- 729 | text : str 730 | The raw text to be cleaned. 731 | 732 | Returns 733 | ------- 734 | str: The cleaned text string. 735 | """ 736 | text_processed = re.sub("[^0-9a-zA-Z ]+", "", text.lower()) 737 | text_processed_further = re.sub(" +", " ", text_processed) 738 | return text_processed_further 739 | 740 | def create_search_params( 741 | self, 742 | metric_type: str = "COSINE", 743 | params: Optional[Dict[str, Any]] = None, 744 | ) -> Dict[str, Any]: 745 | """Create search parameters for the Milvus search.""" 746 | if params is None: 747 | params = {} 748 | 749 | search_params = {"metric_type": metric_type, "params": params} 750 | 751 | return search_params 752 | 753 | def search( 754 | self, 755 | question: str, 756 | collection_name: str, 757 | anns_field: Optional[str] = None, 758 | partition_names: Optional[List[str]] = None, 759 | filter: str = "", 760 | limit: int = 5, 761 | output_fields: Optional[List[str]] = None, 762 | search_params: Optional[Dict[str, Any]] = None, 763 | chat_model: str = "gpt-4-1106-preview", 764 | chat_temperature: float = 0.0, 765 | chat_max_tokens: int = 1000, 766 | chat_seed: int = 2, 767 | embedding_model: str = "text-embedding-3-small", 768 | **kwargs: Dict[str, Any], 769 | ) -> QueryReturnType: 770 | """Query the index. 771 | 772 | Parameters 773 | ---------- 774 | collection_name : str 775 | Name of the collection. 776 | 777 | anns_field : str 778 | Specific Field to search on. 779 | 780 | question : str 781 | The question to ask. 782 | 783 | limit : int 784 | The maximum number of answers to return. 785 | 786 | output_fields : str 787 | The field that should return. 788 | 789 | chat_model : str 790 | The OpenAI chat model to use. 791 | 792 | chat_temperature : float 793 | The temperature for the chat model. 794 | 795 | chat_max_tokens : int 796 | The maximum number of tokens for the chat model. 797 | 798 | chat_seed : int 799 | The seed for the chat model. 800 | 801 | embedding_model : str 802 | The OpenAI embedding model to use. 803 | 804 | 805 | Returns 806 | ------- 807 | QueryReturnType 808 | Dictionary with keys "answer", "matches", and "used_contexts". 809 | The "answer" is the answer to the question. 810 | The "matches" are the "top_k" matches from the index. 811 | The "used_contexts" are the indices of the matches 812 | that were actually used to answer the question. 813 | 814 | Raises 815 | ------ 816 | OpenAIException 817 | If there is an error with the OpenAI API. Some possible reasons 818 | include the chat model not finishing or the response not being 819 | valid JSON. 820 | """ 821 | if output_fields is None: 822 | output_fields = ["text", "filename", "page_number"] 823 | 824 | if search_params is None: 825 | search_params = {} 826 | 827 | logger.info(f"Filter: {filter} and Search params: {search_params}") 828 | 829 | # size of 1024 830 | question_embedding = generate_embeddings( 831 | openai_api_key=self.openai_client.api_key, 832 | chunks=[question], 833 | model=embedding_model, 834 | )[0] 835 | 836 | match_texts: List[str] = [] 837 | 838 | results: Optional[List[Any]] = [] 839 | i = 0 840 | while results is not None and i < 5: 841 | results = self.milvus_client.search( 842 | collection_name=collection_name, 843 | anns_field=anns_field, 844 | partition_names=partition_names, 845 | filter=filter, 846 | data=[question_embedding], 847 | output_fields=[output_fields], 848 | limit=limit, 849 | search_params=search_params, 850 | **kwargs, 851 | ) 852 | i += 1 853 | 854 | if results is not None: 855 | for result in results: 856 | text = result[0]["entity"]["text"] 857 | match_texts.append(text) 858 | 859 | # Proceed to create prompt, send it to OpenAI, and handle the response 860 | prompt = self.create_prompt(question, match_texts) 861 | response = self.openai_client.chat.completions.create( 862 | model=chat_model, 863 | seed=chat_seed, 864 | temperature=chat_temperature, 865 | messages=[{"role": "user", "content": prompt}], 866 | max_tokens=chat_max_tokens, 867 | ) 868 | 869 | output = self.process_response(response) 870 | 871 | return_dict: QueryReturnType = { 872 | "answer": output.answer, 873 | "matches": [], 874 | "used_contexts": output.contexts, 875 | } 876 | 877 | if results is not None and len(results) > 0: 878 | return_dict["matches"] = results[0] 879 | 880 | return return_dict 881 | 882 | def create_prompt(self, question: str, match_texts: list[str]) -> str: 883 | """Create the prompt for the OpenAI chat completion. 884 | 885 | Parameters 886 | ---------- 887 | question : str 888 | The question to ask. 889 | 890 | match_texts : list[str] 891 | The list of context strings to include in the prompt. 892 | 893 | Returns 894 | ------- 895 | str 896 | The generated prompt. 897 | """ 898 | input_actual = Input(question=question, contexts=match_texts) 899 | prompt_end = f""" 900 | ACTUAL INPUT 901 | ```json 902 | {input_actual.model_dump_json()} 903 | ``` 904 | 905 | ACTUAL OUTPUT 906 | """ 907 | return f"{PROMPT_START}\n{prompt_end}" 908 | 909 | def process_response(self, response: Any) -> Output: 910 | """Process the OpenAI chat completion response. 911 | 912 | Parameters 913 | ---------- 914 | response : Any 915 | The OpenAI chat completion response. 916 | 917 | Returns 918 | ------- 919 | Output 920 | The processed output. 921 | 922 | Raises 923 | ------ 924 | OpenAIException 925 | If the chat model did not finish or the response is not valid JSON. 926 | """ 927 | choice = response.choices[0] 928 | if choice.finish_reason != "stop": 929 | raise OpenAIException( 930 | f"Chat did not finish. Reason: {choice.finish_reason}" 931 | ) 932 | 933 | response_raw = cast(str, response.choices[0].message.content) 934 | 935 | if response_raw.startswith("```json"): 936 | start_i = response_raw.index("{") 937 | end_i = response_raw.rindex("}") 938 | response_raw = response_raw[start_i : end_i + 1] 939 | 940 | try: 941 | output = Output.model_validate_json(response_raw) 942 | except ValidationError as e: 943 | raise OpenAIException( 944 | f"OpenAI did not return a valid JSON: {response_raw}" 945 | ) from e 946 | 947 | return output 948 | -------------------------------------------------------------------------------- /src/whyhow_rbr/rag_qdrant.py: -------------------------------------------------------------------------------- 1 | """Rule based RAG with Qdrant.""" 2 | 3 | import logging 4 | import pathlib 5 | import re 6 | import uuid 7 | from typing import Any, TypedDict, cast 8 | 9 | from langchain_core.documents import Document 10 | from openai import OpenAI 11 | from pydantic import BaseModel, Field, ValidationError, field_validator 12 | from qdrant_client import QdrantClient, models 13 | 14 | from whyhow_rbr.embedding import generate_embeddings 15 | from whyhow_rbr.exceptions import ( 16 | CollectionAlreadyExistsException, 17 | CollectionNotFoundException, 18 | OpenAIException, 19 | ) 20 | from whyhow_rbr.processing import clean_chunks, parse_and_split 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | class Metadata(BaseModel, extra="forbid"): 26 | """The metadata to be stored in Qdrant. 27 | 28 | Attributes 29 | ---------- 30 | text : str 31 | The text of the document. 32 | 33 | page_number : int 34 | The page number of the document. 35 | 36 | chunk_number : int 37 | The chunk number of the document. 38 | 39 | filename : str 40 | The filename of the document. 41 | """ 42 | 43 | text: str 44 | page_number: int 45 | chunk_number: int 46 | filename: str 47 | 48 | 49 | class QdrantDocument(BaseModel, extra="forbid"): 50 | """The actual document to be stored in Qdrant. 51 | 52 | Attributes 53 | ---------- 54 | metadata : Metadata 55 | The metadata of the document. 56 | 57 | vector : list[float] | None 58 | The vector embedding representing the document. 59 | 60 | id : str | None 61 | UUID of the document. 62 | """ 63 | 64 | metadata: Metadata 65 | vector: list[float] 66 | id: str = Field(default_factory=lambda: str(uuid.uuid4())) 67 | 68 | 69 | class QdrantMatch(BaseModel, extra="ignore"): 70 | """The match returned from Qdrant. 71 | 72 | Attributes 73 | ---------- 74 | id : str 75 | The ID of the document. 76 | 77 | score : float 78 | The score of the match. Its meaning depends on the distance used for 79 | the collection. 80 | 81 | metadata : Metadata 82 | The metadata of the document. 83 | 84 | """ 85 | 86 | id: str | int 87 | score: float 88 | metadata: Metadata 89 | 90 | 91 | class Rule(BaseModel): 92 | """Retrieval rule. 93 | 94 | The rule is used to filter the documents in the collection. 95 | 96 | Attributes 97 | ---------- 98 | filename : str | None 99 | The filename of the document. 100 | 101 | uuid : str | None 102 | The UUID of the document. 103 | 104 | page_numbers : list[int] | None 105 | The page numbers of the document. 106 | 107 | keywords : list[str] | None 108 | The keywords to trigger a rule. 109 | """ 110 | 111 | filename: str | None = None 112 | uuid: str | None = None 113 | page_numbers: list[int] | None = None 114 | keywords: list[str] | None = None 115 | 116 | @field_validator("page_numbers", mode="before") 117 | @classmethod 118 | def convert_empty_to_none(cls, v: list[int] | None) -> list[int] | None: 119 | """Convert empty list to None.""" 120 | if v is not None and not v: 121 | return None 122 | return v 123 | 124 | @field_validator("keywords", mode="before") 125 | @classmethod 126 | def convert_empty_str_to_none( 127 | cls, s: list[str] | None 128 | ) -> list[str] | None: 129 | """Convert empty string list to None.""" 130 | if s is not None and not s: 131 | return None 132 | return s 133 | 134 | def to_filter(self) -> models.Filter | None: 135 | """Convert rule to Qdrant filter format.""" 136 | if not any([self.filename, self.uuid, self.page_numbers]): 137 | return None 138 | 139 | conditions: list[models.Condition] = [] 140 | if self.filename is not None: 141 | conditions.append( 142 | models.FieldCondition( 143 | key="filename", 144 | match=models.MatchValue(value=self.filename), 145 | ) 146 | ) 147 | if self.uuid is not None: 148 | conditions.append( 149 | models.HasIdCondition(has_id=[self.uuid]), 150 | ) 151 | if self.page_numbers is not None: 152 | conditions.append( 153 | models.FieldCondition( 154 | key="page_number", 155 | match=models.MatchAny(any=self.page_numbers), 156 | ) 157 | ) 158 | 159 | filter_ = models.Filter(must=conditions) 160 | return filter_ 161 | 162 | 163 | class Input(BaseModel): 164 | """Example input for the prompt. 165 | 166 | Attributes 167 | ---------- 168 | question : str 169 | The question to ask. 170 | 171 | contexts : list[str] 172 | The contexts to use for answering the question. 173 | """ 174 | 175 | question: str 176 | contexts: list[str] 177 | 178 | 179 | class Output(BaseModel): 180 | """Example output for the prompt. 181 | 182 | Attributes 183 | ---------- 184 | answer : str 185 | The answer to the question. 186 | 187 | contexts : list[int] 188 | The indices of the contexts that were used to answer the question. 189 | """ 190 | 191 | answer: str 192 | contexts: list[int] 193 | 194 | 195 | input_example_1 = Input( 196 | question="What is the capital of France?", 197 | contexts=[ 198 | "The capital of France is Paris.", 199 | "The capital of France is not London.", 200 | "Paris is beautiful and it is also the capital of France.", 201 | ], 202 | ) 203 | output_example_1 = Output(answer="Paris", contexts=[0, 2]) 204 | 205 | input_example_2 = Input( 206 | question="What are the impacts of climate change on global agriculture?", 207 | contexts=[ 208 | "Climate change can lead to more extreme weather patterns, affecting crop yields.", 209 | "Rising sea levels due to climate change can inundate agricultural lands in coastal areas, reducing arable land.", 210 | "Changes in temperature and precipitation patterns can shift agricultural zones, impacting food security.", 211 | ], 212 | ) 213 | 214 | output_example_2 = Output( 215 | answer="Variable impacts including altered weather patterns, reduced arable land, shifting agricultural zones, increased pests and diseases, with potential mitigation through technology and sustainable practices", 216 | contexts=[0, 1, 2], 217 | ) 218 | 219 | input_example_3 = Input( 220 | question="How has the concept of privacy evolved with the advent of digital technology?", 221 | contexts=[ 222 | "Digital technology has made it easier to collect, store, and analyze personal data, raising privacy concerns.", 223 | "Social media platforms and smartphones often track user activity and preferences, leading to debates over consent and data ownership.", 224 | "Encryption and secure communication technologies have evolved as means to protect privacy in the digital age.", 225 | "Legislation like the GDPR in the EU has been developed to address privacy concerns and regulate data handling by companies.", 226 | "The concept of privacy is increasingly being viewed through the lens of digital rights and cybersecurity.", 227 | ], 228 | ) 229 | 230 | output_example_3 = Output( 231 | answer="Evolving with challenges due to data collection and analysis, changes in legislation, and advancements in encryption and security, amidst ongoing debates over consent and data ownership", 232 | contexts=[0, 1, 2, 3, 4], 233 | ) 234 | 235 | 236 | class QueryReturnType(TypedDict): 237 | """The return type of the query method. 238 | 239 | Attributes 240 | ---------- 241 | answer : str 242 | The answer to the question. 243 | 244 | matches : list[dict[str, Any]] 245 | The retrieved documents from the collection. 246 | 247 | used_contexts : list[int] 248 | The indices of the matches that were actually used to answer the question. 249 | """ 250 | 251 | answer: str 252 | matches: list[dict[str, Any]] 253 | used_contexts: list[int] 254 | 255 | 256 | PROMPT_START = f"""\ 257 | You are a helpful assistant. I will give you a question and provide multiple 258 | context documents. You will need to answer the question based on the contexts 259 | and also specify in which context(s) you found the answer. 260 | If you don't find the answer in the context, you can use your own knowledge, however, 261 | in that case, the contexts array should be empty. 262 | 263 | Both the input and the output are JSON objects. 264 | 265 | # EXAMPLE INPUT 266 | # ```json 267 | # {input_example_1.model_dump_json()} 268 | # ``` 269 | 270 | # EXAMPLE OUTPUT 271 | # ```json 272 | # {output_example_1.model_dump_json()} 273 | 274 | # EXAMPLE INPUT 275 | # ```json 276 | # {input_example_2.model_dump_json()} 277 | # ``` 278 | 279 | # EXAMPLE OUTPUT 280 | # ```json 281 | # {output_example_2.model_dump_json()} 282 | 283 | # EXAMPLE INPUT 284 | # ```json 285 | # {input_example_3.model_dump_json()} 286 | # ``` 287 | 288 | # EXAMPLE OUTPUT 289 | # ```json 290 | # {output_example_3.model_dump_json()} 291 | 292 | """ 293 | 294 | 295 | class Client: 296 | """RBR client for Qdrant.""" 297 | 298 | def __init__( 299 | self, 300 | oclient: OpenAI, 301 | qclient: QdrantClient, 302 | ): 303 | self.openai_client = oclient 304 | self.qdrant_client = qclient 305 | 306 | def create_collection( 307 | self, 308 | collection_name: str, 309 | size: int = 1536, 310 | distance: models.Distance = models.Distance.COSINE, 311 | **collection_kwargs: Any, 312 | ) -> None: 313 | """Create a new collection. 314 | 315 | Parameters 316 | ---------- 317 | collection_name : str 318 | The name of the collection. 319 | 320 | size : int 321 | The dimension of the vectors. 322 | 323 | distance : Distance 324 | The distance metric to use for the collection. 325 | 326 | collection_kwargs : Any 327 | Additional arguments to pass to QdrantClient#create_collection. 328 | 329 | Raises 330 | ------ 331 | CollectionAlreadyExistsException 332 | If the collection already exists. 333 | 334 | """ 335 | if self.qdrant_client.collection_exists(collection_name): 336 | raise CollectionAlreadyExistsException() 337 | 338 | collection_opts = { 339 | "collection_name": collection_name, 340 | "vectors_config": models.VectorParams( 341 | size=size, distance=distance 342 | ), 343 | **collection_kwargs, 344 | } 345 | 346 | self.qdrant_client.create_collection(**collection_opts) 347 | 348 | def upload_documents( 349 | self, 350 | collection_name: str, 351 | documents: list[str | pathlib.Path], 352 | embedding_model: str = "text-embedding-3-small", 353 | batch_size: int = 64, 354 | ) -> None: 355 | """Upload documents to the collection. 356 | 357 | Parameters 358 | ---------- 359 | collection_name : str 360 | The name of the collection. 361 | 362 | documents : list[str | pathlib.Path] 363 | The documents to upload. 364 | 365 | embedding_model : str 366 | The OpenAI embedding model to use. 367 | 368 | batch_size : int 369 | The number of documents to upload at a time. 370 | """ 371 | if not self.qdrant_client.collection_exists(collection_name): 372 | raise CollectionNotFoundException() 373 | 374 | documents = list(set(documents)) 375 | if not documents: 376 | logger.info("No documents to upload") 377 | return 378 | 379 | logger.info(f"Parsing {len(documents)} documents") 380 | all_chunks: list[Document] = [] 381 | for document in documents: 382 | chunks_ = parse_and_split(document) 383 | chunks = clean_chunks(chunks_) 384 | all_chunks.extend(chunks) 385 | 386 | logger.info(f"Embedding {len(all_chunks)} chunks") 387 | embeddings = generate_embeddings( 388 | openai_api_key=self.openai_client.api_key, 389 | chunks=[c.page_content for c in all_chunks], 390 | model=embedding_model, 391 | ) 392 | 393 | if len(embeddings) != len(all_chunks): 394 | raise ValueError( 395 | "Number of embeddings does not match number of chunks" 396 | ) 397 | 398 | qdrant_documents: list[QdrantDocument] = [] 399 | for i, (chunk, embedding) in enumerate(zip(all_chunks, embeddings)): 400 | metadata = Metadata( 401 | text=chunk.page_content, 402 | page_number=chunk.metadata["page"], 403 | chunk_number=chunk.metadata["chunk"], 404 | filename=chunk.metadata["source"], 405 | ) 406 | qdrant_document = QdrantDocument( 407 | vector=embedding, metadata=metadata 408 | ) 409 | qdrant_documents.append(qdrant_document) 410 | 411 | points = [ 412 | models.PointStruct( 413 | id=d.id, vector=d.vector, payload=d.metadata.model_dump() 414 | ) 415 | for d in qdrant_documents 416 | ] 417 | 418 | self.qdrant_client.upload_points( 419 | collection_name, points, batch_size=batch_size 420 | ) 421 | 422 | logger.info(f"Upserted {len(points)} documents") 423 | 424 | def clean_text(self, text: str) -> str: 425 | """Return a lower case version of text with punctuation removed. 426 | 427 | Parameters 428 | ---------- 429 | text : str 430 | The raw text to be cleaned. 431 | 432 | Returns 433 | ------- 434 | str: The cleaned text string. 435 | """ 436 | text_processed = re.sub("[^0-9a-zA-Z ]+", "", text.lower()) 437 | text_processed_further = re.sub(" +", " ", text_processed) 438 | return text_processed_further 439 | 440 | def query( 441 | self, 442 | question: str, 443 | collection_name: str, 444 | rules: list[Rule] | None = None, 445 | top_k: int = 5, 446 | chat_model: str = "gpt-4o", 447 | chat_temperature: float = 0.0, 448 | chat_max_tokens: int = 1000, 449 | chat_seed: int = 2, 450 | embedding_model: str = "text-embedding-3-small", 451 | process_rules_separately: bool = False, 452 | keyword_trigger: bool = False, 453 | ) -> QueryReturnType: 454 | """Query the collection. 455 | 456 | Parameters 457 | ---------- 458 | question : str 459 | The question to ask. 460 | 461 | collection_name : str 462 | The name of the collection. 463 | 464 | rules : list[Rule] | None 465 | The rules to use for filtering the documents. 466 | 467 | top_k : int 468 | The number of matches to return per rule. 469 | 470 | chat_model : str 471 | The OpenAI chat model to use. 472 | 473 | chat_temperature : float 474 | The temperature for the chat model. 475 | 476 | chat_max_tokens : int 477 | The maximum number of tokens for the chat model. 478 | 479 | chat_seed : int 480 | The seed for the chat model. 481 | 482 | embedding_model : str 483 | The OpenAI embedding model to use. 484 | 485 | process_rules_separately : bool, optional 486 | Whether to process each rule individually and combine the results at the end. 487 | When set to True, each rule will be run independently, ensuring that every rule 488 | returns results. When set to False (default), all rules will be run as one joined 489 | query, potentially allowing one rule to dominate the others. 490 | Default is False. 491 | 492 | keyword_trigger : bool, optional 493 | Whether to trigger rules based on keyword matches in the question. 494 | Default is False. 495 | 496 | Returns 497 | ------- 498 | QueryReturnType 499 | Dictionary with keys "answer", "matches", and "used_contexts". 500 | The "answer" is the answer to the question. 501 | The "matches" are the "top_k" matches from the collection. 502 | The "used_contexts" are the indices of the matches 503 | that were actually used to answer the question. 504 | 505 | Raises 506 | ------ 507 | OpenAIException 508 | If there is an error with the OpenAI API. Some possible reasons 509 | include the chat model not finishing or the response not being 510 | valid JSON. 511 | 512 | CollectionNotFoundException 513 | If the collection does not exist in Qdrant. 514 | """ 515 | if not self.qdrant_client.collection_exists(collection_name): 516 | raise CollectionNotFoundException() 517 | 518 | logger.info(f"Raw rules: {rules}") 519 | 520 | if rules is None: 521 | rules = [] 522 | 523 | if keyword_trigger: 524 | clean_question = set(self.clean_text(question).split(" ")) 525 | rules = [ 526 | rule 527 | for rule in rules 528 | if rule.keywords 529 | and set(map(self.clean_text, rule.keywords)) & clean_question 530 | ] 531 | 532 | rule_filters = [rule.to_filter() for rule in rules if rule.to_filter()] 533 | 534 | question_embedding = generate_embeddings( 535 | openai_api_key=self.openai_client.api_key, 536 | chunks=[question], 537 | model=embedding_model, 538 | )[0] 539 | 540 | matches, match_texts = [], [] 541 | 542 | if not rule_filters: 543 | query_response = self.qdrant_client.query_points( 544 | collection_name=collection_name, 545 | limit=top_k, 546 | query=question_embedding, 547 | with_payload=True, 548 | ).points 549 | matches = [ 550 | QdrantMatch( 551 | id=p.id, 552 | score=p.score, 553 | metadata=Metadata(**p.payload), # type: ignore 554 | ) 555 | for p in query_response 556 | ] 557 | match_texts = [m.metadata.text for m in matches] 558 | else: 559 | if process_rules_separately: 560 | for rule_filter in rule_filters: 561 | query_response = self.qdrant_client.query_points( 562 | collection_name=collection_name, 563 | limit=top_k, 564 | query=question_embedding, 565 | query_filter=rule_filter, 566 | with_payload=True, 567 | ).points 568 | matches.extend( 569 | QdrantMatch( 570 | id=p.id, 571 | score=p.score, 572 | metadata=Metadata(**p.payload), # type: ignore 573 | ) 574 | for p in query_response 575 | ) 576 | match_texts.extend(m.metadata.text for m in matches) 577 | match_texts = list( 578 | set(match_texts) 579 | ) # Ensure unique match texts 580 | else: 581 | combined_filter = models.Filter(must=rule_filters) # type: ignore 582 | query_response = self.qdrant_client.query_points( 583 | collection_name=collection_name, 584 | limit=top_k, 585 | query=question_embedding, 586 | query_filter=combined_filter, 587 | with_payload=True, 588 | ).points 589 | matches = [ 590 | QdrantMatch( 591 | id=p.id, 592 | score=p.score, 593 | metadata=Metadata(**p.payload), # type: ignore 594 | ) 595 | for p in query_response 596 | ] 597 | match_texts = [m.metadata.text for m in matches] 598 | 599 | prompt = self.create_prompt(question, match_texts) 600 | response = self.openai_client.chat.completions.create( 601 | model=chat_model, 602 | seed=chat_seed, 603 | temperature=chat_temperature, 604 | messages=[{"role": "user", "content": prompt}], 605 | max_tokens=chat_max_tokens, 606 | ) 607 | 608 | output = self.process_response(response) 609 | 610 | return { 611 | "answer": output.answer, 612 | "matches": [m.model_dump() for m in matches], 613 | "used_contexts": output.contexts, 614 | } 615 | 616 | def create_prompt(self, question: str, match_texts: list[str]) -> str: 617 | """Create the prompt for the OpenAI chat completion. 618 | 619 | Parameters 620 | ---------- 621 | question : str 622 | The question to ask. 623 | 624 | match_texts : list[str] 625 | The list of context strings to include in the prompt. 626 | 627 | Returns 628 | ------- 629 | str 630 | The generated prompt. 631 | """ 632 | input_actual = Input(question=question, contexts=match_texts) 633 | prompt_end = f""" 634 | ACTUAL INPUT 635 | ```json 636 | {input_actual.model_dump_json()} 637 | ``` 638 | 639 | ACTUAL OUTPUT 640 | """ 641 | return f"{PROMPT_START}\n{prompt_end}" 642 | 643 | def process_response(self, response: Any) -> Output: 644 | """Process the OpenAI chat completion response. 645 | 646 | Parameters 647 | ---------- 648 | response : Any 649 | The OpenAI chat completion response. 650 | 651 | Returns 652 | ------- 653 | Output 654 | The processed output. 655 | 656 | Raises 657 | ------ 658 | OpenAIException 659 | If the chat model did not finish or the response is not valid JSON. 660 | """ 661 | choice = response.choices[0] 662 | if choice.finish_reason != "stop": 663 | raise OpenAIException( 664 | f"Chat did not finish. Reason: {choice.finish_reason}" 665 | ) 666 | 667 | response_raw = cast(str, response.choices[0].message.content) 668 | 669 | if response_raw.startswith("```json"): 670 | start_i = response_raw.index("{") 671 | end_i = response_raw.rindex("}") 672 | response_raw = response_raw[start_i : end_i + 1] 673 | 674 | try: 675 | output = Output.model_validate_json(response_raw) 676 | except ValidationError as e: 677 | raise OpenAIException( 678 | f"OpenAI did not return a valid JSON: {response_raw}" 679 | ) from e 680 | 681 | return output 682 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | import pytest 4 | 5 | 6 | @pytest.fixture 7 | def test_path(): 8 | return pathlib.Path(__file__).parent 9 | 10 | 11 | @pytest.fixture(autouse=True) 12 | def delete_env_vars(monkeypatch): 13 | """Delete environment variables. 14 | 15 | This fixture is used to delete the environment variables that are used 16 | 17 | """ 18 | monkeypatch.delenv("OPENAI_API_KEY", raising=False) 19 | monkeypatch.delenv("PINECONE_API_KEY", raising=False) 20 | -------------------------------------------------------------------------------- /tests/test_dummy.py: -------------------------------------------------------------------------------- 1 | def test(): 2 | assert True 3 | -------------------------------------------------------------------------------- /tests/test_embedding.py: -------------------------------------------------------------------------------- 1 | """Tests for the embedding module.""" 2 | 3 | from unittest.mock import Mock 4 | 5 | import pytest 6 | from langchain_openai import OpenAIEmbeddings 7 | 8 | from whyhow_rbr.embedding import generate_embeddings 9 | 10 | 11 | @pytest.mark.parametrize("model", ["whatever", "else"]) 12 | def test_generate_embeddings(monkeypatch, model): 13 | chunks = ["hello there", "today is a great day"] 14 | 15 | fake_inst = Mock(spec=OpenAIEmbeddings) 16 | fake_inst.embed_documents.side_effect = lambda x: [[2.2, 5.5] for _ in x] 17 | fake_class = Mock(return_value=fake_inst) 18 | 19 | monkeypatch.setattr("whyhow_rbr.embedding.OpenAIEmbeddings", fake_class) 20 | embeddings = generate_embeddings( 21 | chunks=chunks, openai_api_key="test", model=model 22 | ) 23 | 24 | assert fake_class.call_count == 1 25 | assert fake_class.call_args.kwargs["openai_api_key"] == "test" 26 | assert fake_class.call_args.kwargs["model"] == model 27 | 28 | assert fake_inst.embed_documents.call_count == 1 29 | 30 | assert len(embeddings) == 2 31 | assert embeddings[0] == [2.2, 5.5] 32 | assert embeddings[1] == [2.2, 5.5] 33 | -------------------------------------------------------------------------------- /tests/test_processing.py: -------------------------------------------------------------------------------- 1 | """Collection of tests for the processing module.""" 2 | 3 | import pathlib 4 | 5 | import pytest 6 | from fpdf import FPDF 7 | from langchain_core.documents import Document 8 | 9 | from whyhow_rbr.processing import clean_chunks, parse_and_split 10 | 11 | 12 | @pytest.fixture 13 | def dummy_pdf(tmp_path) -> pathlib.Path: 14 | """Create a dummy PDF file.""" 15 | output_path = tmp_path / "dummy.pdf" 16 | 17 | pdf = FPDF() 18 | pdf.set_font("Arial", size=12) 19 | 20 | pdf.add_page() 21 | pdf.cell(200, 10, txt="Welcome to the dummy PDF", ln=True, align="C") 22 | 23 | pdf.add_page() 24 | pdf.cell(200, 10, txt="This is the second page", ln=True, align="C") 25 | 26 | pdf.output(str(output_path)) 27 | 28 | return output_path 29 | 30 | 31 | def test_parse_and_split(dummy_pdf): 32 | """Test the dummy PDF.""" 33 | result = parse_and_split(dummy_pdf, chunk_size=512) 34 | 35 | assert len(result) == 2 36 | assert result[0].page_content == "Welcome to the dummy PDF" 37 | assert result[0].metadata["page"] == 0 38 | assert result[0].metadata["chunk"] == 0 39 | 40 | assert result[1].page_content == "This is the second page" 41 | assert result[1].metadata["page"] == 1 42 | assert result[0].metadata["chunk"] == 0 43 | 44 | 45 | def test_parse_and_split_small_chunks(dummy_pdf): 46 | """Test the dummy PDF.""" 47 | result = parse_and_split(dummy_pdf, chunk_size=7, chunk_overlap=0) 48 | 49 | assert len(result) == 8 50 | assert result[0].page_content == "Welcome" 51 | assert result[0].metadata["page"] == 0 52 | assert result[0].metadata["chunk"] == 0 53 | 54 | assert result[1].page_content == "to the" 55 | assert result[1].metadata["page"] == 0 56 | assert result[1].metadata["chunk"] == 1 57 | 58 | assert result[2].page_content == "dummy" 59 | assert result[2].metadata["page"] == 0 60 | assert result[2].metadata["chunk"] == 2 61 | 62 | assert result[3].page_content == "PDF" 63 | assert result[3].metadata["page"] == 0 64 | assert result[3].metadata["chunk"] == 3 65 | 66 | assert result[4].page_content == "This is" 67 | assert result[4].metadata["page"] == 1 68 | assert result[4].metadata["chunk"] == 0 69 | 70 | assert result[5].page_content == "the" 71 | assert result[5].metadata["page"] == 1 72 | assert result[5].metadata["chunk"] == 1 73 | 74 | assert result[6].page_content == "second" 75 | assert result[6].metadata["page"] == 1 76 | assert result[6].metadata["chunk"] == 2 77 | 78 | assert result[7].page_content == "page" 79 | assert result[7].metadata["page"] == 1 80 | assert result[7].metadata["chunk"] == 3 81 | 82 | 83 | def test_clean_chunks(): 84 | """Test the clean_chunks function.""" 85 | 86 | chunks = [ 87 | Document("This is a \n\ntest", metadata={"page": 0}), 88 | Document("Nothing changes here", metadata={"page": 1}), 89 | Document("Hor\nrible", metadata={"page": 1}), 90 | ] 91 | result = clean_chunks(chunks) 92 | 93 | assert len(result) == 3 94 | assert result[0].page_content == "This is a test" 95 | assert result[1].page_content == "Nothing changes here" 96 | assert result[2].page_content == "Horrible" 97 | 98 | assert result[0].metadata["page"] == 0 99 | assert result[1].metadata["page"] == 1 100 | assert result[2].metadata["page"] == 1 101 | 102 | # make sure the original chunks are not modified 103 | assert chunks[0].page_content == "This is a \n\ntest" 104 | -------------------------------------------------------------------------------- /tests/test_qdrant_rag.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from unittest.mock import Mock 3 | from uuid import uuid4 4 | 5 | import pytest 6 | from langchain_core.documents import Document 7 | from openai import OpenAI 8 | from openai.types.chat.chat_completion import ChatCompletion 9 | from qdrant_client import QdrantClient 10 | from qdrant_client.http import models 11 | 12 | from whyhow_rbr.exceptions import ( 13 | CollectionAlreadyExistsException, 14 | CollectionNotFoundException, 15 | ) 16 | from whyhow_rbr.rag_qdrant import ( 17 | Client, 18 | Metadata, 19 | Output, 20 | QdrantDocument, 21 | Rule, 22 | ) 23 | 24 | 25 | class TestRule: 26 | def test_default(self): 27 | rule = Rule() 28 | 29 | assert rule.filename is None 30 | assert rule.page_numbers is None 31 | assert rule.uuid is None 32 | 33 | def test_empty_page_numbers(self): 34 | rule = Rule(page_numbers=[]) 35 | 36 | assert rule.filename is None 37 | assert rule.page_numbers is None 38 | assert rule.uuid is None 39 | 40 | def test_to_filter(self): 41 | # no conditions 42 | rule = Rule() 43 | assert rule.to_filter() is None 44 | 45 | # only filename 46 | rule = Rule(filename="hello.pdf") 47 | assert rule.to_filter() == models.Filter( 48 | must=[ 49 | models.FieldCondition( 50 | key="filename", match=models.MatchValue(value="hello.pdf") 51 | ) 52 | ] 53 | ) 54 | 55 | # only page_numbers 56 | rule = Rule(page_numbers=[1, 2]) 57 | assert rule.to_filter() == models.Filter( 58 | must=[ 59 | models.FieldCondition( 60 | key="page_number", match=models.MatchAny(any=[1, 2]) 61 | ) 62 | ] 63 | ) 64 | 65 | 66 | class TestQdrantDocument: 67 | def test_auto_id(self): 68 | metadata = Metadata( 69 | text="hello world", 70 | page_number=1, 71 | chunk_number=0, 72 | filename="hello.pdf", 73 | ) 74 | doc = QdrantDocument( 75 | vector=[0.2, 0.3], 76 | metadata=metadata, 77 | ) 78 | 79 | assert doc.id is not None 80 | 81 | def test_provide_id(self): 82 | _id = str(uuid4()) 83 | metadata = Metadata( 84 | text="hello world", 85 | page_number=1, 86 | chunk_number=0, 87 | filename="hello.pdf", 88 | ) 89 | doc = QdrantDocument( 90 | vector=[0.2, 0.3], 91 | metadata=metadata, 92 | id=_id, 93 | ) 94 | 95 | assert doc.id == _id 96 | 97 | 98 | @pytest.fixture(name="client") 99 | def patched_client(monkeypatch): 100 | monkeypatch.setenv("OPENAI_API_KEY", "secret_openai") 101 | 102 | fake_qdrant_instance = Mock(spec=QdrantClient) 103 | fake_qdrant_class = Mock(return_value=fake_qdrant_instance) 104 | 105 | fake_openai_instance = Mock(spec=OpenAI) 106 | fake_openai_class = Mock(return_value=fake_openai_instance) 107 | 108 | monkeypatch.setattr( 109 | "whyhow_rbr.rag_qdrant.QdrantClient", fake_qdrant_class 110 | ) 111 | monkeypatch.setattr("whyhow_rbr.rag.OpenAI", fake_openai_class) 112 | 113 | client = Client(fake_openai_instance, fake_qdrant_instance) 114 | 115 | assert isinstance(client.openai_client, Mock) 116 | assert isinstance(client.qdrant_client, Mock) 117 | 118 | return client 119 | 120 | 121 | class TestClient: 122 | def test_collection(self, client): 123 | def side_effect(*args, **kwargs): 124 | raise CollectionNotFoundException() 125 | 126 | client.qdrant_client.collection_exists.side_effect = side_effect 127 | 128 | with pytest.raises(CollectionNotFoundException): 129 | client.query("some question", "some collection") 130 | 131 | def test_create_index(self, client, monkeypatch): 132 | client.qdrant_client.collection_exists.return_value = False 133 | client.create_collection("some name") 134 | assert client.qdrant_client.collection_exists.call_count == 1 135 | assert client.qdrant_client.create_collection.call_count == 1 136 | 137 | client.qdrant_client.collection_exists.return_value = True 138 | with pytest.raises(CollectionAlreadyExistsException): 139 | client.create_collection("some name") 140 | 141 | def test_upload_documents_nothing(self, client, caplog): 142 | caplog.set_level(logging.INFO) 143 | client.upload_documents( 144 | "some collection", 145 | documents=[], 146 | ) 147 | 148 | captured = caplog.records[0] 149 | 150 | assert captured.levelname == "INFO" 151 | assert "No documents to upload" in captured.message 152 | 153 | def test_upload_document(self, client, caplog, monkeypatch): 154 | caplog.set_level(logging.INFO) 155 | documents = ["doc1.pdf", "doc2.pdf"] 156 | 157 | client.openai_client.api_key = "fake" 158 | client.qdrant_client.upload_points = Mock(return_value=None) 159 | 160 | parsed_docs = [ 161 | Document( 162 | page_content="hello there", 163 | metadata={ 164 | "page": 0, 165 | "chunk": 0, 166 | "source": "something", 167 | }, 168 | ), 169 | Document( 170 | page_content="again", 171 | metadata={ 172 | "page": 0, 173 | "chunk": 1, 174 | "source": "something", 175 | }, 176 | ), 177 | Document( 178 | page_content="it is cold", 179 | metadata={ 180 | "page": 1, 181 | "chunk": 0, 182 | "source": "something", 183 | }, 184 | ), 185 | ] 186 | fake_parse_and_split = Mock(return_value=parsed_docs) 187 | fake_clean_chunks = Mock(return_value=parsed_docs) 188 | fake_generate_embeddings = Mock(return_value=6 * [[2.2, 0.6]]) 189 | 190 | monkeypatch.setattr( 191 | "whyhow_rbr.rag_qdrant.parse_and_split", fake_parse_and_split 192 | ) 193 | monkeypatch.setattr( 194 | "whyhow_rbr.rag_qdrant.clean_chunks", fake_clean_chunks 195 | ) 196 | monkeypatch.setattr( 197 | "whyhow_rbr.rag_qdrant.generate_embeddings", 198 | fake_generate_embeddings, 199 | ) 200 | 201 | client.upload_documents( 202 | "some collection", 203 | documents=documents, 204 | ) 205 | 206 | assert fake_parse_and_split.call_count == 2 207 | assert fake_clean_chunks.call_count == 2 208 | assert fake_generate_embeddings.call_count == 1 209 | 210 | assert "Parsing 2 documents" == caplog.records[0].message 211 | assert "Embedding 6 chunks" == caplog.records[1].message 212 | assert "Upserted 6 documents" == caplog.records[2].message 213 | 214 | def test_upload_document_inconsistent(self, client, caplog, monkeypatch): 215 | documents = ["doc1.pdf"] 216 | caplog.set_level(logging.INFO) 217 | client.openai_client.api_key = "fake" 218 | 219 | parsed_docs = [ 220 | Document( 221 | page_content="hello there", 222 | metadata={ 223 | "page": 0, 224 | "chunk": 0, 225 | "source": "something", 226 | }, 227 | ), 228 | Document( 229 | page_content="again", 230 | metadata={ 231 | "page": 0, 232 | "chunk": 1, 233 | "source": "something", 234 | }, 235 | ), 236 | Document( 237 | page_content="it is cold", 238 | metadata={ 239 | "page": 1, 240 | "chunk": 0, 241 | "source": "something", 242 | }, 243 | ), 244 | ] 245 | fake_parse_and_split = Mock(return_value=parsed_docs) 246 | fake_clean_chunks = Mock(return_value=parsed_docs) 247 | fake_generate_embeddings = Mock(return_value=5 * [[2.2, 0.6]]) 248 | 249 | monkeypatch.setattr( 250 | "whyhow_rbr.rag_qdrant.parse_and_split", fake_parse_and_split 251 | ) 252 | monkeypatch.setattr( 253 | "whyhow_rbr.rag_qdrant.clean_chunks", fake_clean_chunks 254 | ) 255 | monkeypatch.setattr( 256 | "whyhow_rbr.rag_qdrant.generate_embeddings", 257 | fake_generate_embeddings, 258 | ) 259 | 260 | with pytest.raises( 261 | ValueError, match="Number of embeddings does not match" 262 | ): 263 | client.upload_documents( 264 | "some collection", 265 | documents=documents, 266 | ) 267 | 268 | assert fake_parse_and_split.call_count == 1 269 | assert fake_clean_chunks.call_count == 1 270 | assert fake_generate_embeddings.call_count == 1 271 | 272 | assert "Parsing 1 documents" == caplog.records[0].message 273 | assert "Embedding 3 chunks" == caplog.records[1].message 274 | 275 | def test_query_documents(self, client, monkeypatch): 276 | client.openai_client.api_key = "fake" 277 | client.qdrant_client.collection_exists.return_value = False 278 | with pytest.raises(CollectionNotFoundException): 279 | client.query("some question", "some collection") 280 | 281 | assert client.qdrant_client.collection_exists.call_count == 1 282 | assert client.qdrant_client.query_points.call_count == 0 283 | 284 | client.qdrant_client.collection_exists.return_value = True 285 | client.qdrant_client.query_points.return_value = models.QueryResponse( 286 | points=[] 287 | ) 288 | fake_generate_embeddings = Mock( 289 | return_value=10 * [[0.525, 0.532, 0.5321]] 290 | ) 291 | monkeypatch.setattr( 292 | "whyhow_rbr.rag_qdrant.generate_embeddings", 293 | fake_generate_embeddings, 294 | ) 295 | content = Output( 296 | answer="Hello world", 297 | contexts=[0, 1], 298 | ) 299 | fake_openai_response_rv = ChatCompletion( 300 | id="whatever", 301 | choices=[ 302 | dict( 303 | finish_reason="stop", 304 | index=0, 305 | logprobs=None, 306 | message=dict( 307 | content="```json\n" 308 | + content.model_dump_json() 309 | + "\n```", 310 | role="assistant", 311 | function_call=None, 312 | tool_calls=None, 313 | ), 314 | ) 315 | ], 316 | created=1710065537, 317 | model="gpt-4o", 318 | object="chat.completion", 319 | system_fingerprint="whatever", 320 | usage=dict( 321 | completion_tokens=20, prompt_tokens=679, total_tokens=699 322 | ), 323 | ) 324 | 325 | client.openai_client = Mock() 326 | client.openai_client.api_key = "fake" 327 | client.openai_client.chat.completions.create.return_value = ( 328 | fake_openai_response_rv 329 | ) 330 | client.query("some question", "some collection", top_k=4) 331 | 332 | assert client.qdrant_client.collection_exists.call_count == 2 333 | assert client.qdrant_client.query_points.call_count == 1 334 | client.qdrant_client.query_points.assert_called_with( 335 | collection_name="some collection", 336 | limit=4, 337 | query=[0.525, 0.532, 0.5321], 338 | with_payload=True, 339 | ) 340 | -------------------------------------------------------------------------------- /tests/test_rag.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from unittest.mock import Mock 3 | 4 | import pytest 5 | from langchain_core.documents import Document 6 | from openai import OpenAI 7 | from openai.types.chat.chat_completion import ChatCompletion 8 | from pinecone import Index, NotFoundException, Pinecone 9 | 10 | from whyhow_rbr.exceptions import ( 11 | IndexAlreadyExistsException, 12 | IndexNotFoundException, 13 | OpenAIException, 14 | ) 15 | from whyhow_rbr.rag import ( 16 | Client, 17 | Output, 18 | PineconeDocument, 19 | PineconeMetadata, 20 | Rule, 21 | ) 22 | 23 | 24 | class TestRule: 25 | def test_default(self): 26 | rule = Rule() 27 | 28 | assert rule.filename is None 29 | assert rule.page_numbers is None 30 | assert rule.uuid is None 31 | 32 | def test_empty_page_numbers(self): 33 | rule = Rule(page_numbers=[]) 34 | 35 | assert rule.filename is None 36 | assert rule.page_numbers is None 37 | assert rule.uuid is None 38 | 39 | def test_to_filter(self): 40 | # no conditions 41 | rule = Rule() 42 | assert rule.to_filter() is None 43 | 44 | # only filename 45 | rule = Rule(filename="hello.pdf") 46 | assert rule.to_filter() == { 47 | "$and": [ 48 | {"filename": {"$eq": "hello.pdf"}}, 49 | ] 50 | } 51 | 52 | # only page_numbers 53 | rule = Rule(page_numbers=[1, 2]) 54 | assert rule.to_filter() == { 55 | "$and": [ 56 | {"page_number": {"$in": [1, 2]}}, 57 | ] 58 | } 59 | 60 | # everything 61 | rule = Rule(filename="hello.pdf", page_numbers=[1, 2], uuid="123") 62 | 63 | assert rule.to_filter() == { 64 | "$and": [ 65 | {"filename": {"$eq": "hello.pdf"}}, 66 | {"uuid": {"$eq": "123"}}, 67 | {"page_number": {"$in": [1, 2]}}, 68 | ] 69 | } 70 | 71 | 72 | class TestPineconeDocument: 73 | def test_generate_id(self): 74 | metadata = PineconeMetadata( 75 | text="hello world", 76 | page_number=1, 77 | chunk_number=0, 78 | filename="hello.pdf", 79 | ) 80 | doc = PineconeDocument( 81 | values=[0.2, 0.3], 82 | metadata=metadata, 83 | ) 84 | 85 | assert doc.id == "hello.pdf-1-0" 86 | 87 | def test_provide_id(self): 88 | metadata = PineconeMetadata( 89 | text="hello world", 90 | page_number=1, 91 | chunk_number=0, 92 | filename="hello.pdf", 93 | ) 94 | doc = PineconeDocument( 95 | values=[0.2, 0.3], 96 | metadata=metadata, 97 | id="custom_id", 98 | ) 99 | 100 | assert doc.id == "custom_id" 101 | 102 | 103 | @pytest.fixture(name="client") 104 | def patched_client(monkeypatch): 105 | """Generate a client instance with patched OpenAI and Pinecone clients.""" 106 | monkeypatch.setenv("OPENAI_API_KEY", "secret_openai") 107 | monkeypatch.setenv("PINECONE_API_KEY", "secret_pinecone") 108 | 109 | fake_pinecone_instance = Mock(spec=Pinecone) 110 | fake_pinecone_class = Mock(return_value=fake_pinecone_instance) 111 | 112 | fake_openai_instance = Mock(spec=OpenAI) 113 | fake_openai_class = Mock(return_value=fake_openai_instance) 114 | 115 | monkeypatch.setattr("whyhow_rbr.rag.Pinecone", fake_pinecone_class) 116 | monkeypatch.setattr("whyhow_rbr.rag.OpenAI", fake_openai_class) 117 | 118 | client = Client() 119 | 120 | assert isinstance(client.openai_client, Mock) 121 | assert isinstance(client.pinecone_client, Mock) 122 | 123 | return client 124 | 125 | 126 | class TestClient: 127 | def test_no_openai_key(self, monkeypatch): 128 | monkeypatch.delenv("OPENAI_API_KEY", raising=False) 129 | monkeypatch.setenv("PINECONE_API_KEY", "whatever") 130 | 131 | with pytest.raises(ValueError, match="No OPENAI_API_KEY"): 132 | Client() 133 | 134 | def test_no_pinecone_key(self, monkeypatch): 135 | monkeypatch.setenv("OPENAI_API_KEY", "whatever") 136 | monkeypatch.delenv("PINECONE_API_KEY", raising=False) 137 | 138 | with pytest.raises(ValueError, match="No PINECONE_API_KEY"): 139 | Client() 140 | 141 | def test_correct_instantiation(self, monkeypatch): 142 | monkeypatch.setenv("OPENAI_API_KEY", "secret_openai") 143 | monkeypatch.setenv("PINECONE_API_KEY", "secret_pinecone") 144 | 145 | fake_pinecone_instance = Mock(spec=Pinecone) 146 | fake_pinecone_class = Mock(return_value=fake_pinecone_instance) 147 | 148 | fake_openai_instance = Mock(spec=OpenAI) 149 | fake_openai_class = Mock(return_value=fake_openai_instance) 150 | 151 | monkeypatch.setattr("whyhow_rbr.rag.Pinecone", fake_pinecone_class) 152 | monkeypatch.setattr("whyhow_rbr.rag.OpenAI", fake_openai_class) 153 | 154 | client = Client() 155 | 156 | assert client.openai_client == fake_openai_instance 157 | assert client.pinecone_client == fake_pinecone_instance 158 | 159 | assert fake_openai_class.call_count == 1 160 | args, kwargs = fake_openai_class.call_args 161 | assert args == () 162 | assert kwargs == {"api_key": "secret_openai"} 163 | 164 | assert fake_pinecone_class.call_count == 1 165 | args, kwargs = fake_pinecone_class.call_args 166 | assert args == () 167 | assert kwargs == {"api_key": "secret_pinecone"} 168 | 169 | def test_get_index(self, client): 170 | client.pinecone_client.Index.return_value = Index("foo", "bar") 171 | 172 | index = client.get_index("something") 173 | assert isinstance(index, Index) 174 | 175 | def side_effect(*args, **kwargs): 176 | raise NotFoundException("Index not found") 177 | 178 | client.pinecone_client.Index.side_effect = side_effect 179 | 180 | with pytest.raises(IndexNotFoundException, match="Index something"): 181 | client.get_index("something") 182 | 183 | def test_create_index(self, client, monkeypatch): 184 | # index does not exist 185 | def side_effect(*args, **kwargs): 186 | raise IndexNotFoundException("Index not found") 187 | 188 | monkeypatch.setattr(client, "get_index", Mock(side_effect=side_effect)) 189 | client.create_index("new_index") 190 | 191 | assert client.pinecone_client.create_index.call_count == 1 192 | assert client.pinecone_client.Index.call_count == 1 193 | 194 | # index exists already 195 | monkeypatch.setattr(client, "get_index", Mock()) 196 | 197 | with pytest.raises( 198 | IndexAlreadyExistsException, match="Index new_index" 199 | ): 200 | client.create_index("new_index") 201 | 202 | assert client.pinecone_client.create_index.call_count == 1 203 | assert client.pinecone_client.Index.call_count == 1 204 | 205 | def test_upload_documents_nothing(self, client, caplog): 206 | caplog.set_level(logging.INFO) 207 | client.upload_documents( 208 | index="index", 209 | namespace="namespace", 210 | documents=[], 211 | ) 212 | 213 | captured = caplog.records[0] 214 | 215 | assert captured.levelname == "INFO" 216 | assert "No documents to upload" in captured.message 217 | 218 | def test_upload_document(self, client, caplog, monkeypatch): 219 | caplog.set_level(logging.INFO) 220 | documents = ["doc1.pdf", "doc2.pdf"] 221 | 222 | # mocking 223 | client.openai_client.api_key = "fake" 224 | fake_index = Mock() 225 | fake_index.upsert = Mock(return_value={"upserted_count": 6}) 226 | 227 | parsed_docs = [ 228 | Document( 229 | page_content="hello there", 230 | metadata={ 231 | "page": 0, 232 | "chunk": 0, 233 | "source": "something", 234 | }, 235 | ), 236 | Document( 237 | page_content="again", 238 | metadata={ 239 | "page": 0, 240 | "chunk": 1, 241 | "source": "something", 242 | }, 243 | ), 244 | Document( 245 | page_content="it is cold", 246 | metadata={ 247 | "page": 1, 248 | "chunk": 0, 249 | "source": "something", 250 | }, 251 | ), 252 | ] 253 | fake_parse_and_split = Mock(return_value=parsed_docs) 254 | fake_clean_chunks = Mock(return_value=parsed_docs) 255 | fake_generate_embeddings = Mock(return_value=6 * [[2.2, 0.6]]) 256 | 257 | monkeypatch.setattr( 258 | "whyhow_rbr.rag.parse_and_split", fake_parse_and_split 259 | ) 260 | monkeypatch.setattr("whyhow_rbr.rag.clean_chunks", fake_clean_chunks) 261 | monkeypatch.setattr( 262 | "whyhow_rbr.rag.generate_embeddings", fake_generate_embeddings 263 | ) 264 | 265 | client.upload_documents( 266 | index=fake_index, 267 | namespace="great_namespace", 268 | documents=documents, 269 | ) 270 | 271 | # assertions mocks 272 | assert fake_parse_and_split.call_count == 2 273 | assert fake_clean_chunks.call_count == 2 274 | assert fake_generate_embeddings.call_count == 1 275 | assert fake_index.upsert.call_count == 1 276 | 277 | # assertions logging 278 | assert "Parsing 2 documents" == caplog.records[0].message 279 | assert "Embedding 6 chunks" == caplog.records[1].message 280 | assert "Upserted 6 documents" == caplog.records[2].message 281 | 282 | def test_upload_document_inconsistent(self, client, caplog, monkeypatch): 283 | documents = ["doc1.pdf"] 284 | caplog.set_level(logging.INFO) 285 | 286 | # mocking 287 | client.openai_client.api_key = "fake" 288 | fake_index = Mock() 289 | 290 | parsed_docs = [ 291 | Document( 292 | page_content="hello there", 293 | metadata={ 294 | "page": 0, 295 | "chunk": 0, 296 | "source": "something", 297 | }, 298 | ), 299 | Document( 300 | page_content="again", 301 | metadata={ 302 | "page": 0, 303 | "chunk": 1, 304 | "source": "something", 305 | }, 306 | ), 307 | Document( 308 | page_content="it is cold", 309 | metadata={ 310 | "page": 1, 311 | "chunk": 0, 312 | "source": "something", 313 | }, 314 | ), 315 | ] 316 | fake_parse_and_split = Mock(return_value=parsed_docs) 317 | fake_clean_chunks = Mock(return_value=parsed_docs) 318 | fake_generate_embeddings = Mock(return_value=5 * [[2.2, 0.6]]) 319 | 320 | monkeypatch.setattr( 321 | "whyhow_rbr.rag.parse_and_split", fake_parse_and_split 322 | ) 323 | monkeypatch.setattr("whyhow_rbr.rag.clean_chunks", fake_clean_chunks) 324 | monkeypatch.setattr( 325 | "whyhow_rbr.rag.generate_embeddings", fake_generate_embeddings 326 | ) 327 | 328 | with pytest.raises( 329 | ValueError, match="Number of embeddings does not match" 330 | ): 331 | client.upload_documents( 332 | index=fake_index, 333 | namespace="great_namespace", 334 | documents=documents, 335 | ) 336 | 337 | # assertions mocks 338 | assert fake_parse_and_split.call_count == 1 339 | assert fake_clean_chunks.call_count == 1 340 | assert fake_generate_embeddings.call_count == 1 341 | assert fake_index.upsert.call_count == 0 342 | 343 | # assertions logging 344 | assert "Parsing 1 documents" == caplog.records[0].message 345 | assert "Embedding 3 chunks" == caplog.records[1].message 346 | 347 | def test_query_no_rules_json_header(self, client, monkeypatch): 348 | # mocking embedding 349 | fake_generate_embeddings = Mock(return_value=[[0.2, 0.3]]) 350 | 351 | # mocking pinecone related stuff 352 | fake_index = Mock() 353 | fake_match = Mock() 354 | fake_match.to_dict.return_value = { 355 | "id": "doc1", 356 | "score": 0.8, 357 | "metadata": { 358 | "filename": "hello.pdf", 359 | "page_number": 1, 360 | "chunk_number": 0, 361 | "text": "hello world", 362 | "uuid": "123", 363 | }, 364 | } 365 | fake_query_response = { 366 | "matches": [fake_match, fake_match, fake_match], 367 | } 368 | fake_index.query = Mock(return_value=fake_query_response) 369 | 370 | # mocking openai related stuff 371 | content = Output( 372 | answer="Hello world", 373 | contexts=[0, 1], 374 | ) 375 | fake_openai_response_rv = ChatCompletion( 376 | id="whatever", 377 | choices=[ 378 | dict( 379 | finish_reason="stop", 380 | index=0, 381 | logprobs=None, 382 | message=dict( 383 | content="```json\n" 384 | + content.model_dump_json() 385 | + "\n```", 386 | role="assistant", 387 | function_call=None, 388 | tool_calls=None, 389 | ), 390 | ) 391 | ], 392 | created=1710065537, 393 | model="gpt-3.5-turbo-0125", 394 | object="chat.completion", 395 | system_fingerprint="whatever", 396 | usage=dict( 397 | completion_tokens=20, prompt_tokens=679, total_tokens=699 398 | ), 399 | ) 400 | 401 | client.openai_client = ( 402 | Mock() 403 | ) # for some reason spec is not working correctly 404 | client.openai_client.api_key = "whatever" 405 | client.openai_client.chat.completions.create.return_value = ( 406 | fake_openai_response_rv 407 | ) 408 | 409 | monkeypatch.setattr(client, "get_index", Mock(return_value=fake_index)) 410 | monkeypatch.setattr( 411 | "whyhow_rbr.rag.generate_embeddings", fake_generate_embeddings 412 | ) 413 | 414 | final_result = client.query( 415 | question="How are you?", 416 | index=fake_index, 417 | namespace="great_namespace", 418 | ) 419 | 420 | assert fake_index.query.call_count == 1 421 | 422 | expected_final_result = { 423 | "answer": "Hello world", 424 | "matches": [ 425 | { 426 | "id": "doc1", 427 | "metadata": { 428 | "chunk_number": 0, 429 | "filename": "hello.pdf", 430 | "page_number": 1, 431 | "text": "hello world", 432 | "uuid": "123", 433 | }, 434 | "score": 0.8, 435 | }, 436 | { 437 | "id": "doc1", 438 | "metadata": { 439 | "chunk_number": 0, 440 | "filename": "hello.pdf", 441 | "page_number": 1, 442 | "text": "hello world", 443 | "uuid": "123", 444 | }, 445 | "score": 0.8, 446 | }, 447 | { 448 | "id": "doc1", 449 | "metadata": { 450 | "chunk_number": 0, 451 | "filename": "hello.pdf", 452 | "page_number": 1, 453 | "text": "hello world", 454 | "uuid": "123", 455 | }, 456 | "score": 0.8, 457 | }, 458 | ], 459 | "used_contexts": [0, 1], 460 | } 461 | 462 | assert final_result == expected_final_result 463 | 464 | def test_query_no_rules_no_json_header(self, client, monkeypatch): 465 | # mocking embedding 466 | fake_generate_embeddings = Mock(return_value=[[0.2, 0.3]]) 467 | 468 | # mocking pinecone related stuff 469 | fake_index = Mock() 470 | fake_match = Mock() 471 | fake_match.to_dict.return_value = { 472 | "id": "doc1", 473 | "score": 0.8, 474 | "metadata": { 475 | "filename": "hello.pdf", 476 | "page_number": 1, 477 | "chunk_number": 0, 478 | "text": "hello world", 479 | "uuid": "123", 480 | }, 481 | } 482 | fake_query_response = { 483 | "matches": [fake_match, fake_match, fake_match], 484 | } 485 | fake_index.query = Mock(return_value=fake_query_response) 486 | 487 | # mocking openai related stuff 488 | content = Output( 489 | answer="The answer is 42", 490 | contexts=[0, 2], 491 | ) 492 | fake_openai_response_rv = ChatCompletion( 493 | id="whatever", 494 | choices=[ 495 | dict( 496 | finish_reason="stop", 497 | index=0, 498 | logprobs=None, 499 | message=dict( 500 | content=content.model_dump_json(), 501 | role="assistant", 502 | function_call=None, 503 | tool_calls=None, 504 | ), 505 | ) 506 | ], 507 | created=1710065537, 508 | model="gpt-3.5-turbo-0125", 509 | object="chat.completion", 510 | system_fingerprint="whatever", 511 | usage=dict( 512 | completion_tokens=20, prompt_tokens=679, total_tokens=699 513 | ), 514 | ) 515 | 516 | client.openai_client = ( 517 | Mock() 518 | ) # for some reason spec is not working correctly 519 | client.openai_client.api_key = "whatever" 520 | client.openai_client.chat.completions.create.return_value = ( 521 | fake_openai_response_rv 522 | ) 523 | 524 | monkeypatch.setattr(client, "get_index", Mock(return_value=fake_index)) 525 | monkeypatch.setattr( 526 | "whyhow_rbr.rag.generate_embeddings", fake_generate_embeddings 527 | ) 528 | 529 | final_result = client.query( 530 | question="How are you?", 531 | index=fake_index, 532 | namespace="great_namespace", 533 | ) 534 | 535 | assert fake_index.query.call_count == 1 536 | 537 | expected_final_result = { 538 | "answer": "The answer is 42", 539 | "matches": [ 540 | { 541 | "id": "doc1", 542 | "metadata": { 543 | "chunk_number": 0, 544 | "filename": "hello.pdf", 545 | "page_number": 1, 546 | "text": "hello world", 547 | "uuid": "123", 548 | }, 549 | "score": 0.8, 550 | }, 551 | { 552 | "id": "doc1", 553 | "metadata": { 554 | "chunk_number": 0, 555 | "filename": "hello.pdf", 556 | "page_number": 1, 557 | "text": "hello world", 558 | "uuid": "123", 559 | }, 560 | "score": 0.8, 561 | }, 562 | { 563 | "id": "doc1", 564 | "metadata": { 565 | "chunk_number": 0, 566 | "filename": "hello.pdf", 567 | "page_number": 1, 568 | "text": "hello world", 569 | "uuid": "123", 570 | }, 571 | "score": 0.8, 572 | }, 573 | ], 574 | "used_contexts": [0, 2], 575 | } 576 | 577 | assert final_result == expected_final_result 578 | prompt = client.openai_client.chat.completions.create.call_args.kwargs[ 579 | "messages" 580 | ][0]["content"] 581 | 582 | assert prompt.count("hello world") == 3 583 | 584 | def test_query_with_rules_no_json_header(self, client, monkeypatch): 585 | # mocking embedding 586 | fake_generate_embeddings = Mock(return_value=[[0.2, 0.3]]) 587 | 588 | # mocking pinecone related stuff 589 | fake_index = Mock() 590 | fake_match = Mock() 591 | fake_match.to_dict.return_value = { 592 | "id": "doc1", 593 | "score": 0.8, 594 | "metadata": { 595 | "filename": "hello.pdf", 596 | "page_number": 1, 597 | "chunk_number": 0, 598 | "text": "hello world", 599 | "uuid": "123", 600 | }, 601 | } 602 | fake_query_response = { 603 | "matches": [fake_match, fake_match, fake_match], 604 | } 605 | fake_index.query = Mock(return_value=fake_query_response) 606 | 607 | # mocking openai related stuff 608 | content = Output( 609 | answer="The answer is 42", 610 | contexts=[0, 2], 611 | ) 612 | fake_openai_response_rv = ChatCompletion( 613 | id="whatever", 614 | choices=[ 615 | dict( 616 | finish_reason="stop", 617 | index=0, 618 | logprobs=None, 619 | message=dict( 620 | content=content.model_dump_json(), 621 | role="assistant", 622 | function_call=None, 623 | tool_calls=None, 624 | ), 625 | ) 626 | ], 627 | created=1710065537, 628 | model="gpt-3.5-turbo-0125", 629 | object="chat.completion", 630 | system_fingerprint="whatever", 631 | usage=dict( 632 | completion_tokens=20, prompt_tokens=679, total_tokens=699 633 | ), 634 | ) 635 | 636 | client.openai_client = ( 637 | Mock() 638 | ) # for some reason spec is not working correctly 639 | client.openai_client.api_key = "whatever" 640 | client.openai_client.chat.completions.create.return_value = ( 641 | fake_openai_response_rv 642 | ) 643 | 644 | monkeypatch.setattr(client, "get_index", Mock(return_value=fake_index)) 645 | monkeypatch.setattr( 646 | "whyhow_rbr.rag.generate_embeddings", fake_generate_embeddings 647 | ) 648 | 649 | _ = client.query( 650 | question="How are you?", 651 | index=fake_index, 652 | namespace="great_namespace", 653 | rules=[ 654 | Rule( 655 | filename="hello.pdf", 656 | page_numbers=[1], 657 | ), 658 | Rule( 659 | page_numbers=[0], 660 | ), 661 | ], 662 | ) 663 | 664 | assert fake_index.query.call_count == 1 665 | kwargs = fake_index.query.call_args.kwargs 666 | 667 | assert kwargs["filter"] == { 668 | "$or": [ 669 | { 670 | "$and": [ 671 | {"filename": {"$eq": "hello.pdf"}}, 672 | {"page_number": {"$in": [1]}}, 673 | ], 674 | }, 675 | {"$and": [{"page_number": {"$in": [0]}}]}, 676 | ] 677 | } 678 | 679 | def test_query_impossible_to_decode(self, client, monkeypatch): 680 | # mocking embedding 681 | fake_generate_embeddings = Mock(return_value=[[0.2, 0.3]]) 682 | 683 | # mocking pinecone related stuff 684 | fake_index = Mock() 685 | fake_match = Mock() 686 | fake_match.to_dict.return_value = { 687 | "id": "doc1", 688 | "score": 0.8, 689 | "metadata": { 690 | "filename": "hello.pdf", 691 | "page_number": 1, 692 | "chunk_number": 0, 693 | "text": "hello world", 694 | "uuid": "123", 695 | }, 696 | } 697 | fake_query_response = { 698 | "matches": [fake_match, fake_match, fake_match], 699 | } 700 | fake_index.query = Mock(return_value=fake_query_response) 701 | 702 | # mocking openai related stuff 703 | fake_openai_response_rv = ChatCompletion( 704 | id="whatever", 705 | choices=[ 706 | dict( 707 | finish_reason="stop", 708 | index=0, 709 | logprobs=None, 710 | message=dict( 711 | content="This is not a JSON", 712 | role="assistant", 713 | function_call=None, 714 | tool_calls=None, 715 | ), 716 | ) 717 | ], 718 | created=1710065537, 719 | model="gpt-3.5-turbo-0125", 720 | object="chat.completion", 721 | system_fingerprint="whatever", 722 | usage=dict( 723 | completion_tokens=20, prompt_tokens=679, total_tokens=699 724 | ), 725 | ) 726 | 727 | client.openai_client = ( 728 | Mock() 729 | ) # for some reason spec is not working correctly 730 | client.openai_client.api_key = "whatever" 731 | client.openai_client.chat.completions.create.return_value = ( 732 | fake_openai_response_rv 733 | ) 734 | 735 | monkeypatch.setattr(client, "get_index", Mock(return_value=fake_index)) 736 | monkeypatch.setattr( 737 | "whyhow_rbr.rag.generate_embeddings", fake_generate_embeddings 738 | ) 739 | 740 | with pytest.raises(OpenAIException, match="OpenAI did not return"): 741 | client.query( 742 | question="How are you?", 743 | index=fake_index, 744 | namespace="great_namespace", 745 | ) 746 | 747 | def test_query_wrong_reason(self, client, monkeypatch): 748 | # mocking embedding 749 | fake_generate_embeddings = Mock(return_value=[[0.2, 0.3]]) 750 | 751 | # mocking pinecone related stuff 752 | fake_index = Mock() 753 | fake_match = Mock() 754 | fake_match.to_dict.return_value = { 755 | "id": "doc1", 756 | "score": 0.8, 757 | "metadata": { 758 | "filename": "hello.pdf", 759 | "page_number": 1, 760 | "chunk_number": 0, 761 | "text": "hello world", 762 | "uuid": "123", 763 | }, 764 | } 765 | fake_query_response = { 766 | "matches": [fake_match, fake_match, fake_match], 767 | } 768 | fake_index.query = Mock(return_value=fake_query_response) 769 | 770 | # mocking openai related stuff 771 | content = Output( 772 | answer="The answer is 42", 773 | contexts=[0, 2], 774 | ) 775 | fake_openai_response_rv = ChatCompletion( 776 | id="whatever", 777 | choices=[ 778 | dict( 779 | finish_reason="length", 780 | index=0, 781 | logprobs=None, 782 | message=dict( 783 | content=content.model_dump_json(), 784 | role="assistant", 785 | function_call=None, 786 | tool_calls=None, 787 | ), 788 | ) 789 | ], 790 | created=1710065537, 791 | model="gpt-3.5-turbo-0125", 792 | object="chat.completion", 793 | system_fingerprint="whatever", 794 | usage=dict( 795 | completion_tokens=20, prompt_tokens=679, total_tokens=699 796 | ), 797 | ) 798 | 799 | client.openai_client = ( 800 | Mock() 801 | ) # for some reason spec is not working correctly 802 | client.openai_client.api_key = "whatever" 803 | client.openai_client.chat.completions.create.return_value = ( 804 | fake_openai_response_rv 805 | ) 806 | 807 | monkeypatch.setattr(client, "get_index", Mock(return_value=fake_index)) 808 | monkeypatch.setattr( 809 | "whyhow_rbr.rag.generate_embeddings", fake_generate_embeddings 810 | ) 811 | 812 | with pytest.raises(OpenAIException, match="Chat did not finish"): 813 | client.query( 814 | question="How are you?", 815 | index=fake_index, 816 | namespace="great_namespace", 817 | ) 818 | --------------------------------------------------------------------------------