├── .github └── workflows │ ├── deploy-docs.yml │ ├── manual-workflow.yml │ └── run-tests.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── docs ├── api │ ├── chat.md │ └── knowledge_base.md ├── community │ ├── contributing.md │ ├── professional-services.md │ └── support.md ├── concepts │ ├── chat.md │ ├── citations.md │ ├── components.md │ ├── config.md │ ├── knowledge-bases.md │ ├── logging.md │ ├── overview.md │ └── vlm.md ├── getting-started │ ├── installation.md │ └── quickstart.md ├── index.md └── logo.png ├── dsrag ├── __init__.py ├── add_document.py ├── auto_context.py ├── auto_query.py ├── chat │ ├── __init__.py │ ├── auto_query.py │ ├── chat.py │ ├── chat_types.py │ ├── citations.py │ ├── instructor_get_response.py │ └── model_names.py ├── create_kb.py ├── custom_term_mapping.py ├── database │ ├── __init__.py │ ├── chat_thread │ │ ├── basic_db.py │ │ ├── db.py │ │ └── sqlite_db.py │ ├── chunk │ │ ├── __init__.py │ │ ├── basic_db.py │ │ ├── db.py │ │ ├── dynamo_db.py │ │ ├── postgres_db.py │ │ ├── sqlite_db.py │ │ └── types.py │ └── vector │ │ ├── __init__.py │ │ ├── basic_db.py │ │ ├── chroma_db.py │ │ ├── db.py │ │ ├── milvus_db.py │ │ ├── pinecone_db.py │ │ ├── postgres_db.py │ │ ├── qdrant_db.py │ │ ├── types.py │ │ └── weaviate_db.py ├── dsparse │ ├── MANIFEST.in │ ├── README.md │ ├── __init__.py │ ├── file_parsing │ │ ├── __init__.py │ │ ├── element_types.py │ │ ├── file_system.py │ │ ├── non_vlm_file_parsing.py │ │ ├── vlm.py │ │ └── vlm_file_parsing.py │ ├── main.py │ ├── models │ │ ├── __init__.py │ │ └── types.py │ ├── pyproject.toml │ ├── requirements.in │ ├── requirements.txt │ ├── sectioning_and_chunking │ │ ├── __init__.py │ │ ├── chunking.py │ │ └── semantic_sectioning.py │ ├── tests │ │ ├── data │ │ │ ├── chunks.json │ │ │ ├── document_lines.json │ │ │ ├── elements.json │ │ │ └── sections.json │ │ ├── integration │ │ │ ├── test_semantic_sectioning.py │ │ │ └── test_vlm_file_parsing.py │ │ └── unit │ │ │ ├── test_chunking.py │ │ │ ├── test_file_system.py │ │ │ └── test_semantic_sectioning.py │ └── utils │ │ ├── __init__.py │ │ └── imports.py ├── embedding.py ├── knowledge_base.py ├── llm.py ├── metadata.py ├── reranker.py ├── rse.py └── utils │ └── imports.py ├── eval ├── financebench │ ├── dsrag_financebench_create_kb.py │ └── dsrag_financebench_eval.py ├── rse_eval.py ├── streaming_chat_test.py └── vlm_eval.py ├── examples ├── dsRAG_motivation.ipynb ├── example_kb_data │ ├── chunk_storage │ │ ├── levels_of_agi.pkl │ │ └── nike_10k.pkl │ ├── metadata │ │ ├── levels_of_agi.json │ │ └── nike_10k.json │ └── vector_storage │ │ ├── levels_of_agi.pkl │ │ └── nike_10k.pkl └── logging_example.py ├── integrations └── langchain_retriever.py ├── mkdocs.yml ├── pyproject.toml ├── requirements.in ├── requirements.txt └── tests ├── data ├── chunks.json ├── document_lines.json ├── elements.json ├── financebench_sample_150.csv ├── les_miserables.txt ├── levels_of_agi.pdf ├── mck_energy_first_5_pages.pdf ├── nike_2023_annual_report.txt ├── page_7.png └── sections.json ├── integration ├── test_chat_w_citations.py ├── test_chunk_and_segment_headers.py ├── test_create_kb_and_query.py ├── test_create_kb_and_query_chromadb.py ├── test_create_kb_and_query_weaviate.py ├── test_custom_embedding_subclass.py ├── test_custom_langchain_retriever.py ├── test_save_and_load.py └── test_vlm_file_parsing.py └── unit ├── test_auto_context.py ├── test_chunk_db.py ├── test_chunking.py ├── test_embedding.py ├── test_file_system.py ├── test_llm_class.py ├── test_reranker.py └── test_vector_db.py /.github/workflows/deploy-docs.yml: -------------------------------------------------------------------------------- 1 | name: Deploy Documentation 2 | on: 3 | push: 4 | branches: 5 | - main 6 | workflow_dispatch: 7 | 8 | # Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages 9 | permissions: 10 | contents: read 11 | pages: write 12 | id-token: write 13 | 14 | # Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued. 15 | concurrency: 16 | group: "pages" 17 | cancel-in-progress: false 18 | 19 | jobs: 20 | build: 21 | runs-on: ubuntu-latest 22 | steps: 23 | - uses: actions/checkout@v4 24 | 25 | - name: Setup Pages 26 | uses: actions/configure-pages@v5 27 | 28 | - uses: actions/setup-python@v4 29 | with: 30 | python-version: '3.12' 31 | 32 | - name: Install dependencies 33 | run: | 34 | python -m pip install --upgrade pip 35 | pip install mkdocs-material 36 | pip install mkdocstrings[python] 37 | # Install the package in development mode 38 | pip install -e . 39 | 40 | - name: Build with MkDocs 41 | run: mkdocs build 42 | 43 | - name: Upload artifact 44 | uses: actions/upload-pages-artifact@v3 45 | with: 46 | path: ./site 47 | 48 | # Deployment job 49 | deploy: 50 | environment: 51 | name: github-pages 52 | url: ${{ steps.deployment.outputs.page_url }} 53 | runs-on: ubuntu-latest 54 | needs: build 55 | steps: 56 | - name: Deploy to GitHub Pages 57 | id: deployment 58 | uses: actions/deploy-pages@v4 -------------------------------------------------------------------------------- /.github/workflows/manual-workflow.yml: -------------------------------------------------------------------------------- 1 | name: Manual Test Run 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | repository: 7 | description: 'Repository to run tests on (e.g., user/repo)' 8 | required: true 9 | branch: 10 | description: 'Branch to run tests on' 11 | required: true 12 | 13 | jobs: 14 | run-tests: 15 | runs-on: ubuntu-latest 16 | 17 | env: 18 | ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} 19 | CO_API_KEY: ${{ secrets.CO_API_KEY }} 20 | OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} 21 | VOYAGE_API_KEY: ${{ secrets.VOYAGE_API_KEY }} 22 | 23 | steps: 24 | - name: Checkout code 25 | uses: actions/checkout@v2 26 | 27 | - name: Set up Python 28 | uses: actions/setup-python@v2 29 | with: 30 | python-version: '3.x' 31 | 32 | - name: Install dependencies 33 | run: | 34 | python -m pip install --upgrade pip 35 | pip install -e ".[all-dbs,all-models]" # Install package with optional dependencies 36 | pip install pytest 37 | 38 | - name: Install dsParse dependencies 39 | run: | 40 | cd dsrag/dsparse && pip install -r requirements.txt 41 | 42 | - name: Install additional test dependencies 43 | run: | 44 | pip install vertexai psycopg2-binary 45 | 46 | - name: Run unit tests 47 | run: | 48 | pytest tests/unit 49 | 50 | - name: Run integration tests 51 | run: | 52 | pytest tests/integration 53 | -------------------------------------------------------------------------------- /.github/workflows/run-tests.yml: -------------------------------------------------------------------------------- 1 | name: Run Tests 2 | 3 | on: 4 | pull_request_review: 5 | types: [submitted] 6 | 7 | jobs: 8 | run-tests: 9 | if: github.event.review.state == 'approved' 10 | runs-on: ubuntu-latest 11 | 12 | env: 13 | ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} 14 | CO_API_KEY: ${{ secrets.CO_API_KEY }} 15 | OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} 16 | VOYAGE_API_KEY: ${{ secrets.VOYAGE_API_KEY }} 17 | AWS_S3_BUCKET_NAME: ${{ secrets.AWS_S3_BUCKET_NAME }} 18 | AWS_S3_REGION: ${{ secrets.AWS_S3_REGION }} 19 | AWS_S3_ACCESS_KEY: ${{ secrets.AWS_S3_ACCESS_KEY }} 20 | AWS_S3_SECRET_KEY: ${{ secrets.AWS_S3_SECRET_KEY }} 21 | GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} 22 | AWS_REGION: ${{ secrets.AWS_REGION }} 23 | AWS_DYNAMO_ACCESS_KEY: ${{ secrets.AWS_DYNAMO_ACCESS_KEY }} 24 | AWS_DYNAMO_SECRET_KEY: ${{ secrets.AWS_DYNAMO_SECRET_KEY }} 25 | PINECONE_API_KEY: ${{ secrets.PINECONE_API_KEY }} 26 | POSTGRES_DB: ${{ secrets.POSTGRES_DB }} 27 | POSTGRES_USER: ${{ secrets.POSTGRES_USER }} 28 | POSTGRES_PASSWORD: ${{ secrets.POSTGRES_PASSWORD }} 29 | POSTGRES_HOST: ${{ secrets.POSTGRES_HOST }} 30 | 31 | steps: 32 | - name: Checkout code 33 | uses: actions/checkout@v2 34 | 35 | - name: Set up Python 36 | uses: actions/setup-python@v2 37 | with: 38 | python-version: '3.12' 39 | 40 | - name: Install Poppler 41 | run: sudo apt-get install -y poppler-utils 42 | 43 | - name: Install dependencies 44 | run: | 45 | python -m pip install --upgrade pip 46 | pip install -e ".[all-dbs,all-models]" # Install package with optional dependencies 47 | pip install pytest 48 | 49 | - name: Install dsParse dependencies 50 | run: | 51 | cd dsrag/dsparse && pip install -r requirements.txt 52 | 53 | - name: Install additional test dependencies 54 | run: | 55 | pip install vertexai psycopg2-binary 56 | 57 | - name: Run unit tests 58 | run: | 59 | pytest tests/unit 60 | 61 | - name: Run integration tests 62 | run: | 63 | pytest tests/integration 64 | 65 | - name: Run dsParse unit tests 66 | run: | 67 | pytest dsrag/dsparse/tests/unit 68 | 69 | - name: Run dsParse integration tests 70 | run: | 71 | pytest dsrag/dsparse/tests/integration 72 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | 162 | # Embedded Weaviate data 163 | weaviate 164 | 165 | # DS_Store files 166 | .DS_Store 167 | 168 | .python-version -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 D-Star 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | exclude dsrag/dsparse/requirements.txt 2 | exclude dsrag/dsparse/requirements.in 3 | exclude dsrag/dsparse/MANIFEST.in 4 | exclude dsrag/dsparse/README.md 5 | exclude dsrag/dsparse/pyproject.toml 6 | -------------------------------------------------------------------------------- /docs/api/chat.md: -------------------------------------------------------------------------------- 1 | # Chat 2 | 3 | The Chat module provides functionality for managing chat-based interactions with knowledge bases. It handles chat thread creation, message history, and generating responses using knowledge base search. 4 | 5 | ### Public Methods 6 | 7 | #### create_new_chat_thread 8 | 9 | ```python 10 | def create_new_chat_thread(chat_thread_params: ChatThreadParams, chat_thread_db: ChatThreadDB) -> str 11 | ``` 12 | 13 | Create a new chat thread in the database. 14 | 15 | **Arguments**: 16 | - `chat_thread_params`: Parameters for the chat thread. Example: 17 | ```python 18 | { 19 | # Knowledge base IDs to use 20 | "kb_ids": ["kb1", "kb2"], 21 | 22 | # LLM model to use 23 | "model": "gpt-4o-mini", 24 | 25 | # Temperature for LLM sampling 26 | "temperature": 0.2, 27 | 28 | # System message for LLM 29 | "system_message": "You are a helpful assistant", 30 | 31 | # Model for auto-query generation 32 | "auto_query_model": "gpt-4o-mini", 33 | 34 | # Guidance for auto-query generation 35 | "auto_query_guidance": "", 36 | 37 | # Target response length (short/medium/long) 38 | "target_output_length": "medium", 39 | 40 | # Maximum tokens in chat history 41 | "max_chat_history_tokens": 8000, 42 | 43 | # Optional supplementary ID 44 | "supp_id": "" 45 | } 46 | ``` 47 | - `chat_thread_db`: Database instance for storing chat threads. 48 | 49 | **Returns**: Unique identifier (str) for the created chat thread. 50 | 51 | #### get_chat_thread_response 52 | 53 | ```python 54 | def get_chat_thread_response(thread_id: str, get_response_input: ChatResponseInput, 55 | chat_thread_db: ChatThreadDB, knowledge_bases: dict, stream: bool = False) 56 | ``` 57 | 58 | Get a response for a chat thread using knowledge base search, with optional streaming support. 59 | 60 | **Arguments**: 61 | - `thread_id`: Unique identifier for the chat thread. 62 | - `get_response_input`: Input parameters containing: 63 | - `user_input`: User's message text 64 | - `chat_thread_params`: Optional parameter overrides 65 | - `metadata_filter`: Optional search filter 66 | - `chat_thread_db`: Database instance for chat threads. 67 | - `knowledge_bases`: Dictionary mapping knowledge base IDs to instances. 68 | - `stream`: Whether to stream the response (defaults to False). 69 | 70 | **Returns**: 71 | If `stream=False` (default): 72 | - Formatted interaction containing: 73 | - `user_input`: User message with content and timestamp 74 | - `model_response`: Model response with content and timestamp 75 | - `search_queries`: Generated search queries 76 | - `relevant_segments`: Retrieved relevant segments with file names and types 77 | - `message`: Error message if something went wrong (optional) 78 | 79 | If `stream=True`: 80 | - Generator that yields partial response objects with the same structure as above 81 | - The final response is automatically saved to the chat thread database 82 | 83 | ### Chat Types 84 | 85 | #### ChatThreadParams 86 | 87 | Type definition for chat thread parameters. 88 | 89 | ```python 90 | ChatThreadParams = TypedDict('ChatThreadParams', { 91 | 'kb_ids': List[str], # List of knowledge base IDs to use 92 | 'model': str, # LLM model name (e.g., "gpt-4o-mini") 93 | 'temperature': float, # Temperature for LLM sampling (0.0-1.0) 94 | 'system_message': str, # Custom system message for LLM 95 | 'auto_query_model': str, # Model for generating search queries 96 | 'auto_query_guidance': str, # Custom guidance for query generation 97 | 'target_output_length': str, # Response length ("short", "medium", "long") 98 | 'max_chat_history_tokens': int, # Maximum tokens in chat history 99 | 'thread_id': str, # Unique thread identifier (auto-generated) 100 | 'supp_id': str, # Optional supplementary identifier 101 | }) 102 | ``` 103 | 104 | #### ChatResponseInput 105 | 106 | Input parameters for getting a chat response. 107 | 108 | ```python 109 | ChatResponseInput = TypedDict('ChatResponseInput', { 110 | 'user_input': str, # User's message text 111 | 'chat_thread_params': Optional[ChatThreadParams], # Optional parameter overrides 112 | 'metadata_filter': Optional[MetadataFilter], # Optional search filter 113 | }) 114 | ``` 115 | 116 | #### MetadataFilter 117 | 118 | Filter criteria for knowledge base searches. 119 | 120 | ```python 121 | MetadataFilter = TypedDict('MetadataFilter', { 122 | 'field_name': str, # Name of the metadata field to filter on 123 | 'field_value': Any, # Value to match against 124 | 'comparison_type': str, # Type of comparison ("equals", "contains", etc.) 125 | }, total=False) 126 | ``` 127 | 128 | You can use these types to ensure type safety when working with the chat functions. For example: 129 | 130 | ```python 131 | # Create chat thread parameters 132 | params: ChatThreadParams = { 133 | "kb_ids": ["kb1"], 134 | "model": "gpt-4o-mini", 135 | "temperature": 0.2, 136 | "system_message": "You are a helpful assistant", 137 | "auto_query_model": "gpt-4o-mini", 138 | "auto_query_guidance": "", 139 | "target_output_length": "medium", 140 | "max_chat_history_tokens": 8000, 141 | "supp_id": "" 142 | } 143 | 144 | # Create chat response input 145 | response_input: ChatResponseInput = { 146 | "user_input": "What is the capital of France?", 147 | "chat_thread_params": None, 148 | "metadata_filter": None 149 | } 150 | ``` -------------------------------------------------------------------------------- /docs/api/knowledge_base.md: -------------------------------------------------------------------------------- 1 | # KnowledgeBase 2 | 3 | The KnowledgeBase class is the main interface for working with dsRAG. It handles document processing, storage, and retrieval. 4 | 5 | ### Public Methods 6 | 7 | The following methods are part of the public API: 8 | 9 | - `__init__`: Initialize a new KnowledgeBase instance 10 | - `add_document`: Add a single document to the knowledge base 11 | - `add_documents`: Add multiple documents in parallel 12 | - `delete`: Delete the entire knowledge base and all associated data 13 | - `delete_document`: Delete a specific document from the knowledge base 14 | - `query`: Search the knowledge base with one or more queries 15 | 16 | ::: dsrag.knowledge_base.KnowledgeBase 17 | options: 18 | show_root_heading: false 19 | show_root_full_path: false 20 | members: 21 | - __init__ 22 | - add_document 23 | - add_documents 24 | - delete 25 | - delete_document 26 | - query 27 | 28 | ## KB Components 29 | 30 | ### Vector Databases 31 | 32 | ::: dsrag.database.vector.VectorDB 33 | options: 34 | show_root_heading: true 35 | show_root_full_path: false 36 | 37 | ### Chunk Databases 38 | 39 | ::: dsrag.database.chunk.ChunkDB 40 | options: 41 | show_root_heading: true 42 | show_root_full_path: false 43 | 44 | ### Embedding Models 45 | 46 | ::: dsrag.embedding.Embedding 47 | options: 48 | show_root_heading: true 49 | show_root_full_path: false 50 | 51 | ### Rerankers 52 | 53 | ::: dsrag.reranker.Reranker 54 | options: 55 | show_root_heading: true 56 | show_root_full_path: false 57 | 58 | ### LLM Providers 59 | 60 | ::: dsrag.llm.LLM 61 | options: 62 | show_root_heading: true 63 | show_root_full_path: false 64 | 65 | ### File Systems 66 | 67 | ::: dsrag.dsparse.file_parsing.file_system.FileSystem 68 | options: 69 | show_root_heading: true 70 | show_root_full_path: false -------------------------------------------------------------------------------- /docs/community/contributing.md: -------------------------------------------------------------------------------- 1 | # Contributing to dsRAG 2 | 3 | We welcome contributions from the community! Whether it's fixing bugs, improving documentation, or proposing new features, your contributions make dsRAG better for everyone. 4 | 5 | ## Getting Started 6 | 7 | 1. Fork the repository on GitHub 8 | 2. Clone your fork locally: 9 | ```bash 10 | git clone https://github.com/your-username/dsRAG.git 11 | cd dsRAG 12 | ``` 13 | 3. Create a new branch for your changes: 14 | ```bash 15 | git checkout -b feature/your-feature-name 16 | ``` 17 | 18 | ## Development Setup 19 | 20 | 1. Create a virtual environment: 21 | ```bash 22 | python -m venv venv 23 | source venv/bin/activate # On Windows: venv\Scripts\activate 24 | ``` 25 | 26 | 2. Install development dependencies: 27 | ```bash 28 | pip install -e ".[dev]" 29 | ``` 30 | 31 | ## Making Changes 32 | 33 | 1. Make your changes in your feature branch 34 | 2. Write or update tests as needed 35 | 3. Update documentation to reflect your changes 36 | 4. Run the test suite to ensure everything works: 37 | ```bash 38 | python -m unittest discover 39 | ``` 40 | 41 | ## Code Style 42 | 43 | We follow standard Python coding conventions: 44 | - Use [PEP 8](https://peps.python.org/pep-0008/) style guide 45 | - Use meaningful variable and function names 46 | - Write docstrings for functions and classes 47 | - Keep functions focused and concise 48 | - Add comments for complex logic 49 | 50 | ## Documentation 51 | 52 | When adding or modifying features: 53 | - Update docstrings for any modified functions/classes 54 | - Update relevant documentation in the `docs/` directory 55 | - Add examples if appropriate 56 | - Ensure documentation builds without errors: 57 | ```bash 58 | mkdocs serve 59 | ``` 60 | 61 | ## Testing 62 | 63 | For new features or bug fixes: 64 | - Add appropriate unit tests in the `tests/` directory 65 | - Tests should subclass `unittest.TestCase` 66 | - Follow the existing test file naming pattern: `test_*.py` 67 | - Update existing tests if needed 68 | - Ensure all tests pass locally 69 | - Maintain or improve code coverage 70 | 71 | Example test structure: 72 | ```python 73 | import unittest 74 | from dsrag import YourModule 75 | 76 | class TestYourFeature(unittest.TestCase): 77 | def setUp(self): 78 | # Set up any test fixtures 79 | pass 80 | 81 | def test_your_feature(self): 82 | # Test your new functionality 83 | result = YourModule.your_function() 84 | self.assertEqual(result, expected_value) 85 | 86 | def tearDown(self): 87 | # Clean up after tests 88 | pass 89 | ``` 90 | 91 | ## Submitting Changes 92 | 93 | 1. Commit your changes: 94 | ```bash 95 | git add . 96 | git commit -m "Description of your changes" 97 | ``` 98 | 99 | 2. Push to your fork: 100 | ```bash 101 | git push origin feature/your-feature-name 102 | ``` 103 | 104 | 3. Open a Pull Request: 105 | - Go to the dsRAG repository on GitHub 106 | - Click "Pull Request" 107 | - Select your feature branch 108 | - Describe your changes and their purpose 109 | - Reference any related issues 110 | 111 | ## Pull Request Guidelines 112 | 113 | Your PR should: 114 | 115 | - Focus on a single feature or fix 116 | - Include appropriate tests 117 | - Update relevant documentation 118 | - Follow the code style guidelines 119 | - Include a clear description of the changes 120 | - Reference any related issues 121 | 122 | ## Getting Help 123 | 124 | If you need help with your contribution: 125 | 126 | - Join our [Discord](https://discord.gg/NTUVX9DmQ3) 127 | - Ask questions in the PR 128 | - Tag maintainers if you need specific guidance 129 | 130 | ## Code of Conduct 131 | 132 | - Be respectful and inclusive 133 | - Focus on constructive feedback 134 | - Help others learn and grow 135 | - Follow project maintainers' guidance 136 | 137 | Thank you for contributing to dsRAG! -------------------------------------------------------------------------------- /docs/community/professional-services.md: -------------------------------------------------------------------------------- 1 | # Professional Services 2 | 3 | ## About Our Team 4 | 5 | The creators of dsRAG, Zach and Nick McCormick, run a specialized applied AI consulting firm. As former startup founders and YC alums, we bring a unique business and product-centric perspective to highly technical engineering projects. 6 | 7 | ## Our Expertise 8 | 9 | We specialize in: 10 | 11 | - Building high-performance RAG-based applications 12 | - Optimizing existing RAG systems 13 | - Integrating advanced retrieval techniques 14 | - Designing scalable AI architectures 15 | 16 | ## Services Offered 17 | 18 | We provide: 19 | 20 | - Advisory services 21 | - Implementation work 22 | - Performance optimization 23 | - Custom feature development 24 | - Technical architecture design 25 | 26 | ## Getting Started 27 | 28 | If you need professional assistance with: 29 | 30 | - Implementing dsRAG in your production environment 31 | - Optimizing your RAG-based application 32 | - Custom feature development 33 | - Technical consultation 34 | 35 | Please fill out our [contact form](https://forms.gle/zbQwDJp7pBQKtqVT8) and we'll be in touch to discuss how we can help with your specific needs. -------------------------------------------------------------------------------- /docs/community/support.md: -------------------------------------------------------------------------------- 1 | # Community Support 2 | 3 | ## Discord Community 4 | 5 | Join our [Discord](https://discord.gg/NTUVX9DmQ3) to: 6 | 7 | - Ask questions about dsRAG 8 | - Make suggestions for improvements 9 | - Discuss potential contributions 10 | - Connect with other developers 11 | - Share your experiences 12 | 13 | ## Production Use Cases 14 | 15 | If you're using (or planning to use) dsRAG in production: 16 | 17 | 1. Please fill out our [use case form](https://forms.gle/RQ5qFVReonSHDcCu5) 18 | 19 | 2. This helps us: 20 | - Understand how dsRAG is being used 21 | - Prioritize new features 22 | - Improve documentation 23 | - Focus on the most important issues 24 | 25 | 3. In return, you'll receive: 26 | - Direct email contact with the creators 27 | - Priority email support 28 | - Early access to new features -------------------------------------------------------------------------------- /docs/concepts/citations.md: -------------------------------------------------------------------------------- 1 | # Citations 2 | 3 | dsRAG's citation system ensures that responses are grounded in your knowledge base content and provides transparency about information sources. Citations are included automatically in chat responses. 4 | 5 | ## Overview 6 | 7 | Citations in dsRAG: 8 | 9 | - Track the source of information in responses 10 | - Include document IDs and page numbers 11 | - Provide exact quoted text from sources 12 | 13 | ## Citation Structure 14 | 15 | Each citation includes: 16 | 17 | - `doc_id`: Unique identifier of the source document 18 | - `page_number`: Page where the information was found (if available) 19 | - `cited_text`: Exact text containing the cited information 20 | - `kb_id`: ID of the knowledge base containing the document 21 | 22 | ## Working with Citations 23 | 24 | Citations are automatically included in chat responses: 25 | 26 | ```python 27 | from dsrag.chat.chat import get_chat_thread_response 28 | from dsrag.chat.chat_types import ChatResponseInput 29 | from dsrag.database.chat_thread.sqlite_db import SQLiteChatThreadDB 30 | 31 | # Create the knowledge base instances 32 | knowledge_bases = { 33 | "my_knowledge_base": KnowledgeBase(kb_id="my_knowledge_base") 34 | } 35 | 36 | # Initialize database and get response 37 | chat_thread_db = SQLiteChatThreadDB() 38 | response = get_chat_thread_response( 39 | thread_id=thread_id, 40 | get_response_input=ChatResponseInput( 41 | user_input="What is the system architecture?" 42 | ), 43 | chat_thread_db=chat_thread_db, 44 | knowledge_bases=knowledge_bases 45 | ) 46 | 47 | # Access citations 48 | citations = response["model_response"]["citations"] 49 | for citation in citations: 50 | print(f""" 51 | Source: {citation['doc_id']} 52 | Page: {citation['page_number']} 53 | Text: {citation['cited_text']} 54 | Knowledge Base: {citation['kb_id']} 55 | """) 56 | ``` 57 | 58 | ## Technical Details 59 | 60 | Citations are managed through the `ResponseWithCitations` model: 61 | 62 | ```python 63 | from dsrag.chat.citations import ResponseWithCitations, Citation 64 | 65 | response = ResponseWithCitations( 66 | response="The system architecture is...", 67 | citations=[ 68 | Citation( 69 | doc_id="arch-doc-v1", 70 | page_number=12, 71 | cited_text="The system uses a microservices architecture" 72 | ) 73 | ] 74 | ) 75 | ``` 76 | 77 | This structured format ensures consistent citation handling throughout the system. -------------------------------------------------------------------------------- /docs/concepts/components.md: -------------------------------------------------------------------------------- 1 | # Components 2 | 3 | There are six key components that define the configuration of a KnowledgeBase. Each component is customizable, with several built-in options available. 4 | 5 | ## VectorDB 6 | 7 | The VectorDB component stores embedding vectors and associated metadata. 8 | 9 | Available options: 10 | 11 | - `BasicVectorDB` 12 | - `WeaviateVectorDB` 13 | - `ChromaDB` 14 | - `QdrantVectorDB` 15 | - `MilvusDB` 16 | - `PineconeDB` 17 | 18 | ## ChunkDB 19 | 20 | The ChunkDB stores the content of text chunks in a nested dictionary format, keyed on `doc_id` and `chunk_index`. This is used by RSE to retrieve the full text associated with specific chunks. 21 | 22 | Available options: 23 | 24 | - `BasicChunkDB` 25 | - `SQLiteDB` 26 | 27 | ## Embedding 28 | 29 | The Embedding component defines the embedding model used for vectorizing text. 30 | 31 | Available options: 32 | 33 | - `OpenAIEmbedding` 34 | - `CohereEmbedding` 35 | - `VoyageAIEmbedding` 36 | - `OllamaEmbedding` 37 | 38 | ## Reranker 39 | 40 | The Reranker component provides more accurate ranking of chunks after vector database search and before RSE. This is optional but highly recommended. 41 | 42 | Available options: 43 | 44 | - `CohereReranker` 45 | - `VoyageReranker` 46 | - `NoReranker` 47 | 48 | ## LLM 49 | 50 | The LLM component is used in AutoContext for: 51 | 52 | - Document title generation 53 | - Document summarization 54 | - Section summarization 55 | 56 | Available options: 57 | 58 | - `OpenAIChatAPI` 59 | - `AnthropicChatAPI` 60 | - `OllamaChatAPI` 61 | - `GeminiAPI` 62 | 63 | ## FileSystem 64 | 65 | The FileSystem component defines where to save PDF images and extracted elementsfor VLM file parsing. 66 | 67 | Available options: 68 | 69 | - `LocalFileSystem` 70 | - `S3FileSystem` 71 | 72 | #### LocalFileSystem Configuration 73 | Only requires a `base_path` parameter to define where files will be stored on the system. 74 | 75 | #### S3FileSystem Configuration 76 | Requires the following parameters: 77 | 78 | - `base_path`: Used when downloading files from S3 79 | - `bucket_name`: S3 bucket name 80 | - `region_name`: AWS region 81 | - `access_key`: AWS access key 82 | - `access_secret`: AWS secret key 83 | 84 | Note: Files must be stored locally temporarily for use in the retrieval system, even when using S3. -------------------------------------------------------------------------------- /docs/concepts/config.md: -------------------------------------------------------------------------------- 1 | # Configuration 2 | 3 | dsRAG uses several configuration dictionaries to organize its many parameters. These configs can be passed to different methods of the KnowledgeBase class. 4 | 5 | ## Document Addition Configs 6 | 7 | The following configuration dictionaries can be passed to the `add_document` method: 8 | 9 | ### AutoContext Config 10 | 11 | ```python 12 | auto_context_config = { 13 | "use_generated_title": bool, # whether to use an LLM-generated title if no title is provided (default: True) 14 | "document_title_guidance": str, # Additional guidance for generating the document title (default: "") 15 | "get_document_summary": bool, # whether to get a document summary (default: True) 16 | "document_summarization_guidance": str, # Additional guidance for summarizing the document (default: "") 17 | "get_section_summaries": bool, # whether to get section summaries (default: False) 18 | "section_summarization_guidance": str # Additional guidance for summarizing the sections (default: "") 19 | } 20 | ``` 21 | 22 | ### File Parsing Config 23 | 24 | ```python 25 | file_parsing_config = { 26 | "use_vlm": bool, # whether to use VLM for parsing the file (default: False) 27 | "vlm_config": { 28 | "provider": str, # the VLM provider to use (default: "gemini", "vertex_ai" is also supported) 29 | "model": str, # the VLM model to use (default: "gemini-2.0-flash") 30 | "project_id": str, # GCP project ID (only required for "vertex_ai") 31 | "location": str, # GCP location (only required for "vertex_ai") 32 | "save_path": str, # path to save intermediate files during VLM processing 33 | "exclude_elements": list, # element types to exclude (default: ["Header", "Footer"]) 34 | "images_already_exist": bool # whether images are pre-extracted (default: False) 35 | }, 36 | "always_save_page_images": bool # save page images even if VLM is not used (default: False) 37 | } 38 | ``` 39 | 40 | ### Semantic Sectioning Config 41 | 42 | ```python 43 | semantic_sectioning_config = { 44 | "llm_provider": str, # LLM provider (default: "openai", "anthropic" and "gemini" are also supported) 45 | "model": str, # LLM model to use (default: "gpt-4o-mini") 46 | "use_semantic_sectioning": bool # if False, skip semantic sectioning (default: True) 47 | } 48 | ``` 49 | 50 | ### Chunking Config 51 | 52 | ```python 53 | chunking_config = { 54 | "chunk_size": int, # maximum characters per chunk (default: 800) 55 | "min_length_for_chunking": int # minimum text length to allow chunking (default: 2000) 56 | } 57 | ``` 58 | 59 | ## Query Config 60 | 61 | The following configuration dictionary can be passed to the `query` method: 62 | 63 | ### RSE Parameters 64 | 65 | ```python 66 | rse_params = { 67 | "max_length": int, # maximum segment length in chunks (default: 15) 68 | "overall_max_length": int, # maximum total length of all segments (default: 30) 69 | "minimum_value": float, # minimum relevance value for segments (default: 0.5) 70 | "irrelevant_chunk_penalty": float, # penalty for irrelevant chunks (0-1) (default: 0.18) 71 | "overall_max_length_extension": int, # length increase per additional query (default: 5) 72 | "decay_rate": float, # rate at which relevance decays (default: 30) 73 | "top_k_for_document_selection": int, # maximum number of documents to consider (default: 10) 74 | "chunk_length_adjustment": bool # whether to scale by chunk length (default: True) 75 | } 76 | ``` 77 | 78 | ## Metadata Query Filters 79 | 80 | Some vector databases (currently only ChromaDB) support metadata filtering during queries. This allows for more controlled document selection. 81 | 82 | Example metadata filter format: 83 | ```python 84 | metadata_filter = { 85 | "field": str, # The metadata field to filter by 86 | "operator": str, # One of: 'equals', 'not_equals', 'in', 'not_in', 87 | # 'greater_than', 'less_than', 'greater_than_equals', 'less_than_equals' 88 | "value": str | int | float | list # If list, all items must be same type 89 | } 90 | ``` 91 | 92 | Example usage: 93 | ```python 94 | # Filter with "equals" operator 95 | metadata_filter = { 96 | "field": "doc_id", 97 | "operator": "equals", 98 | "value": "test_id_1" 99 | } 100 | 101 | # Filter with "in" operator 102 | metadata_filter = { 103 | "field": "doc_id", 104 | "operator": "in", 105 | "value": ["test_id_1", "test_id_2"] 106 | } 107 | ``` -------------------------------------------------------------------------------- /docs/concepts/knowledge-bases.md: -------------------------------------------------------------------------------- 1 | # Knowledge Bases 2 | 3 | A knowledge base in dsRAG is a searchable collection of documents that can be queried to find relevant information. The `KnowledgeBase` class handles document processing, storage, and retrieval. 4 | 5 | ## Creating a Knowledge Base 6 | 7 | To create a knowledge base: 8 | 9 | ```python 10 | from dsrag.knowledge_base import KnowledgeBase 11 | 12 | # Create a basic knowledge base 13 | kb = KnowledgeBase( 14 | kb_id="my_kb", 15 | title="Product Documentation", 16 | description="Technical documentation for XYZ product" 17 | ) 18 | 19 | # Or with custom configuration 20 | kb = KnowledgeBase( 21 | kb_id="my_kb", 22 | storage_directory="path/to/storage", # Where to store KB data 23 | embedding_model=custom_embedder, # Custom embedding model 24 | reranker=custom_reranker, # Custom reranking model 25 | vector_db=custom_vector_db, # Custom vector database 26 | chunk_db=custom_chunk_db # Custom chunk database 27 | ) 28 | ``` 29 | 30 | ## Adding Documents 31 | 32 | Documents can be added from text or files: 33 | 34 | ```python 35 | # Add from text 36 | kb.add_document( 37 | doc_id="intro-guide", 38 | text="This is the introduction guide...", 39 | document_title="Introduction Guide", 40 | metadata={"type": "guide", "version": "1.0"} 41 | ) 42 | 43 | # Add from file 44 | kb.add_document( 45 | doc_id="user-manual", 46 | file_path="path/to/manual.pdf", 47 | metadata={"type": "manual", "department": "engineering"} 48 | ) 49 | 50 | # Add with advanced configuration 51 | kb.add_document( 52 | doc_id="technical-spec", 53 | file_path="path/to/spec.pdf", 54 | file_parsing_config={ 55 | "use_vlm": True, # Use vision language model for PDFs 56 | "always_save_page_images": True # Save page images for visual content 57 | }, 58 | chunking_config={ 59 | "chunk_size": 800, # Characters per chunk 60 | "min_length_for_chunking": 2000 # Minimum length to chunk 61 | }, 62 | auto_context_config={ 63 | "use_generated_title": True, # Generate title if not provided 64 | "get_document_summary": True # Generate document summary 65 | } 66 | ) 67 | ``` 68 | 69 | ## Querying the Knowledge Base 70 | 71 | Search the knowledge base for relevant information: 72 | 73 | ```python 74 | # Simple query 75 | results = kb.query( 76 | search_queries=["How to configure the system?"] 77 | ) 78 | 79 | # Advanced query with filtering and parameters 80 | results = kb.query( 81 | search_queries=[ 82 | "System configuration steps", 83 | "Configuration prerequisites" 84 | ], 85 | metadata_filter={ 86 | "field": "doc_id", 87 | "operator": "equals", 88 | "value": "user_manual" 89 | }, 90 | rse_params="precise", # Use preset RSE parameters 91 | return_mode="text" # Return text content 92 | ) 93 | 94 | # Process results 95 | for segment in results: 96 | print(f""" 97 | Document: {segment['doc_id']} 98 | Pages: {segment['segment_page_start']} - {segment['segment_page_end']} 99 | Content: {segment['content']} 100 | Relevance: {segment['score']} 101 | """) 102 | ``` 103 | 104 | ## RSE Parameters 105 | 106 | The Relevant Segment Extraction (RSE) system can be tuned using different parameter presets: 107 | - `"balanced"`: Default preset balancing precision and comprehensiveness 108 | - `"precise"`: Favors shorter, more focused segments 109 | - `"comprehensive"`: Returns longer segments with more context 110 | 111 | Or configure custom RSE parameters: 112 | ```python 113 | results = kb.query( 114 | search_queries=["system requirements"], 115 | rse_params={ 116 | "max_length": 5, # Max segments length (in number of chunks) 117 | "overall_max_length": 20, # Total length limit across all segments (in number of chunks) 118 | "minimum_value": 0.5, # Minimum relevance score 119 | "irrelevant_chunk_penalty": 0.2 # Penalty for irrelevant chunks in a segment - higher penalty leads to shorter segments 120 | } 121 | ) 122 | ``` 123 | 124 | ## Metadata Query Filters 125 | 126 | Certain vector DBs support metadata filtering when running a query (currently only ChromaDB). This allows you to have more control over what document(s) get searched. A common use case would be asking questions about a single document in a knowledge base, in which case you would supply the `doc_id` as a metadata filter. 127 | 128 | The metadata filter should be a dictionary with the following structure: 129 | 130 | ```python 131 | metadata_filter = { 132 | "field": "doc_id", # The metadata field to filter on 133 | "operator": "equals", # The comparison operator 134 | "value": "doc123" # The value to compare against 135 | } 136 | ``` 137 | 138 | Supported operators: 139 | - `equals` 140 | - `not_equals` 141 | - `in` 142 | - `not_in` 143 | - `greater_than` 144 | - `less_than` 145 | - `greater_than_equals` 146 | - `less_than_equals` 147 | 148 | For operators that take multiple values (`in` and `not_in`), the value should be a list where all items are of the same type (string, integer, or float). 149 | 150 | Example usage: 151 | ```python 152 | # Query a specific document 153 | results = kb.query( 154 | search_queries=["system requirements"], 155 | metadata_filter={ 156 | "field": "doc_id", 157 | "operator": "equals", 158 | "value": "technical_spec_v1" 159 | } 160 | ) 161 | 162 | # Query documents from multiple departments 163 | results = kb.query( 164 | search_queries=["security protocols"], 165 | metadata_filter={ 166 | "field": "department", 167 | "operator": "in", 168 | "value": ["security", "compliance"] 169 | } 170 | ) 171 | ``` -------------------------------------------------------------------------------- /docs/concepts/overview.md: -------------------------------------------------------------------------------- 1 | # Architecture Overview 2 | 3 | dsRAG is built around three key methods that improve performance over vanilla RAG systems: 4 | 5 | 1. Semantic sectioning 6 | 2. AutoContext 7 | 3. Relevant Segment Extraction (RSE) 8 | 9 | ## Key Methods 10 | 11 | ### Semantic Sectioning 12 | 13 | Semantic sectioning uses an LLM to break a document into cohesive sections. The process works as follows: 14 | 15 | 1. The document is annotated with line numbers 16 | 2. An LLM identifies the starting and ending lines for each "semantically cohesive section" 17 | 3. Sections typically range from a few paragraphs to a few pages long 18 | 4. Sections are broken into smaller chunks if needed 19 | 5. The LLM generates descriptive titles for each section 20 | 6. Section titles are used in contextual chunk headers created by AutoContext 21 | 22 | This process provides additional context to the ranking models (embeddings and reranker), enabling better retrieval. 23 | 24 | ### AutoContext (Contextual Chunk Headers) 25 | 26 | AutoContext creates contextual chunk headers that contain: 27 | 28 | - Document-level context 29 | - Section-level context 30 | 31 | These headers are prepended to chunks before embedding. Benefits include: 32 | 33 | - More accurate and complete representation of text content and meaning 34 | - Dramatic improvement in retrieval quality 35 | - Reduced rate of irrelevant results 36 | - Reduced LLM misinterpretation in downstream applications 37 | 38 | ### Relevant Segment Extraction (RSE) 39 | 40 | RSE is a query-time post-processing step that: 41 | 42 | 1. Takes clusters of relevant chunks 43 | 2. Intelligently combines them into longer sections (segments) 44 | 3. Provides better context to the LLM than individual chunks 45 | 46 | RSE is particularly effective for: 47 | 48 | - Complex questions where answers span multiple chunks 49 | - Adapting the context length based on query type 50 | - Maintaining coherent context while avoiding irrelevant information 51 | 52 | ## Document Processing Flow 53 | 54 | 1. Documents → VLM file parsing 55 | 2. Semantic sectioning 56 | 3. Chunking 57 | 4. AutoContext 58 | 5. Embedding 59 | 6. Chunk and vector database upsert 60 | 61 | ## Query Processing Flow 62 | 63 | 1. Queries 64 | 2. Vector database search 65 | 3. Reranking 66 | 4. RSE 67 | 5. Results 68 | 69 | For more detailed information about specific components and configuration options, please refer to: 70 | - [Components Documentation](components.md) 71 | - [Configuration Options](config.md) 72 | - [Knowledge Base Details](knowledge-base.md) -------------------------------------------------------------------------------- /docs/concepts/vlm.md: -------------------------------------------------------------------------------- 1 | # VLM File Parsing 2 | 3 | dsRAG supports Vision Language Model (VLM) integration for enhanced PDF parsing capabilities. This feature is particularly useful for documents with complex layouts, tables, diagrams, or other visual elements that traditional text extraction might miss. 4 | 5 | ## Overview 6 | 7 | When VLM parsing is enabled: 8 | 9 | 1. The PDF is converted to images (one per page) 10 | 2. Each page image is analyzed by the VLM to identify and extract text and visual elements 11 | 3. The extracted elements are converted to a structured format with page numbers preserved 12 | 13 | ## Supported Providers 14 | 15 | Currently supported VLM providers: 16 | 17 | - Google's Gemini (default) 18 | - Google Cloud Vertex AI 19 | 20 | ## Configuration 21 | 22 | To use VLM for file parsing, configure the `file_parsing_config` when adding documents: 23 | 24 | ```python 25 | file_parsing_config = { 26 | "use_vlm": True, # Enable VLM parsing 27 | "vlm_config": { 28 | # VLM Provider (required) 29 | "provider": "gemini", # or "vertex_ai" 30 | 31 | # Model name (optional - defaults based on provider) 32 | "model": "gemini-2.0-flash", # default for Gemini 33 | 34 | # For Vertex AI, additional required fields: 35 | "project_id": "your-gcp-project-id", 36 | "location": "us-central1", 37 | 38 | # Elements to exclude from parsing (optional) 39 | "exclude_elements": ["Header", "Footer"], 40 | 41 | # Whether images are pre-extracted (optional) 42 | "images_already_exist": False 43 | }, 44 | 45 | # Save page images even if VLM unused (optional) 46 | "always_save_page_images": False 47 | } 48 | 49 | # Add document with VLM parsing (must be a PDF file) 50 | kb.add_document( 51 | doc_id="my_document", 52 | file_path="path/to/document.pdf", # Only PDF files are supported with VLM 53 | file_parsing_config=file_parsing_config 54 | ) 55 | ``` 56 | 57 | ## Element Types 58 | 59 | The VLM is prompted to categorize page content into the following element types: 60 | 61 | Text Elements: 62 | 63 | - NarrativeText: Main text content including paragraphs, lists, and titles 64 | - Header: Page header content (typically at top of page) 65 | - Footnote: References or notes at bottom of content 66 | - Footer: Page footer content (at very bottom of page) 67 | 68 | Visual Elements: 69 | 70 | - Figure: Charts, graphs, and diagrams with associated titles and legends 71 | - Image: Photos, illustrations, and other visual content 72 | - Table: Tabular data arrangements with titles and captions 73 | - Equation: Mathematical formulas and expressions 74 | 75 | By default, Header and Footer elements are excluded from parsing as they rarely contain valuable information and can break the flow between pages. You can modify which elements to exclude using the `exclude_elements` configuration option. 76 | 77 | ## Best Practices 78 | 79 | 1. Use VLM parsing for documents with: 80 | - Complex layouts 81 | - Important visual elements 82 | - Tables that need precise formatting 83 | - Mathematical formulas 84 | - Scanned documents that need OCR 85 | 86 | 2. Consider traditional parsing for: 87 | - Simple text documents 88 | - Documents where visual layout isn't critical 89 | - Large volumes of documents (VLM parsing is slower and more expensive) 90 | 91 | 3. Configure `exclude_elements` to ignore irrelevant elements like headers/footers -------------------------------------------------------------------------------- /docs/getting-started/installation.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | ## Installing dsRAG 4 | 5 | Install the package using pip: 6 | 7 | ```bash 8 | pip install dsrag 9 | ``` 10 | 11 | If you want to use VLM file parsing, you will need one non-Python dependency: poppler. This is used for converting PDFs to images. On MacOS, you can install it using Homebrew: 12 | 13 | ```bash 14 | brew install poppler 15 | ``` 16 | 17 | ## Third-party dependencies 18 | 19 | To get the most out of dsRAG, you'll need a few third-party API keys set as environment variables. Here are the most important ones: 20 | 21 | - `OPENAI_API_KEY`: Recommended for embeddings, AutoContext, and semantic sectioning 22 | - `CO_API_KEY`: Recommended for reranking 23 | - `GEMINI_API_KEY`: Recommended for VLM file parsing -------------------------------------------------------------------------------- /docs/getting-started/quickstart.md: -------------------------------------------------------------------------------- 1 | # Quick Start Guide 2 | 3 | This guide will help you get started with dsRAG quickly. We'll cover the basics of creating a knowledge base and querying it. 4 | 5 | ## Setting Up 6 | 7 | First, make sure you have the necessary API keys set as environment variables: 8 | - `OPENAI_API_KEY` for embeddings and AutoContext 9 | - `CO_API_KEY` for reranking with Cohere 10 | 11 | If you don't have both of these available, you can use the OpenAI-only configuration or local-only configuration below. 12 | 13 | ## Creating a Knowledge Base 14 | 15 | Create a new knowledge base and add documents to it: 16 | 17 | ```python 18 | from dsrag.knowledge_base import KnowledgeBase 19 | 20 | # Create a knowledge base 21 | kb = KnowledgeBase(kb_id="my_knowledge_base") 22 | 23 | # Add documents 24 | kb.add_document( 25 | doc_id="user_manual", # Use a meaningful ID if possible 26 | file_path="path/to/your/document.pdf", 27 | document_title="User Manual", # Optional but recommended 28 | metadata={"type": "manual"} # Optional metadata 29 | ) 30 | ``` 31 | 32 | The KnowledgeBase object persists to disk automatically, so you don't need to explicitly save it. 33 | 34 | ## Querying the Knowledge Base 35 | 36 | Once you have created a knowledge base, you can load it by its `kb_id` and query it: 37 | 38 | ```python 39 | from dsrag.knowledge_base import KnowledgeBase 40 | 41 | # Load the knowledge base 42 | kb = KnowledgeBase("my_knowledge_base") 43 | 44 | # You can query with multiple search queries 45 | search_queries = [ 46 | "What are the main topics covered?", 47 | "What are the key findings?" 48 | ] 49 | 50 | # Get results 51 | results = kb.query(search_queries) 52 | for segment in results: 53 | print(segment) 54 | ``` 55 | 56 | ## Using OpenAI-Only Configuration 57 | 58 | If you prefer to use only OpenAI services (without Cohere), you can customize the configuration: 59 | 60 | ```python 61 | from dsrag.llm import OpenAIChatAPI 62 | from dsrag.reranker import NoReranker 63 | 64 | # Configure components 65 | llm = OpenAIChatAPI(model='gpt-4o-mini') 66 | reranker = NoReranker() 67 | 68 | # Create knowledge base with custom configuration 69 | kb = KnowledgeBase( 70 | kb_id="my_knowledge_base", 71 | reranker=reranker, 72 | auto_context_model=llm 73 | ) 74 | 75 | # Add documents 76 | kb.add_document( 77 | doc_id="user_manual", # Use a meaningful ID 78 | file_path="path/to/your/document.pdf", 79 | document_title="User Manual", # Optional but recommended 80 | metadata={"type": "manual"} # Optional metadata 81 | ) 82 | ``` 83 | 84 | ## Local-only configuration 85 | 86 | If you don't want to use any third-party services, you can configure dsRAG to run fully locally using Ollama. There will be a couple limitations: 87 | - You will not be able to use VLM file parsing 88 | - You will not be able to use semantic sectioning 89 | - You will not be able to use a reranker, unless you implement your own custom Reranker class 90 | 91 | ```python 92 | from dsrag.llm import OllamaChatAPI 93 | from dsrag.reranker import NoReranker 94 | from dsrag.embedding import OllamaEmbedding 95 | 96 | llm = OllamaChatAPI(model="llama3.1:8b") 97 | reranker = NoReranker() 98 | embedding = OllamaEmbedding(model="nomic-embed-text") 99 | 100 | kb = KnowledgeBase( 101 | kb_id="my_knowledge_base", 102 | reranker=reranker, 103 | auto_context_model=llm, 104 | embedding_model=embedding 105 | ) 106 | 107 | # Disable semantic sectioning 108 | semantic_sectioning_config = { 109 | "use_semantic_sectioning": False, 110 | } 111 | 112 | # Add documents 113 | kb.add_document( 114 | doc_id="user_manual", 115 | file_path="path/to/your/document.pdf", 116 | document_title="User Manual", 117 | semantic_sectioning_config=semantic_sectioning_config 118 | ) 119 | ``` -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Welcome to dsRAG 2 | 3 | dsRAG is a retrieval engine for unstructured data. It is especially good at handling challenging queries over dense text, like financial reports, legal documents, and academic papers. dsRAG achieves substantially higher accuracy than vanilla RAG baselines on complex open-book question answering tasks. On one especially challenging benchmark, [FinanceBench](https://arxiv.org/abs/2311.11944), dsRAG gets accurate answers 96.6% of the time, compared to the vanilla RAG baseline which only gets 32% of questions correct. 4 | 5 | ## Key Methods 6 | 7 | dsRAG uses three key methods to improve performance over vanilla RAG systems: 8 | 9 | ### 1. Semantic Sectioning 10 | Semantic sectioning uses an LLM to break a document into sections. It works by annotating the document with line numbers and then prompting an LLM to identify the starting and ending lines for each "semantically cohesive section." These sections should be anywhere from a few paragraphs to a few pages long. The sections then get broken into smaller chunks if needed. The LLM is also prompted to generate descriptive titles for each section. These section titles get used in the contextual chunk headers created by AutoContext, which provides additional context to the ranking models (embeddings and reranker), enabling better retrieval. 11 | 12 | ### 2. AutoContext 13 | AutoContext creates contextual chunk headers that contain document-level and section-level context, and prepends those chunk headers to the chunks prior to embedding them. This gives the embeddings a much more accurate and complete representation of the content and meaning of the text. In our testing, this feature leads to a dramatic improvement in retrieval quality. In addition to increasing the rate at which the correct information is retrieved, AutoContext also substantially reduces the rate at which irrelevant results show up in the search results. This reduces the rate at which the LLM misinterprets a piece of text in downstream chat and generation applications. 14 | 15 | ### 3. Relevant Segment Extraction (RSE) 16 | Relevant Segment Extraction (RSE) is a query-time post-processing step that takes clusters of relevant chunks and intelligently combines them into longer sections of text that we call segments. These segments provide better context to the LLM than any individual chunk can. For simple factual questions, the answer is usually contained in a single chunk; but for more complex questions, the answer usually spans a longer section of text. The goal of RSE is to intelligently identify the section(s) of text that provide the most relevant information, without being constrained to fixed length chunks. 17 | 18 | ## Quick Example 19 | 20 | ```python 21 | from dsrag.create_kb import create_kb_from_file 22 | 23 | # Create a knowledge base from a file 24 | file_path = "path/to/your/document.pdf" 25 | kb_id = "my_knowledge_base" 26 | kb = create_kb_from_file(kb_id, file_path) 27 | 28 | # Query the knowledge base 29 | search_queries = ["What are the main topics covered?"] 30 | results = kb.query(search_queries) 31 | for segment in results: 32 | print(segment) 33 | ``` 34 | 35 | ## Evaluation Results 36 | 37 | ### FinanceBench 38 | On [FinanceBench](https://arxiv.org/abs/2311.11944), which uses a corpus of hundreds of 10-Ks and 10-Qs with challenging queries that often require combining multiple pieces of information: 39 | 40 | - Baseline retrieval pipeline: 32% accuracy 41 | - dsRAG (with default parameters and Claude 3.5 Sonnet): 96.6% accuracy 42 | 43 | ### KITE Benchmark 44 | On the [KITE](https://github.com/D-Star-AI/KITE) benchmark, which includes diverse datasets (AI papers, company 10-Ks, company handbooks, and Supreme Court opinions), dsRAG shows significant improvements: 45 | 46 | | | Top-k | RSE | CCH+Top-k | CCH+RSE | 47 | |-------------------------|----------|--------|--------------|------------| 48 | | AI Papers | 4.5 | 7.9 | 4.7 | 7.9 | 49 | | BVP Cloud | 2.6 | 4.4 | 6.3 | 7.8 | 50 | | Sourcegraph | 5.7 | 6.6 | 5.8 | 9.4 | 51 | | Supreme Court Opinions | 6.1 | 8.0 | 7.4 | 8.5 | 52 | | **Average** | 4.72 | 6.73 | 6.04 | 8.42 | 53 | 54 | ## Getting Started 55 | 56 | Check out our [Quick Start Guide](getting-started/quickstart.md) to begin using dsRAG in your projects. 57 | 58 | ## Community and Support 59 | 60 | - Join our [Discord](https://discord.gg/NTUVX9DmQ3) for community support 61 | - Fill out our [use case form](https://forms.gle/RQ5qFVReonSHDcCu5) if using dsRAG in production 62 | - Need professional help? Contact our [team](https://forms.gle/zbQwDJp7pBQKtqVT8) -------------------------------------------------------------------------------- /docs/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-Star-AI/dsRAG/b326c1d0930c9a80ff486ab0df5e771e788e3234/docs/logo.png -------------------------------------------------------------------------------- /dsrag/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | # Configure the root dsrag logger with a NullHandler to prevent "No handler found" warnings 4 | # This follows Python best practices for library logging 5 | # Users will need to configure their own handlers if they want to see dsrag logs 6 | logger = logging.getLogger("dsrag") 7 | logger.addHandler(logging.NullHandler()) -------------------------------------------------------------------------------- /dsrag/auto_query.py: -------------------------------------------------------------------------------- 1 | # NOTE: This is a legacy file and is not used in the current implementation. Will be removed in the future. 2 | 3 | import os 4 | from dsrag.utils.imports import instructor 5 | from anthropic import Anthropic 6 | from pydantic import BaseModel 7 | from typing import List 8 | 9 | SYSTEM_MESSAGE = """ 10 | You are a query generation system. Please generate one or more search queries (up to a maximum of {max_queries}) based on the provided user input. DO NOT generate the answer, just queries. 11 | 12 | Each of the queries you generate will be used to search a knowledge base for information that can be used to respond to the user input. Make sure each query is specific enough to return relevant information. If multiple pieces of information would be useful, you should generate multiple queries, one for each specific piece of information needed. 13 | 14 | {auto_query_guidance} 15 | """.strip() 16 | 17 | 18 | def get_search_queries(user_input: str, auto_query_guidance: str = "", max_queries: int = 5): 19 | base_url = os.environ.get("DSRAG_ANTHROPIC_BASE_URL", None) 20 | if base_url is not None: 21 | client = instructor.from_anthropic(Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"], base_url=base_url)) 22 | else: 23 | client = instructor.from_anthropic(Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])) 24 | 25 | class Queries(BaseModel): 26 | queries: List[str] 27 | 28 | resp = client.messages.create( 29 | model="claude-3-5-sonnet-20241022", 30 | max_tokens=400, 31 | temperature=0.2, 32 | system=SYSTEM_MESSAGE.format(max_queries=max_queries, auto_query_guidance=auto_query_guidance), 33 | messages=[ 34 | { 35 | "role": "user", 36 | "content": user_input 37 | } 38 | ], 39 | response_model=Queries, 40 | ) 41 | 42 | return resp.queries[:max_queries] -------------------------------------------------------------------------------- /dsrag/chat/__init__.py: -------------------------------------------------------------------------------- 1 | from dsrag.chat.chat import create_new_chat_thread, get_chat_thread_response 2 | from dsrag.chat.chat_types import ChatThreadParams, ChatResponseInput 3 | from dsrag.chat.citations import ResponseWithCitations, Citation 4 | 5 | __all__ = [ 6 | 'create_new_chat_thread', 7 | 'get_chat_thread_response', 8 | 'ChatThreadParams', 9 | 'ChatResponseInput', 10 | 'ResponseWithCitations', 11 | 'Citation' 12 | ] -------------------------------------------------------------------------------- /dsrag/chat/auto_query.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pydantic import BaseModel 3 | from typing import List 4 | from dsrag.chat.instructor_get_response import get_response 5 | 6 | SYSTEM_MESSAGE = """ 7 | You are a query generation system. Please generate one or more search queries (up to a maximum of {max_queries}) based on the provided user input. DO NOT generate the answer, just queries. You must specify the knowledge base you want to use for each query by providing the knowledge_base_id. 8 | 9 | Each of the queries you generate will be used to search a knowledge base for information that can be used to respond to the user input. Make sure each query is specific enough to return relevant information. If multiple pieces of information would be useful, you should generate multiple queries, one for each specific piece of information needed. 10 | 11 | {auto_query_guidance} 12 | 13 | Here are the knowledge bases you can search: 14 | {knowledge_base_descriptions} 15 | """.strip() 16 | 17 | class Query(BaseModel): 18 | query: str 19 | knowledge_base_id: str 20 | 21 | class Queries(BaseModel): 22 | queries: List[Query] 23 | 24 | 25 | def get_knowledge_base_descriptions_str(kb_info: list[dict]): 26 | kb_descriptions = [] 27 | for kb in kb_info: 28 | kb_descriptions.append(f"kb_id: {kb['id']}\ndescription: {kb['title']} - {kb['description']}") 29 | return "\n\n".join(kb_descriptions) 30 | 31 | def validate_queries(queries: List[Query], kb_info: list[dict]) -> List[dict]: 32 | """ 33 | Validates and potentially modifies search queries based on knowledge base availability. 34 | 35 | Args: 36 | queries: List of Query objects to validate 37 | kb_info: List of available knowledge base information 38 | 39 | Returns: 40 | List of validated query dictionaries 41 | """ 42 | valid_kb_ids = {kb["id"] for kb in kb_info} 43 | validated_queries = [] 44 | 45 | for query in queries: 46 | if query.knowledge_base_id in valid_kb_ids: 47 | validated_queries.append({"query": query.query, "kb_id": query.knowledge_base_id}) 48 | else: 49 | # If invalid KB ID is found 50 | if len(kb_info) == 1: 51 | # If only one KB exists, use that 52 | validated_queries.append({"query": query.query, "kb_id": kb_info[0]["id"]}) 53 | else: 54 | # If multiple KBs exist, create a query for each KB 55 | for kb in kb_info: 56 | validated_queries.append({"query": query.query, "kb_id": kb["id"]}) 57 | 58 | return validated_queries 59 | 60 | def get_search_queries(chat_messages: list[dict], kb_info: list[dict], auto_query_guidance: str = "", max_queries: int = 5, auto_query_model: str = "gpt-4o-mini") -> List[dict]: 61 | """ 62 | Input: 63 | - chat_messages: list of dictionaries, where each dictionary has the keys "role" and "content". This should include the current user input as the final message. 64 | - kb_info: list of dictionaries, where each dictionary has the keys "id", "title", and "description". This should include information about the available knowledge bases. 65 | - auto_query_guidance: str, optional additional instructions for the auto_query system 66 | - max_queries: int, maximum number of queries to generate 67 | - auto_query_model: str, the model to use for generating queries 68 | 69 | Returns 70 | - queries: list of dictionaries, where each dictionary has the keys "query" and "knowledge_base_id" 71 | """ 72 | knowledge_base_descriptions = get_knowledge_base_descriptions_str(kb_info) 73 | 74 | system_message = SYSTEM_MESSAGE.format( 75 | max_queries=max_queries, 76 | auto_query_guidance=auto_query_guidance, 77 | knowledge_base_descriptions=knowledge_base_descriptions 78 | ) 79 | 80 | messages = [{"role": "system", "content": system_message}] + chat_messages 81 | 82 | # Try with initial model 83 | try: 84 | queries = get_response( 85 | messages=messages, 86 | model_name=auto_query_model, 87 | response_model=Queries, 88 | max_tokens=600, 89 | temperature=0.0 90 | ) 91 | except: 92 | # Fallback to a stronger model 93 | fallback_model = "claude-3-5-sonnet-20241022" if "gpt" in auto_query_model else "gpt-4o" 94 | queries = get_response( 95 | messages=messages, 96 | model_name=fallback_model, 97 | response_model=Queries, 98 | max_tokens=600, 99 | temperature=0.0 100 | ) 101 | 102 | # Validate and potentially modify the queries 103 | validated_queries = validate_queries(queries.queries[:max_queries], kb_info)[:max_queries] 104 | return validated_queries -------------------------------------------------------------------------------- /dsrag/chat/chat_types.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, Literal 2 | from typing_extensions import TypedDict 3 | from pydantic import BaseModel 4 | 5 | class ChatThreadParams(TypedDict): 6 | kb_ids: Optional[list[str]] 7 | model: Optional[str] 8 | temperature: Optional[float] 9 | system_message: Optional[str] 10 | auto_query_model: Optional[str] 11 | auto_query_guidance: Optional[str] 12 | rse_params: Optional[dict] 13 | target_output_length: Optional[str] 14 | max_chat_history_tokens: Optional[int] 15 | 16 | class ChatResponseOutput(TypedDict): 17 | response: str 18 | metadata: dict 19 | 20 | class MetadataFilter(TypedDict): 21 | field: str 22 | operator: Literal['equals', 'not_equals', 'in', 'not_in', 'greater_than', 'less_than', 'greater_than_equals', 'less_than_equals'] 23 | value: Union[str, int, float, list[str], list[int], list[float]] 24 | 25 | class ChatResponseInput(BaseModel): 26 | user_input: str 27 | chat_thread_params: Optional[ChatThreadParams] = None 28 | metadata_filter: Optional[MetadataFilter] = None -------------------------------------------------------------------------------- /dsrag/chat/citations.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | from typing import Optional, List 3 | import instructor 4 | 5 | class Citation(BaseModel): 6 | source_index: int = Field(..., description="The index of thesource where the information used to generate the response was found.") 7 | page_number: Optional[int] = Field(None, description="The page where the information was found. May only be None if page numbers are not provided.") 8 | cited_text: str = Field(..., description="The exact text containing the information used to generate the response.") 9 | 10 | class ResponseWithCitations(BaseModel): 11 | response: str = Field(..., description="The response to the user's question") 12 | citations: List[Citation] = Field(..., description="The citations used to generate the response") 13 | 14 | # Create a partial version of ResponseWithCitations for streaming 15 | PartialResponseWithCitations = instructor.Partial[ResponseWithCitations] 16 | 17 | def format_page_content(page_number: int, content: str) -> str: 18 | """Format a single page's content with page number annotations""" 19 | return f"\n{content}\n" 20 | 21 | def get_source_text(kb_id: str, doc_id: str, source_index: int, page_start: Optional[int], page_end: Optional[int], file_system) -> Optional[str]: 22 | """ 23 | Get the source text for a given document and page range with page number annotations. 24 | Returns None if no page content is available. 25 | """ 26 | if page_start is None or page_end is None: 27 | return None 28 | 29 | page_contents = file_system.load_page_content_range(kb_id, doc_id, page_start, page_end) 30 | if not page_contents: 31 | print(f"No page contents found for doc_id: {doc_id}, page_start: {page_start}, page_end: {page_end}") 32 | return None 33 | 34 | source_text = f"\n" 35 | for i, content in enumerate(page_contents): 36 | page_number = page_start + i 37 | source_text += format_page_content(page_number, content) + "\n" 38 | source_text += f"" 39 | 40 | return source_text 41 | 42 | def format_sources_for_context(search_results: list[dict], kb_id: str, file_system) -> tuple[str, list[str]]: 43 | """ 44 | Format all search results into a context string for the LLM. 45 | Handles both cases with and without page numbers. 46 | """ 47 | context_parts = [] 48 | all_doc_ids = [] 49 | 50 | for result in search_results: 51 | doc_id = result["doc_id"] 52 | source_index = result["source_index"] 53 | page_start = result.get("segment_page_start") 54 | page_end = result.get("segment_page_end") 55 | all_doc_ids.append(doc_id) 56 | 57 | full_page_source_text = None 58 | if page_start is not None and page_end is not None: 59 | # if we have page numbers, then we assume we have the page content - but this is not always the case 60 | full_page_source_text = get_source_text(kb_id, doc_id, source_index, page_start, page_end, file_system) 61 | 62 | if full_page_source_text: 63 | context_parts.append(full_page_source_text) 64 | else: 65 | result_content = result.get("content", "") 66 | context_parts.append(f"\n{result_content}\n") 67 | 68 | return "\n\n".join(context_parts), all_doc_ids 69 | 70 | def convert_elements_to_page_content(elements: list[dict], kb_id: str, doc_id: str, file_system) -> None: 71 | """ 72 | Convert elements to page content and save it using the page content methods. 73 | This should be called when a document is first added to the knowledge base. 74 | Only processes documents where elements have page numbers. 75 | """ 76 | # Check if this document has page numbers 77 | if not elements or "page_number" not in elements[0]: 78 | print(f"No page numbers found for document {doc_id}") 79 | return 80 | 81 | # Group elements by page 82 | pages = {} 83 | for element in elements: 84 | page_num = element["page_number"] 85 | if page_num not in pages: 86 | pages[page_num] = [] 87 | pages[page_num].append(element["content"]) 88 | 89 | # Save each page's content 90 | for page_num, contents in pages.items(): 91 | page_content = "\n".join(contents) 92 | file_system.save_page_content(kb_id, doc_id, page_num, page_content) -------------------------------------------------------------------------------- /dsrag/chat/model_names.py: -------------------------------------------------------------------------------- 1 | # NOTE: this is deprecated but remains for backwards compatibility. 2 | # The new way to specify model names for Chat is {provider}/{model_name} so we don't have to do this lookup (and maintain an up-to-date mapping here). 3 | 4 | ANTHROPIC_MODEL_NAMES = [ 5 | "claude-3-haiku-20240307", 6 | "claude-3-5-haiku-20241022", 7 | "claude-3-5-haiku-latest", 8 | "claude-3-5-sonnet-20240620", 9 | "claude-3-5-sonnet-20241022", 10 | "claude-3-5-sonnet-latest", 11 | "claude-3-7-sonnet-latest", 12 | "claude-3-7-sonnet-20250219", 13 | ] 14 | 15 | GEMINI_MODEL_NAMES = [ 16 | "gemini-1.5-flash-002", 17 | "gemini-1.5-pro-002", 18 | "gemini-2.0-flash", 19 | "gemini-2.0-flash-lite", 20 | "gemini-2.0-flash-thinking-exp", 21 | "gemini-2.0-pro-exp", 22 | "gemini-2.5-pro-exp-03-25", 23 | "gemini-2.5-pro-preview-03-25", 24 | ] 25 | 26 | OPENAI_MODEL_NAMES = [ 27 | "gpt-4o-mini", 28 | "gpt-4o-mini-2024-07-18", 29 | "gpt-4o", 30 | "gpt-4o-2024-05-13", 31 | "gpt-4o-2024-08-06", 32 | "gpt-4o-2024-11-20", 33 | "o1", 34 | "o1-2024-12-17" 35 | "o1-preview", 36 | "o1-preview-2024-09-12", 37 | "o3-mini", 38 | "o3-mini-2025-01-31", 39 | "gpt-4.5-preview", 40 | "gpt-4.5-preview-2025-02-27", 41 | "chatgpt-4o-latest", 42 | ] -------------------------------------------------------------------------------- /dsrag/create_kb.py: -------------------------------------------------------------------------------- 1 | from dsrag.dsparse.file_parsing.non_vlm_file_parsing import extract_text_from_pdf, extract_text_from_docx 2 | from dsrag.knowledge_base import KnowledgeBase 3 | import os 4 | import time 5 | 6 | def create_kb_from_directory(kb_id: str, directory: str, title: str = None, description: str = "", language: str = 'en'): 7 | """ 8 | - kb_id is the name of the knowledge base 9 | - directory is the absolute path to the directory containing the documents 10 | - no support for manually defined chunk headers here, because they would have to be defined for each file in the directory 11 | 12 | Supported file types: .docx, .md, .txt, .pdf 13 | """ 14 | if not title: 15 | title = kb_id 16 | 17 | # create a new KB 18 | kb = KnowledgeBase(kb_id, title=title, description=description, language=language, exists_ok=False) 19 | 20 | # add documents 21 | for root, dirs, files in os.walk(directory): 22 | for file_name in files: 23 | if file_name.endswith(('.docx', '.md', '.txt', '.pdf')): 24 | try: 25 | file_path = os.path.join(root, file_name) 26 | clean_file_path = file_path.replace(directory, "") 27 | 28 | if file_name.endswith('.docx'): 29 | text = extract_text_from_docx(file_path) 30 | elif file_name.endswith('.pdf'): 31 | text, _ = extract_text_from_pdf(file_path) 32 | elif file_name.endswith('.md') or file_name.endswith('.txt'): 33 | with open(file_path, 'r') as f: 34 | text = f.read() 35 | 36 | kb.add_document(doc_id=clean_file_path, text=text) 37 | time.sleep(1) # pause for 1 second to avoid hitting API rate limits 38 | except: 39 | print (f"Error reading {file_name}") 40 | continue 41 | else: 42 | print (f"Unsupported file type: {file_name}") 43 | continue 44 | 45 | return kb 46 | 47 | def create_kb_from_file(kb_id: str, file_path: str, title: str = None, description: str = "", language: str = 'en'): 48 | """ 49 | - kb_id is the name of the knowledge base 50 | - file_path is the absolute path to the file containing the documents 51 | 52 | Supported file types: .docx, .md, .txt, .pdf 53 | """ 54 | if not title: 55 | title = kb_id 56 | 57 | # create a new KB 58 | kb = KnowledgeBase(kb_id, title=title, description=description, language=language, exists_ok=False) 59 | 60 | print (f'Creating KB with id {kb_id}...') 61 | 62 | file_name = os.path.basename(file_path) 63 | 64 | # add document 65 | if file_path.endswith(('.docx', '.md', '.txt', '.pdf')): 66 | # define clean file path as just the file name here since we're not using a directory 67 | clean_file_path = file_name 68 | 69 | if file_path.endswith('.docx'): 70 | text = extract_text_from_docx(file_path) 71 | elif file_name.endswith('.pdf'): 72 | text, _ = extract_text_from_pdf(file_path) 73 | elif file_path.endswith('.md') or file_path.endswith('.txt'): 74 | with open(file_path, 'r') as f: 75 | text = f.read() 76 | 77 | kb.add_document(doc_id=clean_file_path, text=text) 78 | else: 79 | print (f"Unsupported file type: {file_name}") 80 | return 81 | 82 | return kb -------------------------------------------------------------------------------- /dsrag/custom_term_mapping.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import List, Dict 3 | from pydantic import BaseModel, Field 4 | from dsrag.chat.instructor_get_response import get_response 5 | 6 | FIND_TARGET_TERMS_PROMPT = """ 7 | Your task is to take a list of target terms and return a list of all instances of those terms in the provided text. The instances you return must be direct matches to the text, but can be fuzzy matches to the terms. In other words, you should allow for some variation in spelling and capitalization relative to the terms. For example, if the synonym is "capital of France", and the test has the phrase "Capitol of Frances", you should return the phrase "Capitol of Frances". 8 | 9 | 10 | {target_terms} 11 | 12 | 13 | 14 | {text} 15 | 16 | """.strip() 17 | 18 | class Terms(BaseModel): 19 | terms: List[str] = Field(description="A list of target terms (or variations thereof) that were found in the text") 20 | 21 | def find_target_terms_batch(text: str, target_terms: List[str]) -> List[str]: 22 | """Find fuzzy matches of target terms in a batch of text using LLM.""" 23 | target_terms_str = "\n".join(target_terms) 24 | terms = get_response( 25 | prompt=FIND_TARGET_TERMS_PROMPT.format(target_terms=target_terms_str, text=text), 26 | response_model=Terms, 27 | ).terms 28 | return terms 29 | 30 | def find_all_term_variations(chunks: List[str], target_terms: List[str]) -> List[str]: 31 | """Find all variations of target terms across all chunks, processing in batches.""" 32 | all_variations = set() 33 | batch_size = 15 # Process 15 chunks at a time 34 | 35 | for i in range(0, len(chunks), batch_size): 36 | batch_chunks = chunks[i:i + batch_size] 37 | batch_text = "\n\n".join(batch_chunks) 38 | variations = find_target_terms_batch(batch_text, target_terms) 39 | all_variations.update(variations) 40 | 41 | return list(all_variations) 42 | 43 | def annotate_chunk(chunk: str, key: str, terms: List[str]) -> str: 44 | """Annotate a single chunk with terms and their mapping.""" 45 | # Process instances in reverse order to avoid messing up string indices 46 | for term in terms: 47 | instances = re.finditer(re.escape(term), chunk) 48 | instances = list(instances) 49 | for instance in reversed(instances): 50 | start = instance.start() 51 | end = instance.end() 52 | chunk = chunk[:end] + f" ({key})" + chunk[end:] 53 | return chunk.strip() 54 | 55 | def annotate_chunks(chunks: List[str], custom_term_mapping: Dict[str, List[str]]) -> List[str]: 56 | """ 57 | Annotate chunks with custom term mappings. 58 | 59 | Args: 60 | chunks: list[str] - list of all chunks in the document 61 | custom_term_mapping: dict - a dictionary of custom term mapping for the document 62 | - key: str - the term to map to 63 | - value: list[str] - the list of terms to map to the key 64 | 65 | Returns: 66 | list[str] - annotated chunks 67 | """ 68 | # First, find all variations for each key's terms 69 | term_variations = {} 70 | for key, target_terms in custom_term_mapping.items(): 71 | variations = find_all_term_variations(chunks, target_terms) 72 | term_variations[key] = list(set(variations + target_terms)) # Include original terms 73 | # remove an exact match of the key from the list of terms 74 | term_variations[key] = [term for term in term_variations[key] if term != key] 75 | 76 | # Then annotate each chunk with all terms 77 | annotated_chunks = [] 78 | for chunk in chunks: 79 | annotated_chunk = chunk 80 | for key, terms in term_variations.items(): 81 | annotated_chunk = annotate_chunk(annotated_chunk, key, terms) 82 | annotated_chunks.append(annotated_chunk) 83 | 84 | return annotated_chunks -------------------------------------------------------------------------------- /dsrag/database/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-Star-AI/dsRAG/b326c1d0930c9a80ff486ab0df5e771e788e3234/dsrag/database/__init__.py -------------------------------------------------------------------------------- /dsrag/database/chat_thread/basic_db.py: -------------------------------------------------------------------------------- 1 | import json 2 | from .db import ChatThreadDB 3 | import uuid 4 | 5 | class BasicChatThreadDB(ChatThreadDB): 6 | def __init__(self): 7 | try: 8 | with open("chat_thread_db.json", "r") as f: 9 | self.chat_threads = json.load(f) 10 | except FileNotFoundError: 11 | self.chat_threads = {} 12 | 13 | def create_chat_thread(self, chat_thread_params: dict) -> dict: 14 | """ 15 | chat_thread_params: dict with the following keys: 16 | - thread_id: str 17 | - kb_ids: list[str] 18 | - model: str 19 | - title: str 20 | - temperature: float 21 | - system_message: str 22 | - auto_query_guidance: str 23 | - target_output_length: str 24 | - max_chat_history_tokens: int 25 | 26 | Returns 27 | chat_thread: dict with the following keys: 28 | - params: dict 29 | - interactions: list[dict] 30 | """ 31 | 32 | chat_thread = { 33 | "params": chat_thread_params, 34 | "interactions": [] 35 | } 36 | 37 | self.chat_threads[chat_thread_params["thread_id"]] = chat_thread 38 | self.save() 39 | return chat_thread 40 | 41 | def list_chat_threads(self) -> list[dict]: 42 | return [self.chat_threads[thread_id] for thread_id in self.chat_threads] 43 | 44 | def get_chat_thread(self, thread_id: str) -> dict: 45 | return self.chat_threads.get(thread_id) 46 | 47 | def update_chat_thread(self, thread_id: str, chat_thread_params: dict) -> dict: 48 | self.chat_threads[thread_id]["params"] = chat_thread_params 49 | self.save() 50 | return chat_thread_params 51 | 52 | def delete_chat_thread(self, thread_id: str) -> dict: 53 | chat_thread = self.chat_threads.pop(thread_id, None) 54 | self.save() 55 | return chat_thread 56 | 57 | def add_interaction(self, thread_id: str, interaction: dict) -> dict: 58 | """ 59 | interaction should be a dict with the following keys: 60 | - user_input: dict 61 | - content: str 62 | - timestamp: str 63 | - model_response: dict 64 | - content: str 65 | - citations: list[dict] 66 | - doc_id: str 67 | - page_numbers: list[int] 68 | - cited_text: str 69 | - timestamp: str 70 | - status: str (pending, streaming, highlighting_citations, finished, failed) 71 | - relevant_segments: list[dict] 72 | - text: str 73 | - doc_id: str 74 | - kb_id: str 75 | - search_queries: list[dict] 76 | - query: str 77 | - kb_id: str 78 | """ 79 | message_id = str(uuid.uuid4()) 80 | interaction["message_id"] = message_id 81 | 82 | # Set default status if not provided 83 | if "model_response" in interaction and "status" not in interaction["model_response"]: 84 | interaction["model_response"]["status"] = "pending" 85 | 86 | self.chat_threads[thread_id]["interactions"].append(interaction) 87 | self.save() 88 | return interaction 89 | 90 | def update_interaction(self, thread_id: str, message_id: str, interaction_update: dict) -> dict: 91 | """ 92 | Updates an existing interaction in a chat thread. 93 | Only updates the fields provided in interaction_update. 94 | """ 95 | # Find the interaction with the matching message_id 96 | for i, interaction in enumerate(self.chat_threads[thread_id]["interactions"]): 97 | if interaction.get("message_id") == message_id: 98 | # Update only the fields provided in interaction_update 99 | if "model_response" in interaction_update: 100 | interaction["model_response"].update(interaction_update["model_response"]) 101 | 102 | # Ensure status field exists (for backward compatibility) 103 | if "status" not in interaction["model_response"]: 104 | interaction["model_response"]["status"] = "finished" 105 | 106 | # Save the changes 107 | self.save() 108 | return {"message_id": message_id, "updated": True} 109 | 110 | return {"message_id": message_id, "updated": False, "error": "Interaction not found"} 111 | 112 | def save(self): 113 | """ 114 | Save the ChatThreadDB to a JSON file. 115 | """ 116 | with open("chat_thread_db.json", "w") as f: 117 | json.dump(self.chat_threads, f) 118 | 119 | def load(self): 120 | with open("chat_thread_db.json", "r") as f: 121 | self.chat_threads = json.load(f) -------------------------------------------------------------------------------- /dsrag/database/chat_thread/db.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | class ChatThreadDB(ABC): 4 | @abstractmethod 5 | def create_chat_thread(self, chat_thread_params: dict) -> dict: 6 | """ 7 | Creates a chat thread in the database. 8 | """ 9 | pass 10 | 11 | @abstractmethod 12 | def list_chat_threads(self) -> list[dict]: 13 | """ 14 | Lists all chat threads in the database. 15 | """ 16 | pass 17 | 18 | @abstractmethod 19 | def get_chat_thread(self, thread_id: str) -> dict: 20 | """ 21 | Gets a chat thread by ID. 22 | """ 23 | pass 24 | 25 | @abstractmethod 26 | def update_chat_thread(self, thread_id: str, chat_thread_params: dict) -> dict: 27 | """ 28 | Updates a chat thread by ID. 29 | """ 30 | pass 31 | 32 | @abstractmethod 33 | def delete_chat_thread(self, thread_id: str) -> dict: 34 | """ 35 | Deletes a chat thread by ID. 36 | """ 37 | pass 38 | 39 | @abstractmethod 40 | def add_interaction(self, thread_id: str, interaction: dict) -> dict: 41 | """ 42 | Adds an interaction to a chat thread. 43 | """ 44 | pass 45 | 46 | @abstractmethod 47 | def update_interaction(self, thread_id: str, message_id: str, interaction_update: dict) -> dict: 48 | """ 49 | Updates an existing interaction in a chat thread. 50 | Only updates the fields provided in interaction_update. 51 | """ 52 | pass -------------------------------------------------------------------------------- /dsrag/database/chunk/__init__.py: -------------------------------------------------------------------------------- 1 | from .db import ChunkDB 2 | from .types import FormattedDocument 3 | 4 | # Always import the basic DB as it has no dependencies 5 | from .basic_db import BasicChunkDB 6 | 7 | # Define what's in __all__ for "from dsrag.database.chunk import *" 8 | __all__ = [ 9 | "ChunkDB", 10 | "BasicChunkDB", 11 | "FormattedDocument" 12 | ] 13 | 14 | # Lazy load database modules to avoid importing all dependencies at once 15 | def __getattr__(name): 16 | """Lazily import database implementations only when accessed.""" 17 | if name == "SQLiteDB": 18 | from .sqlite_db import SQLiteDB 19 | return SQLiteDB 20 | elif name == "PostgresChunkDB": 21 | from .postgres_db import PostgresChunkDB 22 | return PostgresChunkDB 23 | elif name == "DynamoDB": 24 | from .dynamo_db import DynamoDB 25 | return DynamoDB 26 | else: 27 | raise AttributeError(f"module 'dsrag.database.chunk' has no attribute '{name}'") 28 | -------------------------------------------------------------------------------- /dsrag/database/chunk/db.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any, Optional 3 | 4 | from dsrag.database.chunk.types import FormattedDocument 5 | 6 | 7 | class ChunkDB(ABC): 8 | subclasses = {} 9 | 10 | def __init_subclass__(cls, **kwargs): 11 | super().__init_subclass__(**kwargs) 12 | cls.subclasses[cls.__name__] = cls 13 | 14 | def to_dict(self): 15 | return { 16 | "subclass_name": self.__class__.__name__, 17 | } 18 | 19 | @classmethod 20 | def from_dict(cls, config) -> "ChunkDB": 21 | subclass_name = config.pop( 22 | "subclass_name", None 23 | ) # Remove subclass_name from config 24 | subclass = cls.subclasses.get(subclass_name) 25 | if subclass: 26 | return subclass(**config) # Pass the modified config without subclass_name 27 | else: 28 | raise ValueError(f"Unknown subclass: {subclass_name}") 29 | 30 | @abstractmethod 31 | def add_document(self, doc_id: str, chunks: dict[int, dict[str, Any]], supp_id: str = "", metadata: dict = {}) -> None: 32 | """ 33 | Store all chunks for a given document. 34 | """ 35 | pass 36 | 37 | @abstractmethod 38 | def remove_document(self, doc_id: str) -> None: 39 | """ 40 | Remove all chunks and metadata associated with a given document ID. 41 | """ 42 | pass 43 | 44 | @abstractmethod 45 | def get_chunk_text(self, doc_id: str, chunk_index: int) -> Optional[str]: 46 | """ 47 | Retrieve a specific chunk from a given document ID. 48 | """ 49 | pass 50 | 51 | @abstractmethod 52 | def get_is_visual(self, doc_id: str, chunk_index: int) -> Optional[bool]: 53 | """ 54 | Retrieve the is_visual flag of a specific chunk from a given document ID. 55 | """ 56 | pass 57 | 58 | @abstractmethod 59 | def get_chunk_page_numbers(self, doc_id: str, chunk_index: int) -> Optional[tuple[int, int]]: 60 | """ 61 | Retrieve the page numbers of a specific chunk from a given document ID. 62 | """ 63 | pass 64 | 65 | @abstractmethod 66 | def get_document(self, doc_id: str) -> Optional[FormattedDocument]: 67 | """ 68 | Retrieve all chunks from a given document ID. 69 | """ 70 | pass 71 | 72 | @abstractmethod 73 | def get_document_title(self, doc_id: str, chunk_index: int) -> Optional[str]: 74 | """ 75 | Retrieve the document title of a specific chunk from a given document ID. 76 | """ 77 | pass 78 | 79 | @abstractmethod 80 | def get_document_summary(self, doc_id: str, chunk_index: int) -> Optional[str]: 81 | """ 82 | Retrieve the document summary of a specific chunk from a given document ID. 83 | """ 84 | pass 85 | 86 | @abstractmethod 87 | def get_section_title(self, doc_id: str, chunk_index: int) -> Optional[str]: 88 | """ 89 | Retrieve the section title of a specific chunk from a given document ID. 90 | """ 91 | pass 92 | 93 | @abstractmethod 94 | def get_section_summary(self, doc_id: str, chunk_index: int) -> Optional[str]: 95 | """ 96 | Retrieve the section summary of a specific chunk from a given document ID. 97 | """ 98 | pass 99 | 100 | @abstractmethod 101 | def get_all_doc_ids(self, supp_id: Optional[str] = None) -> list[str]: 102 | """ 103 | Retrieve all document IDs. 104 | """ 105 | pass 106 | 107 | @abstractmethod 108 | def delete(self) -> None: 109 | """ 110 | Delete the chunk database. 111 | """ 112 | pass 113 | -------------------------------------------------------------------------------- /dsrag/database/chunk/types.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import Optional 3 | from typing_extensions import TypedDict 4 | 5 | class FormattedDocument(TypedDict): 6 | id: str 7 | title: str 8 | content: Optional[str] 9 | summary: Optional[str] 10 | created_on: Optional[datetime] 11 | supp_id: Optional[str] 12 | metadata: Optional[dict] 13 | chunk_count: int 14 | -------------------------------------------------------------------------------- /dsrag/database/vector/__init__.py: -------------------------------------------------------------------------------- 1 | from .db import VectorDB 2 | from .types import VectorSearchResult, Vector, ChunkMetadata 3 | 4 | # Always import the basic DB as it has no dependencies 5 | from .basic_db import BasicVectorDB 6 | 7 | # Define what's in __all__ for "from dsrag.database.vector import *" 8 | __all__ = [ 9 | "VectorDB", 10 | "BasicVectorDB", 11 | "VectorSearchResult", 12 | "Vector", 13 | "ChunkMetadata" 14 | ] 15 | 16 | # Lazy load database modules to avoid importing all dependencies at once 17 | def __getattr__(name): 18 | """Lazily import database implementations only when accessed.""" 19 | if name == "ChromaDB": 20 | from .chroma_db import ChromaDB 21 | return ChromaDB 22 | elif name == "WeaviateVectorDB": 23 | from .weaviate_db import WeaviateVectorDB 24 | return WeaviateVectorDB 25 | elif name == "QdrantVectorDB": 26 | from .qdrant_db import QdrantVectorDB 27 | return QdrantVectorDB 28 | elif name == "MilvusDB": 29 | from .milvus_db import MilvusDB 30 | return MilvusDB 31 | elif name == "PostgresVectorDB": 32 | from .postgres_db import PostgresVectorDB 33 | return PostgresVectorDB 34 | elif name == "PineconeDB": 35 | from .pinecone_db import PineconeDB 36 | return PineconeDB 37 | else: 38 | raise AttributeError(f"module 'dsrag.database.vector' has no attribute '{name}'") 39 | -------------------------------------------------------------------------------- /dsrag/database/vector/basic_db.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from dsrag.database.vector.db import VectorDB 3 | from typing import Sequence, Optional 4 | from dsrag.database.vector.types import ChunkMetadata, Vector, VectorSearchResult 5 | from sklearn.metrics.pairwise import cosine_similarity 6 | import os 7 | import numpy as np 8 | from dsrag.utils.imports import faiss 9 | 10 | 11 | class BasicVectorDB(VectorDB): 12 | def __init__( 13 | self, kb_id: str, storage_directory: str = "~/dsRAG", use_faiss: bool = False 14 | ) -> None: 15 | self.kb_id = kb_id 16 | self.storage_directory = storage_directory 17 | self.use_faiss = use_faiss 18 | self.vector_storage_path = os.path.join( 19 | self.storage_directory, "vector_storage", f"{kb_id}.pkl" 20 | ) 21 | self.load() 22 | 23 | def add_vectors( 24 | self, vectors: Sequence[Vector], metadata: Sequence[ChunkMetadata] 25 | ) -> None: 26 | try: 27 | assert len(vectors) == len(metadata) 28 | except AssertionError: 29 | raise ValueError( 30 | "Error in add_vectors: the number of vectors and metadata items must be the same." 31 | ) 32 | self.vectors.extend(vectors) 33 | self.metadata.extend(metadata) 34 | self.save() 35 | 36 | def search(self, query_vector, top_k=10, metadata_filter: Optional[dict] = None) -> list[VectorSearchResult]: 37 | if not self.vectors: 38 | return [] 39 | 40 | if self.use_faiss: 41 | try: 42 | return self.search_faiss(query_vector, top_k) 43 | except Exception as e: 44 | print(f"Faiss search failed: {e}. Falling back to numpy search.") 45 | return self._fallback_search(query_vector, top_k) 46 | else: 47 | return self._fallback_search(query_vector, top_k) 48 | 49 | def _fallback_search(self, query_vector, top_k=10) -> list[VectorSearchResult]: 50 | """Fallback search method using numpy when faiss is not available.""" 51 | similarities = cosine_similarity([query_vector], self.vectors)[0] 52 | indexed_similarities = sorted( 53 | enumerate(similarities), key=lambda x: x[1], reverse=True 54 | ) 55 | results: list[VectorSearchResult] = [] 56 | for i, similarity in indexed_similarities[:top_k]: 57 | result = VectorSearchResult( 58 | doc_id=None, 59 | vector=None, 60 | metadata=self.metadata[i], 61 | similarity=similarity, 62 | ) 63 | results.append(result) 64 | return results 65 | 66 | def search_faiss(self, query_vector, top_k=10) -> list[VectorSearchResult]: 67 | # Limit top_k to the number of vectors we have - Faiss doesn't automatically handle this 68 | top_k = min(top_k, len(self.vectors)) 69 | 70 | # faiss expects 2D arrays of vectors 71 | vectors_array = ( 72 | np.array(self.vectors).astype("float32").reshape(len(self.vectors), -1) 73 | ) 74 | query_vector_array = np.array(query_vector).astype("float32").reshape(1, -1) 75 | 76 | try: 77 | # Access nested modules step by step 78 | contrib = faiss.contrib 79 | exhaustive_search = contrib.exhaustive_search 80 | knn_func = exhaustive_search.knn 81 | 82 | _, I = knn_func(query_vector_array, vectors_array, top_k) 83 | except AttributeError: 84 | # If the nested module structure doesn't work, try direct import as a fallback 85 | try: 86 | from faiss.contrib.exhaustive_search import knn as direct_knn 87 | _, I = direct_knn(query_vector_array, vectors_array, top_k) 88 | except (ImportError, AttributeError): 89 | # If both approaches fail, raise an error 90 | raise ImportError( 91 | "Could not import faiss.contrib.exhaustive_search.knn. " 92 | "Please ensure faiss is installed correctly." 93 | ) 94 | 95 | # I is a list of indices in the corpus_vectors array 96 | results: list[VectorSearchResult] = [] 97 | for i in I[0]: 98 | result = VectorSearchResult( 99 | doc_id=None, 100 | vector=None, 101 | metadata=self.metadata[i], 102 | similarity=cosine_similarity([query_vector], [self.vectors[i]])[0][0], 103 | ) 104 | results.append(result) 105 | return results 106 | 107 | def remove_document(self, doc_id): 108 | i = 0 109 | while i < len(self.metadata): 110 | if self.metadata[i]["doc_id"] == doc_id: 111 | del self.vectors[i] 112 | del self.metadata[i] 113 | else: 114 | i += 1 115 | self.save() 116 | 117 | def save(self): 118 | os.makedirs( 119 | os.path.dirname(self.vector_storage_path), exist_ok=True 120 | ) # Ensure the directory exists 121 | with open(self.vector_storage_path, "wb") as f: 122 | pickle.dump((self.vectors, self.metadata), f) 123 | 124 | def load(self): 125 | if os.path.exists(self.vector_storage_path): 126 | with open(self.vector_storage_path, "rb") as f: 127 | self.vectors, self.metadata = pickle.load(f) 128 | else: 129 | self.vectors = [] 130 | self.metadata = [] 131 | 132 | def delete(self): 133 | if os.path.exists(self.vector_storage_path): 134 | os.remove(self.vector_storage_path) 135 | 136 | def to_dict(self): 137 | return { 138 | **super().to_dict(), 139 | "kb_id": self.kb_id, 140 | "storage_directory": self.storage_directory, 141 | "use_faiss": self.use_faiss, 142 | } 143 | -------------------------------------------------------------------------------- /dsrag/database/vector/chroma_db.py: -------------------------------------------------------------------------------- 1 | from dsrag.database.vector.db import VectorDB 2 | from dsrag.database.vector.types import VectorSearchResult, MetadataFilter 3 | from typing import Optional 4 | import os 5 | import numpy as np 6 | from dsrag.utils.imports import LazyLoader 7 | 8 | # Lazy load chromadb 9 | chromadb = LazyLoader("chromadb") 10 | 11 | 12 | def format_metadata_filter(metadata_filter: MetadataFilter) -> dict: 13 | """ 14 | Format the metadata filter to be used in the ChromaDB query method. 15 | 16 | Args: 17 | metadata_filter (dict): The metadata filter. 18 | 19 | Returns: 20 | dict: The formatted metadata filter. 21 | """ 22 | 23 | field = metadata_filter["field"] 24 | operator = metadata_filter["operator"] 25 | value = metadata_filter["value"] 26 | 27 | operator_mapping = { 28 | "equals": "$eq", 29 | "not_equals": "$ne", 30 | "in": "$in", 31 | "not_in": "$nin", 32 | "greater_than": "$gt", 33 | "less_than": "$lt", 34 | "greater_than_equals": "$gte", 35 | "less_than_equals": "$lte", 36 | } 37 | 38 | formatted_operator = operator_mapping[operator] 39 | formatted_metadata_filter = {field: {formatted_operator: value}} 40 | 41 | return formatted_metadata_filter 42 | 43 | 44 | class ChromaDB(VectorDB): 45 | 46 | def __init__(self, kb_id: str, storage_directory: str = "~/dsRAG"): 47 | self.kb_id = kb_id 48 | self.storage_directory = os.path.expanduser(storage_directory) 49 | self.vector_storage_path = os.path.join( 50 | self.storage_directory, "vector_storage" 51 | ) 52 | self.client = chromadb.PersistentClient(path=self.vector_storage_path) 53 | self.collection = self.client.get_or_create_collection( 54 | kb_id, metadata={"hnsw:space": "cosine"} 55 | ) 56 | 57 | def get_num_vectors(self): 58 | return self.collection.count() 59 | 60 | def add_vectors(self, vectors: list, metadata: list): 61 | 62 | # Convert NumPy arrays to lists 63 | vectors_as_lists = [vector.tolist() if isinstance(vector, np.ndarray) else vector for vector in vectors] 64 | try: 65 | assert len(vectors_as_lists) == len(metadata) 66 | except AssertionError: 67 | raise ValueError( 68 | "Error in add_vectors: the number of vectors and metadata items must be the same." 69 | ) 70 | 71 | # Create the ids from the metadata, defined as {metadata["doc_id"]}_{metadata["chunk_index"]} 72 | ids = [f"{meta['doc_id']}_{meta['chunk_index']}" for meta in metadata] 73 | 74 | self.collection.add(embeddings=vectors_as_lists, metadatas=metadata, ids=ids) 75 | 76 | def search(self, query_vector, top_k=10, metadata_filter: Optional[MetadataFilter] = None) -> list[VectorSearchResult]: 77 | 78 | num_vectors = self.get_num_vectors() 79 | if num_vectors == 0: 80 | return [] 81 | # raise ValueError('No vectors stored in the database.') 82 | 83 | if metadata_filter: 84 | formatted_metadata_filter = format_metadata_filter(metadata_filter) 85 | 86 | if isinstance(query_vector, np.ndarray): 87 | query_vector = query_vector.tolist() 88 | query_results = self.collection.query( 89 | query_embeddings=[query_vector], 90 | n_results=top_k, 91 | include=["distances", "metadatas"], 92 | where=formatted_metadata_filter if metadata_filter else None 93 | ) 94 | 95 | metadata = query_results["metadatas"][0] 96 | distances = query_results["distances"][0] 97 | 98 | results: list[VectorSearchResult] = [] 99 | for _, (distance, metadata) in enumerate(zip(distances, metadata)): 100 | results.append( 101 | VectorSearchResult( 102 | doc_id=metadata["doc_id"], 103 | vector=None, 104 | metadata=metadata, 105 | similarity=1 - distance, 106 | ) 107 | ) 108 | 109 | results = sorted(results, key=lambda x: x["similarity"], reverse=True) 110 | 111 | return results 112 | 113 | def remove_document(self, doc_id: str): 114 | data = self.collection.get(where={"doc_id": doc_id}) 115 | doc_ids = data["ids"] 116 | self.collection.delete(ids=doc_ids) 117 | 118 | def delete(self): 119 | self.client.delete_collection(name=self.kb_id) 120 | 121 | def to_dict(self): 122 | return { 123 | **super().to_dict(), 124 | "kb_id": self.kb_id, 125 | } 126 | -------------------------------------------------------------------------------- /dsrag/database/vector/db.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Sequence, Optional 3 | from dsrag.database.vector.types import ChunkMetadata, Vector, VectorSearchResult 4 | 5 | 6 | class VectorDB(ABC): 7 | subclasses = {} 8 | 9 | def __init_subclass__(cls, **kwargs): 10 | super().__init_subclass__(**kwargs) 11 | cls.subclasses[cls.__name__] = cls 12 | 13 | def to_dict(self): 14 | return { 15 | "subclass_name": self.__class__.__name__, 16 | } 17 | 18 | @classmethod 19 | def from_dict(cls, config): 20 | subclass_name = config.pop( 21 | "subclass_name", None 22 | ) # Remove subclass_name from config 23 | subclass = cls.subclasses.get(subclass_name) 24 | if subclass: 25 | return subclass(**config) # Pass the modified config without subclass_name 26 | else: 27 | raise ValueError(f"Unknown subclass: {subclass_name}") 28 | 29 | @abstractmethod 30 | def add_vectors( 31 | self, vectors: Sequence[Vector], metadata: Sequence[ChunkMetadata] 32 | ) -> None: 33 | """ 34 | Store a list of vectors with associated metadata. 35 | """ 36 | pass 37 | 38 | @abstractmethod 39 | def remove_document(self, doc_id) -> None: 40 | """ 41 | Remove all vectors and metadata associated with a given document ID. 42 | """ 43 | pass 44 | 45 | @abstractmethod 46 | def search(self, query_vector, top_k: int=10, metadata_filter: Optional[dict] = None) -> list[VectorSearchResult]: 47 | """ 48 | Retrieve the top-k closest vectors to a given query vector. 49 | - needs to return results as list of dictionaries in this format: 50 | { 51 | 'metadata': 52 | { 53 | 'doc_id': doc_id, 54 | 'chunk_index': chunk_index, 55 | 'chunk_header': chunk_header, 56 | 'chunk_text': chunk_text 57 | }, 58 | 'similarity': similarity, 59 | } 60 | """ 61 | pass 62 | 63 | @abstractmethod 64 | def delete(self) -> None: 65 | """ 66 | Delete the vector database. 67 | """ 68 | pass 69 | -------------------------------------------------------------------------------- /dsrag/database/vector/milvus_db.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional, Sequence 3 | 4 | from dsrag.database.vector import VectorSearchResult 5 | from dsrag.database.vector.db import VectorDB 6 | from dsrag.database.vector.types import MetadataFilter, Vector, ChunkMetadata 7 | from dsrag.utils.imports import LazyLoader 8 | 9 | # Lazy load pymilvus 10 | pymilvus = LazyLoader("pymilvus") 11 | 12 | def _convert_metadata_to_expr(metadata_filter: MetadataFilter) -> str: 13 | """ 14 | Convert the metadata filter to the expression used in the Milvus query method. 15 | 16 | Args: 17 | metadata_filter (dict): The metadata filter. 18 | 19 | Returns: 20 | str: The formatted expression. 21 | """ 22 | 23 | if not metadata_filter: 24 | return "" 25 | field = metadata_filter["field"] 26 | operator = metadata_filter["operator"] 27 | value = metadata_filter["value"] 28 | 29 | operator_mapping = { 30 | "equals": "==", 31 | "not_equals": "!=", 32 | "in": "in", 33 | "not_in": "not in", 34 | "greater_than": ">", 35 | "less_than": "<", 36 | "greater_than_equals": ">=", 37 | "less_than_equals": "<=", 38 | } 39 | 40 | formatted_operator = operator_mapping[operator] 41 | if isinstance(value, str): 42 | value = f'"{value}"' 43 | formatted_metadata_filter = f"metadata['{field}'] {formatted_operator} {value}" 44 | 45 | return formatted_metadata_filter 46 | 47 | 48 | class MilvusDB(VectorDB): 49 | def __init__(self, kb_id: str, storage_directory: str = '~/dsRAG', 50 | dimension: int = 768): 51 | """ 52 | Initialize a MilvusDB instance. 53 | 54 | Args: 55 | kb_id (str): The knowledge base ID. 56 | storage_directory (str, optional): The storage directory. Defaults to '~/dsRAG'. 57 | dimension (int, optional): The dimension of the vectors. Defaults to 768. 58 | """ 59 | self.kb_id = kb_id 60 | self.storage_directory = os.path.expanduser(storage_directory) 61 | self.collection_name = kb_id 62 | self.dimension = dimension 63 | 64 | # Initialize Milvus client 65 | self.client = pymilvus.MilvusClient(uri='http://localhost:19530') 66 | 67 | # Create collection if it doesn't exist 68 | try: 69 | if not self.client.has_collection(self.collection_name): 70 | self._create_collection(self.collection_name, dimension) 71 | except Exception as e: 72 | print(f"Error initializing Milvus: {e}") 73 | raise 74 | 75 | def _create_collection(self, collection_name: str, dimension: int = 768): 76 | """ 77 | Create a collection in Milvus. 78 | 79 | Args: 80 | collection_name (str): The name of the collection. 81 | dimension (int, optional): The dimension of the vectors. Defaults to 768. 82 | """ 83 | schema = { 84 | "id": {"data_type": pymilvus.DataType.VARCHAR, "is_primary": True, "max_length": 100}, 85 | "doc_id": {"data_type": pymilvus.DataType.VARCHAR, "max_length": 100}, 86 | "chunk_index": {"data_type": pymilvus.DataType.INT64}, 87 | "text": {"data_type": pymilvus.DataType.VARCHAR, "max_length": 65535}, 88 | "embedding": {"data_type": pymilvus.DataType.FLOAT_VECTOR, "dim": dimension} 89 | } 90 | 91 | self.client.create_collection( 92 | collection_name=collection_name, 93 | schema=schema, 94 | index_params={ 95 | "field_name": "embedding", 96 | "index_type": "HNSW", 97 | "metric_type": "COSINE", 98 | "params": {"M": 8, "efConstruction": 64} 99 | } 100 | ) 101 | 102 | def add_vectors(self, vectors: Sequence[Vector], metadata: Sequence[ChunkMetadata]): 103 | try: 104 | assert len(vectors) == len(metadata) 105 | except AssertionError: 106 | raise ValueError( 107 | 'Error in add_vectors: the number of vectors and metadata items must be the same.') 108 | 109 | # Create the ids from the metadata, defined as {metadata["doc_id"]}_{metadata["chunk_index"]} 110 | ids = [f"{meta['doc_id']}_{meta['chunk_index']}" for meta in metadata] 111 | 112 | data = [] 113 | for i, vector in enumerate(vectors): 114 | data.append({ 115 | 'doc_id': ids[i], 116 | 'vector': vector, 117 | 'metadata': metadata[i] 118 | }) 119 | 120 | self.client.upsert( 121 | collection_name=self.kb_id, 122 | data=data 123 | ) 124 | 125 | def search(self, query_vector, top_k: int=10, metadata_filter: Optional[dict] = None) -> list[VectorSearchResult]: 126 | query_results = self.client.search( 127 | collection_name=self.kb_id, 128 | data=[query_vector], 129 | filter=_convert_metadata_to_expr(metadata_filter), 130 | limit=top_k, 131 | output_fields=["*"] 132 | )[0] 133 | results: list[VectorSearchResult] = [] 134 | for res in query_results: 135 | results.append( 136 | VectorSearchResult( 137 | doc_id=res['entity']['doc_id'], 138 | metadata=res['entity']['metadata'], 139 | similarity=res['distance'], 140 | vector=res['entity']['vector'], 141 | ) 142 | ) 143 | 144 | results = sorted(results, key=lambda x: x['similarity'], reverse=True) 145 | 146 | return results 147 | 148 | def remove_document(self, doc_id): 149 | self.client.delete( 150 | collection_name=self.kb_id, 151 | filter=f'metadata["doc_id"] == "{doc_id}"' 152 | ) 153 | 154 | def get_num_vectors(self): 155 | res = self.client.query( 156 | collection_name=self.kb_id, 157 | output_fields=["count(*)"] 158 | ) 159 | return res[0]["count(*)"] 160 | 161 | def delete(self): 162 | self.client.drop_collection(collection_name=self.kb_id) 163 | 164 | def to_dict(self): 165 | return { 166 | **super().to_dict(), 167 | 'kb_id': self.kb_id, 168 | } 169 | -------------------------------------------------------------------------------- /dsrag/database/vector/types.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence, Union 2 | from typing_extensions import TypedDict 3 | 4 | 5 | class ChunkMetadata(TypedDict): 6 | doc_id: str 7 | chunk_text: str 8 | chunk_index: int 9 | chunk_header: str 10 | 11 | 12 | Vector = Union[Sequence[float], Sequence[int]] 13 | 14 | 15 | class VectorSearchResult(TypedDict): 16 | doc_id: Optional[str] 17 | vector: Optional[Vector] 18 | metadata: ChunkMetadata 19 | similarity: float 20 | 21 | class MetadataFilter(TypedDict): 22 | field: str 23 | operator: str # Can be one of the following: 'equals', 'not_equals', 'in', 'not_in', 'greater_than', 'less_than', 'greater_than_equals', 'less_than_equals' 24 | value: Union[str, int, float, list[str], list[int], list[float]] -------------------------------------------------------------------------------- /dsrag/dsparse/MANIFEST.in: -------------------------------------------------------------------------------- 1 | exclude tests/* -------------------------------------------------------------------------------- /dsrag/dsparse/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | # Configure the dsparse logger with a NullHandler to prevent "No handler found" warnings 4 | logger = logging.getLogger("dsrag.dsparse") 5 | logger.addHandler(logging.NullHandler()) -------------------------------------------------------------------------------- /dsrag/dsparse/file_parsing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-Star-AI/dsRAG/b326c1d0930c9a80ff486ab0df5e771e788e3234/dsrag/dsparse/file_parsing/__init__.py -------------------------------------------------------------------------------- /dsrag/dsparse/file_parsing/element_types.py: -------------------------------------------------------------------------------- 1 | from ..models.types import ElementType 2 | 3 | 4 | def get_visual_elements_as_str(elements: list[ElementType]) -> str: 5 | visual_elements = [element["name"] for element in elements if element["is_visual"]] 6 | last_element = visual_elements[-1] 7 | if len(visual_elements) > 1: 8 | visual_elements[-1] = f"and {last_element}" 9 | return ", ".join(visual_elements) 10 | 11 | def get_non_visual_elements_as_str(elements: list[ElementType]) -> str: 12 | non_visual_elements = [element["name"] for element in elements if not element["is_visual"]] 13 | last_element = non_visual_elements[-1] 14 | if len(non_visual_elements) > 1: 15 | non_visual_elements[-1] = f"and {last_element}" 16 | return ", ".join(non_visual_elements) 17 | 18 | def get_num_visual_elements(elements: list[ElementType]) -> int: 19 | return len([element for element in elements if element["is_visual"]]) 20 | 21 | def get_num_non_visual_elements(elements: list[ElementType]) -> int: 22 | return len([element for element in elements if not element["is_visual"]]) 23 | 24 | def get_element_description_block(elements: list[ElementType]) -> str: 25 | element_blocks = [] 26 | for element in elements: 27 | element_block = ELEMENT_PROMPT.format( 28 | element_name=element["name"], 29 | element_instructions=element["instructions"], 30 | ) 31 | element_blocks.append(element_block) 32 | return "\n".join(element_blocks) 33 | 34 | ELEMENT_PROMPT = """ 35 | - {element_name} 36 | - {element_instructions} 37 | """.strip() 38 | 39 | default_element_types = [ 40 | { 41 | "name": "NarrativeText", 42 | "instructions": "This is the main text content of the page, including paragraphs, lists, titles, and any other text content that is not part of one of the other more specialized element types. Not all pages have narrative text, but most do. Be sure to use Markdown formatting for the text content. This includes using tags like # for headers, * for lists, etc. Make sure your header tags are properly nested and that your lists are properly formatted.", 43 | "is_visual": False, 44 | }, 45 | { 46 | "name": "Figure", 47 | "instructions": "This covers charts, graphs, diagrams, etc. Associated titles, legends, axis titles, etc. should be considered to be part of the figure.", 48 | "is_visual": True, 49 | }, 50 | { 51 | "name": "Image", 52 | "instructions": "This is any visual content on the page that doesn't fit into any of the other categories. This could include photos, illustrations, etc. Any title or captions associated with the image should be considered part of the image.", 53 | "is_visual": True, 54 | }, 55 | { 56 | "name": "Table", 57 | "instructions": "This covers any tabular data arrangement on the page, including simple and complex tables. Any titles, captions, or notes associated with the table should be considered part of the table element.", 58 | "is_visual": True, 59 | }, 60 | { 61 | "name": "Equation", 62 | "instructions": "This covers mathematical equations, formulas, and expressions that appear on the page. Associated equation numbers or labels should be considered part of the equation.", 63 | "is_visual": True, 64 | }, 65 | { 66 | "name": "Header", 67 | "instructions": "This is the header of the page, which would be located at the very top of the page and may include things like page numbers. The text in a Header element is usually in a very small font size. This is NOT the same things as a Markdown header or heading. You should never user more than one header element per page. Not all pages have a header. Note that headers are not the same as titles or subtitles within the main text content of the page. Those should be included in NarrativeText elements.", 68 | "is_visual": False, 69 | }, 70 | { 71 | "name": "Footnote", 72 | "instructions": "Footnotes should always be included as a separate element from the main text content as they aren't part of the main linear reading flow of the page. Not all pages have footnotes.", 73 | "is_visual": False, 74 | }, 75 | { 76 | "name": "Footer", 77 | "instructions": "This is the footer of the page, which would be located at the very bottom of the page. You should never user more than one footer element per page. Not all pages have a footer, but when they do it is always the very last element on the page.", 78 | "is_visual": False, 79 | } 80 | ] 81 | 82 | -------------------------------------------------------------------------------- /dsrag/dsparse/file_parsing/non_vlm_file_parsing.py: -------------------------------------------------------------------------------- 1 | import PyPDF2 2 | import docx2txt 3 | 4 | def extract_text_from_pdf(file_path: str) -> tuple[str, list]: 5 | with open(file_path, 'rb') as file: 6 | # Create a PDF reader object 7 | pdf_reader = PyPDF2.PdfReader(file) 8 | 9 | # Initialize an empty string to store the extracted text 10 | extracted_text = "" 11 | 12 | # Loop through all the pages in the PDF file 13 | pages = [] 14 | for page_num in range(len(pdf_reader.pages)): 15 | # Extract the text from the current page 16 | page_text = pdf_reader.pages[page_num].extract_text() 17 | pages.append(page_text) 18 | 19 | # Add the extracted text to the final text 20 | extracted_text += page_text 21 | 22 | return extracted_text, pages 23 | 24 | def extract_text_from_docx(file_path: str) -> str: 25 | return docx2txt.process(file_path) 26 | 27 | def parse_file_no_vlm(file_path: str) -> str: 28 | pdf_pages = None 29 | if file_path.endswith(".pdf"): 30 | text, pdf_pages = extract_text_from_pdf(file_path) 31 | elif file_path.endswith(".docx"): 32 | text = extract_text_from_docx(file_path) 33 | elif file_path.endswith(".txt") or file_path.endswith(".md"): 34 | with open(file_path, "r") as file: 35 | text = file.read() 36 | else: 37 | raise ValueError( 38 | "Unsupported file format. Only .txt, .md, .pdf and .docx files are supported." 39 | ) 40 | 41 | return text, pdf_pages -------------------------------------------------------------------------------- /dsrag/dsparse/file_parsing/vlm.py: -------------------------------------------------------------------------------- 1 | import PIL.Image 2 | import os 3 | import io 4 | from ..utils.imports import genai, vertexai 5 | 6 | def make_llm_call_gemini(image_path: str, system_message: str, model: str = "gemini-2.0-flash", response_schema: dict = None, max_tokens: int = 4000, temperature: float = 0.5) -> str: 7 | genai.configure(api_key=os.environ["GEMINI_API_KEY"]) 8 | generation_config = { 9 | "temperature": temperature, 10 | "response_mime_type": "application/json", 11 | "max_output_tokens": max_tokens 12 | } 13 | if response_schema is not None: 14 | generation_config["response_schema"] = response_schema 15 | 16 | model = genai.GenerativeModel(model) 17 | try: 18 | image = PIL.Image.open(image_path) 19 | # Compress image if needed 20 | compressed_image = compress_image(image) 21 | response = model.generate_content( 22 | [ 23 | compressed_image, 24 | system_message 25 | ], 26 | generation_config=generation_config 27 | ) 28 | return response.text 29 | finally: 30 | # Ensure image is closed even if an error occurs 31 | if 'image' in locals(): 32 | image.close() 33 | 34 | def make_llm_call_vertex(image_path: str, system_message: str, model: str, project_id: str, location: str, response_schema: dict = None, max_tokens: int = 4000, temperature: float = 0.5) -> str: 35 | """ 36 | This function calls the Vertex AI Gemini API (not to be confused with the Gemini API) with an image and a system message and returns the response text. 37 | """ 38 | vertexai.init(project=project_id, location=location) 39 | model = vertexai.generative_models.GenerativeModel(model) 40 | 41 | if response_schema is not None: 42 | generation_config = vertexai.generative_models.GenerationConfig(temperature=temperature, max_output_tokens=max_tokens, response_mime_type="application/json", response_schema=response_schema) 43 | else: 44 | generation_config = vertexai.generative_models.GenerationConfig(temperature=temperature, max_output_tokens=max_tokens) 45 | 46 | response = model.generate_content( 47 | [ 48 | vertexai.generative_models.Part.from_image(vertexai.generative_models.Image.load_from_file(image_path)), 49 | system_message, 50 | ], 51 | generation_config=generation_config, 52 | ) 53 | return response.text 54 | 55 | def compress_image(image: PIL.Image.Image, max_size_bytes: int = 1097152, quality: int = 85) -> tuple[PIL.Image.Image, int]: 56 | """ 57 | Compress image if it exceeds file size while maintaining aspect ratio. 58 | 59 | Args: 60 | image: PIL Image object 61 | max_size_bytes: Maximum file size in bytes (default ~1MB) 62 | quality: Initial JPEG quality (0-100) 63 | 64 | Returns: 65 | Tuple of (compressed PIL Image object, final quality used) 66 | """ 67 | output = io.BytesIO() 68 | 69 | # Initial compression 70 | image.save(output, format='JPEG', quality=quality) 71 | 72 | # Reduce quality if file is too large 73 | while output.tell() > max_size_bytes and quality > 10: 74 | output = io.BytesIO() 75 | quality -= 5 76 | image.save(output, format='JPEG', quality=quality) 77 | 78 | # If reducing quality didn't work, reduce dimensions 79 | if output.tell() > max_size_bytes: 80 | while output.tell() > max_size_bytes: 81 | width, height = image.size 82 | image = image.resize((int(width*0.9), int(height*0.9)), PIL.Image.Resampling.LANCZOS) 83 | output = io.BytesIO() 84 | image.save(output, format='JPEG', quality=quality) 85 | 86 | # Convert back to PIL Image 87 | output.seek(0) 88 | compressed_image = PIL.Image.open(output) 89 | compressed_image.load() # This is important to ensure the BytesIO can be closed 90 | return compressed_image -------------------------------------------------------------------------------- /dsrag/dsparse/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-Star-AI/dsRAG/b326c1d0930c9a80ff486ab0df5e771e788e3234/dsrag/dsparse/models/__init__.py -------------------------------------------------------------------------------- /dsrag/dsparse/models/types.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, TypedDict 2 | from ..file_parsing.file_system import FileSystem 3 | 4 | class ElementType(TypedDict): 5 | name: str 6 | instructions: str 7 | is_visual: bool 8 | 9 | class Element(TypedDict): 10 | type: str 11 | content: str 12 | page_number: Optional[int] 13 | 14 | class Line(TypedDict): 15 | content: str 16 | element_type: str 17 | page_number: Optional[int] 18 | is_visual: Optional[bool] 19 | 20 | class Section(TypedDict): 21 | title: str 22 | start: int 23 | end: int 24 | content: str 25 | 26 | class Chunk(TypedDict): 27 | line_start: int 28 | line_end: int 29 | content: str 30 | page_start: int 31 | page_end: int 32 | section_index: int 33 | is_visual: bool 34 | 35 | class VLMConfig(TypedDict): 36 | provider: Optional[str] 37 | model: Optional[str] 38 | project_id: Optional[str] 39 | location: Optional[str] 40 | exclude_elements: Optional[list[str]] 41 | element_types: Optional[list[ElementType]] 42 | max_workers: Optional[int] = None 43 | 44 | class SemanticSectioningConfig(TypedDict): 45 | use_semantic_sectioning: Optional[bool] 46 | llm_provider: Optional[str] 47 | model: Optional[str] 48 | language: Optional[str] 49 | 50 | class ChunkingConfig(TypedDict): 51 | chunk_size: Optional[int] 52 | min_length_for_chunking: Optional[int] 53 | 54 | class FileParsingConfig(TypedDict): 55 | use_vlm: Optional[bool] 56 | vlm_config: Optional[VLMConfig] 57 | always_save_page_images: Optional[bool] -------------------------------------------------------------------------------- /dsrag/dsparse/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "dsparse" 7 | version = "0.0.1" 8 | description = "Multi-modal file parsing and chunking" 9 | readme = "README.md" 10 | authors = [{ name = "Zach McCormick", email = "zach@d-star.ai" }, { name = "Nick McCormick", email = "nick@d-star.ai" }] 11 | license = { file = "LICENSE" } 12 | classifiers=[ 13 | "Development Status :: 4 - Beta", 14 | "Intended Audience :: Developers", 15 | "Intended Audience :: Science/Research", 16 | "License :: OSI Approved :: MIT License", 17 | "Operating System :: OS Independent", 18 | "Programming Language :: Python :: 3.9", 19 | "Programming Language :: Python :: 3.10", 20 | "Programming Language :: Python :: 3.11", 21 | "Programming Language :: Python :: 3.12", 22 | "Topic :: Database", 23 | "Topic :: Software Development", 24 | "Topic :: Software Development :: Libraries", 25 | "Topic :: Software Development :: Libraries :: Application Frameworks", 26 | "Topic :: Software Development :: Libraries :: Python Modules" 27 | ] 28 | requires-python = ">=3.9" 29 | dynamic = ["dependencies"] 30 | 31 | [tool.setuptools.dynamic] 32 | dependencies = {file = ["requirements.txt"]} 33 | 34 | [project.optional-dependencies] 35 | s3 = ["boto3>=1.34.142", "botocore>=1.34.142"] 36 | vertexai = ["vertexai>=1.70.0"] 37 | instructor = ["instructor>=0.4.0"] 38 | openai = ["openai>=1.0.0"] 39 | anthropic = ["anthropic>=0.5.0"] 40 | google-generativeai = ["google-generativeai>=0.3.0"] 41 | 42 | # Convenience groups 43 | all-models = [ 44 | "instructor>=0.4.0", 45 | "openai>=1.0.0", 46 | "anthropic>=0.5.0", 47 | "google-generativeai>=0.3.0", 48 | "vertexai>=1.70.0" 49 | ] 50 | 51 | # Complete installation with all optional dependencies 52 | all = [ 53 | "boto3>=1.34.142", 54 | "botocore>=1.34.142", 55 | "instructor>=0.4.0", 56 | "openai>=1.0.0", 57 | "anthropic>=0.5.0", 58 | "google-generativeai>=0.3.0", 59 | "vertexai>=1.70.0" 60 | ] 61 | 62 | [project.urls] 63 | Homepage = "https://github.com/D-Star-AI/dsRAG/blob/main/dsrag/dsparse/README.md" 64 | Documentation = "https://github.com/D-Star-AI/dsRAG/blob/main/dsrag/dsparse/README.md" 65 | Contact = "https://github.com/D-Star-AI/dsRAG/blob/main/dsrag/dsparse/README.md" 66 | 67 | [tool.setuptools.packages.find] 68 | where = ["."] 69 | include = ["dsparse"] -------------------------------------------------------------------------------- /dsrag/dsparse/requirements.in: -------------------------------------------------------------------------------- 1 | # dsparse specific dependencies 2 | anthropic>=0.37.1 3 | docx2txt>=0.8 4 | google-generativeai>=0.8.3 5 | instructor>=1.7.0 # Required for streaming functionality 6 | langchain-text-splitters>=0.3.0 7 | langchain-core>=0.3.0 8 | openai>=1.52.2 9 | pdf2image>=1.16.0 10 | pillow>=10.0.0 11 | pydantic>=2.8.2 12 | PyPDF2>=3.0.1 13 | # tenacity and jiter are unpinned to allow instructor to select compatible versions -------------------------------------------------------------------------------- /dsrag/dsparse/requirements.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with Python 3.12 3 | # by the following command: 4 | # 5 | # pip-compile requirements.in 6 | # 7 | aiohappyeyeballs==2.4.3 8 | # via aiohttp 9 | aiohttp==3.10.10 10 | # via instructor 11 | aiosignal==1.3.1 12 | # via aiohttp 13 | annotated-types==0.7.0 14 | # via pydantic 15 | anthropic==0.37.1 16 | # via -r requirements.in 17 | anyio==4.6.2.post1 18 | # via 19 | # anthropic 20 | # httpx 21 | # openai 22 | attrs==24.2.0 23 | # via aiohttp 24 | cachetools==5.5.0 25 | # via google-auth 26 | certifi==2024.8.30 27 | # via 28 | # httpcore 29 | # httpx 30 | # requests 31 | charset-normalizer==3.4.0 32 | # via requests 33 | click==8.1.7 34 | # via typer 35 | distro==1.9.0 36 | # via 37 | # anthropic 38 | # openai 39 | docstring-parser==0.16 40 | # via instructor 41 | docx2txt==0.8 42 | # via -r requirements.in 43 | filelock==3.16.1 44 | # via huggingface-hub 45 | frozenlist==1.5.0 46 | # via 47 | # aiohttp 48 | # aiosignal 49 | fsspec==2024.10.0 50 | # via huggingface-hub 51 | google-ai-generativelanguage==0.6.10 52 | # via google-generativeai 53 | google-api-core[grpc]==2.21.0 54 | # via 55 | # google-ai-generativelanguage 56 | # google-api-python-client 57 | # google-generativeai 58 | google-api-python-client==2.149.0 59 | # via google-generativeai 60 | google-auth==2.35.0 61 | # via 62 | # google-ai-generativelanguage 63 | # google-api-core 64 | # google-api-python-client 65 | # google-auth-httplib2 66 | # google-generativeai 67 | google-auth-httplib2==0.2.0 68 | # via google-api-python-client 69 | google-generativeai==0.8.3 70 | # via -r requirements.in 71 | googleapis-common-protos==1.65.0 72 | # via 73 | # google-api-core 74 | # grpcio-status 75 | grpcio==1.67.0 76 | # via 77 | # google-api-core 78 | # grpcio-status 79 | grpcio-status==1.67.0 80 | # via google-api-core 81 | h11==0.14.0 82 | # via httpcore 83 | httpcore==1.0.6 84 | # via httpx 85 | httplib2==0.22.0 86 | # via 87 | # google-api-python-client 88 | # google-auth-httplib2 89 | httpx==0.27.2 90 | # via 91 | # anthropic 92 | # langsmith 93 | # openai 94 | huggingface-hub==0.26.1 95 | # via tokenizers 96 | idna==3.10 97 | # via 98 | # anyio 99 | # httpx 100 | # requests 101 | # yarl 102 | instructor==1.7.3 103 | # via -r requirements.in 104 | jinja2==3.1.6 105 | # via instructor 106 | jiter==0.8.2 107 | # via 108 | # anthropic 109 | # instructor 110 | # openai 111 | jsonpatch==1.33 112 | # via langchain-core 113 | jsonpointer==3.0.0 114 | # via jsonpatch 115 | langchain-core==0.3.42 116 | # via 117 | # -r requirements.in 118 | # langchain-text-splitters 119 | langchain-text-splitters==0.3.6 120 | # via -r requirements.in 121 | langsmith==0.1.137 122 | # via langchain-core 123 | markdown-it-py==3.0.0 124 | # via rich 125 | markupsafe==3.0.2 126 | # via jinja2 127 | mdurl==0.1.2 128 | # via markdown-it-py 129 | multidict==6.1.0 130 | # via 131 | # aiohttp 132 | # yarl 133 | openai==1.52.2 134 | # via 135 | # -r requirements.in 136 | # instructor 137 | orjson==3.10.10 138 | # via langsmith 139 | packaging==24.1 140 | # via 141 | # huggingface-hub 142 | # langchain-core 143 | pdf2image==1.16.0 144 | # via -r requirements.in 145 | pillow==10.0.0 146 | # via 147 | # -r requirements.in 148 | # pdf2image 149 | propcache==0.2.0 150 | # via yarl 151 | proto-plus==1.25.0 152 | # via 153 | # google-ai-generativelanguage 154 | # google-api-core 155 | protobuf==5.28.3 156 | # via 157 | # google-ai-generativelanguage 158 | # google-api-core 159 | # google-generativeai 160 | # googleapis-common-protos 161 | # grpcio-status 162 | # proto-plus 163 | pyasn1==0.6.1 164 | # via 165 | # pyasn1-modules 166 | # rsa 167 | pyasn1-modules==0.4.1 168 | # via google-auth 169 | pydantic==2.8.2 170 | # via 171 | # -r requirements.in 172 | # anthropic 173 | # google-generativeai 174 | # instructor 175 | # langchain-core 176 | # langsmith 177 | # openai 178 | pydantic-core==2.20.1 179 | # via 180 | # instructor 181 | # pydantic 182 | pygments==2.18.0 183 | # via rich 184 | pyparsing==3.2.0 185 | # via httplib2 186 | pypdf2==3.0.1 187 | # via -r requirements.in 188 | pyyaml==6.0.2 189 | # via 190 | # huggingface-hub 191 | # langchain-core 192 | requests==2.32.3 193 | # via 194 | # google-api-core 195 | # huggingface-hub 196 | # instructor 197 | # langsmith 198 | # requests-toolbelt 199 | requests-toolbelt==1.0.0 200 | # via langsmith 201 | rich==13.9.3 202 | # via 203 | # instructor 204 | # typer 205 | rsa==4.9 206 | # via google-auth 207 | shellingham==1.5.4 208 | # via typer 209 | sniffio==1.3.1 210 | # via 211 | # anthropic 212 | # anyio 213 | # httpx 214 | # openai 215 | tenacity==9.0.0 216 | # via 217 | # instructor 218 | # langchain-core 219 | tokenizers==0.20.1 220 | # via anthropic 221 | tqdm==4.66.5 222 | # via 223 | # google-generativeai 224 | # huggingface-hub 225 | # openai 226 | typer==0.12.5 227 | # via instructor 228 | typing-extensions==4.12.2 229 | # via 230 | # anthropic 231 | # google-generativeai 232 | # huggingface-hub 233 | # langchain-core 234 | # openai 235 | # pydantic 236 | # pydantic-core 237 | # typer 238 | uritemplate==4.1.1 239 | # via google-api-python-client 240 | urllib3==2.2.3 241 | # via requests 242 | yarl==1.16.0 243 | # via aiohttp 244 | -------------------------------------------------------------------------------- /dsrag/dsparse/sectioning_and_chunking/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-Star-AI/dsRAG/b326c1d0930c9a80ff486ab0df5e771e788e3234/dsrag/dsparse/sectioning_and_chunking/__init__.py -------------------------------------------------------------------------------- /dsrag/dsparse/tests/integration/test_vlm_file_parsing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import unittest 4 | import shutil 5 | 6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..'))) 7 | sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) 8 | from dsparse.main import parse_and_chunk 9 | from dsparse.models.types import Chunk, Section 10 | from dsparse.file_parsing.file_system import LocalFileSystem 11 | 12 | 13 | class TestVLMFileParsing(unittest.TestCase): 14 | 15 | @classmethod 16 | def setUpClass(self): 17 | self.save_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../data/dsparse_output')) 18 | self.file_system = LocalFileSystem(base_path=self.save_path) 19 | self.kb_id = "test_kb" 20 | self.doc_id = "test_doc" 21 | self.test_data_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../tests/data/mck_energy_first_5_pages.pdf')) 22 | try: 23 | shutil.rmtree(self.save_path) 24 | except: 25 | pass 26 | 27 | def test__parse_and_chunk_vlm(self): 28 | 29 | vlm_config = { 30 | "provider": "gemini", 31 | "model": "gemini-2.0-flash", 32 | } 33 | semantic_sectioning_config = { 34 | "llm_provider": "openai", 35 | "model": "gpt-4o-mini", 36 | "language": "en", 37 | } 38 | file_parsing_config = { 39 | "use_vlm": True, 40 | "vlm_config": vlm_config, 41 | "always_save_page_images": True 42 | } 43 | sections, chunks = parse_and_chunk( 44 | kb_id=self.kb_id, 45 | doc_id=self.doc_id, 46 | file_path=self.test_data_path, 47 | file_parsing_config=file_parsing_config, 48 | semantic_sectioning_config=semantic_sectioning_config, 49 | chunking_config={}, 50 | file_system=self.file_system, 51 | ) 52 | 53 | # Make sure the sections and chunks were created, and are the correct types 54 | self.assertTrue(len(sections) > 0) 55 | self.assertTrue(len(chunks) > 0) 56 | self.assertEqual(type(sections), list) 57 | self.assertEqual(type(chunks), list) 58 | 59 | for key, expected_type in Section.__annotations__.items(): 60 | self.assertIsInstance(sections[0][key], expected_type) 61 | 62 | for key, expected_type in Chunk.__annotations__.items(): 63 | self.assertIsInstance(chunks[0][key], expected_type) 64 | 65 | self.assertTrue(len(sections[0]["title"]) > 0) 66 | self.assertTrue(len(sections[0]["content"]) > 0) 67 | 68 | # Make sure the elements.json file was created 69 | self.assertTrue(os.path.exists(os.path.join(self.save_path, self.kb_id, self.doc_id, "elements.json"))) 70 | 71 | # Delete the save path 72 | try: 73 | shutil.rmtree(self.save_path) 74 | except: 75 | pass 76 | 77 | 78 | def test_non_vlm_file_parsing(self): 79 | 80 | semantic_sectioning_config = { 81 | "llm_provider": "openai", 82 | "model": "gpt-4o-mini", 83 | "language": "en", 84 | } 85 | file_parsing_config = { 86 | "use_vlm": False, 87 | "always_save_page_images": True, 88 | } 89 | sections, chunks = parse_and_chunk( 90 | kb_id=self.kb_id, 91 | doc_id=self.doc_id, 92 | file_path=self.test_data_path, 93 | file_parsing_config=file_parsing_config, 94 | semantic_sectioning_config=semantic_sectioning_config, 95 | chunking_config={}, 96 | file_system=self.file_system, 97 | ) 98 | 99 | # Make sure the sections and chunks were created, and are the correct types 100 | self.assertTrue(len(sections) > 0) 101 | self.assertTrue(len(chunks) > 0) 102 | self.assertEqual(type(sections), list) 103 | self.assertEqual(type(chunks), list) 104 | 105 | for key, expected_type in Section.__annotations__.items(): 106 | self.assertIsInstance(sections[0][key], expected_type) 107 | 108 | for key, expected_type in Chunk.__annotations__.items(): 109 | self.assertIsInstance(chunks[0][key], expected_type) 110 | 111 | self.assertTrue(len(sections[0]["title"]) > 0) 112 | self.assertTrue(len(sections[0]["content"]) > 0) 113 | 114 | 115 | @classmethod 116 | def tearDownClass(self): 117 | # Delete the save path 118 | try: 119 | shutil.rmtree(self.save_path) 120 | except: 121 | pass 122 | 123 | if __name__ == "__main__": 124 | unittest.main() -------------------------------------------------------------------------------- /dsrag/dsparse/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for dsparse. 3 | """ -------------------------------------------------------------------------------- /dsrag/dsparse/utils/imports.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for lazy imports of optional dependencies. 3 | """ 4 | import importlib 5 | 6 | class LazyLoader: 7 | """ 8 | Lazily import a module only when its attributes are accessed. 9 | 10 | This allows optional dependencies to be imported only when actually used, 11 | rather than at module import time. 12 | 13 | Usage: 14 | # Instead of: import instructor 15 | instructor = LazyLoader("instructor") 16 | 17 | # Then use instructor as normal - it will only be imported when accessed 18 | # If the module is not installed, a helpful error message is shown 19 | """ 20 | def __init__(self, module_name, package_name=None): 21 | """ 22 | Initialize a lazy loader for a module. 23 | 24 | Args: 25 | module_name: The name of the module to import 26 | package_name: Optional package name for pip install instructions 27 | (defaults to module_name if not provided) 28 | """ 29 | self._module_name = module_name 30 | self._package_name = package_name or module_name 31 | self._module = None 32 | 33 | def __getattr__(self, name): 34 | """Called when an attribute is accessed.""" 35 | if self._module is None: 36 | try: 37 | # Use importlib.import_module instead of __import__ for better handling of nested modules 38 | self._module = importlib.import_module(self._module_name) 39 | except ImportError: 40 | raise ImportError( 41 | f"The '{self._module_name}' module is required but not installed. " 42 | f"Please install it with: pip install {self._package_name} " 43 | f"or pip install dsparse[{self._package_name}]" 44 | ) 45 | 46 | # Try to get the attribute from the module 47 | try: 48 | return getattr(self._module, name) 49 | except AttributeError: 50 | # If the attribute is not found, it might be a nested module 51 | try: 52 | # Try to import the nested module 53 | nested_module = importlib.import_module(f"{self._module_name}.{name}") 54 | # Cache it on the module for future access 55 | setattr(self._module, name, nested_module) 56 | return nested_module 57 | except ImportError: 58 | # If that fails, re-raise the original AttributeError 59 | raise AttributeError( 60 | f"Module '{self._module_name}' has no attribute '{name}'" 61 | ) 62 | 63 | # Create lazy loaders for dependencies used in dsparse 64 | instructor = LazyLoader("instructor") 65 | openai = LazyLoader("openai") 66 | anthropic = LazyLoader("anthropic") 67 | genai = LazyLoader("google.generativeai", "google-generativeai") 68 | vertexai = LazyLoader("vertexai") 69 | boto3 = LazyLoader("boto3") -------------------------------------------------------------------------------- /dsrag/metadata.py: -------------------------------------------------------------------------------- 1 | # Metadata storage handling 2 | import os 3 | from decimal import Decimal 4 | import json 5 | from typing import Any 6 | from abc import ABC, abstractmethod 7 | from dsrag.utils.imports import boto3 8 | 9 | class MetadataStorage(ABC): 10 | 11 | def __init__(self) -> None: 12 | pass 13 | 14 | @abstractmethod 15 | def load(self) -> dict: 16 | pass 17 | 18 | @abstractmethod 19 | def save(self, kb: dict) -> None: 20 | pass 21 | 22 | @abstractmethod 23 | def delete(self) -> None: 24 | pass 25 | 26 | class LocalMetadataStorage(MetadataStorage): 27 | 28 | def __init__(self, storage_directory: str) -> None: 29 | super().__init__() 30 | self.storage_directory = storage_directory 31 | 32 | def get_metadata_path(self, kb_id: str) -> str: 33 | return os.path.join(self.storage_directory, "metadata", f"{kb_id}.json") 34 | 35 | def kb_exists(self, kb_id: str) -> bool: 36 | metadata_path = self.get_metadata_path(kb_id) 37 | return os.path.exists(metadata_path) 38 | 39 | def load(self, kb_id: str) -> dict: 40 | metadata_path = self.get_metadata_path(kb_id) 41 | with open(metadata_path, "r") as f: 42 | data = json.load(f) 43 | return data 44 | 45 | def save(self, full_data: dict, kb_id: str): 46 | metadata_path = self.get_metadata_path(kb_id) 47 | metadata_dir = os.path.join(self.storage_directory, "metadata") 48 | if not os.path.exists(metadata_dir): 49 | os.makedirs(metadata_dir) 50 | 51 | with open(metadata_path, "w") as f: 52 | json.dump(full_data, f, indent=4) 53 | 54 | def delete(self, kb_id: str): 55 | metadata_path = self.get_metadata_path(kb_id) 56 | os.remove(metadata_path) 57 | 58 | 59 | 60 | def convert_numbers_to_decimal(obj: Any) -> Any: 61 | """ 62 | Recursively traverse the object, converting all integers and floats to Decimals 63 | """ 64 | if isinstance(obj, dict): 65 | return {k: convert_numbers_to_decimal(v) for k, v in obj.items()} 66 | elif isinstance(obj, list): 67 | return [convert_numbers_to_decimal(item) for item in obj] 68 | elif isinstance(obj, bool): 69 | return obj # Leave booleans as is 70 | elif isinstance(obj, (int, float)): 71 | return Decimal(str(obj)) 72 | else: 73 | return obj 74 | 75 | 76 | def convert_decimal_to_numbers(obj: Any) -> Any: 77 | """ 78 | Recursively traverse the object, converting all Decimals to integers or floats 79 | """ 80 | if isinstance(obj, dict): 81 | return {k: convert_decimal_to_numbers(v) for k, v in obj.items()} 82 | elif isinstance(obj, list): 83 | return [convert_decimal_to_numbers(item) for item in obj] 84 | elif isinstance(obj, Decimal): 85 | if obj == Decimal('1') or obj == Decimal('0'): 86 | # Convert to bool 87 | return bool(obj) 88 | elif obj % 1 == 0: 89 | # Convert to int 90 | return int(obj) 91 | else: 92 | # Convert to float 93 | return float(obj) 94 | else: 95 | return obj 96 | 97 | 98 | class DynamoDBMetadataStorage(MetadataStorage): 99 | 100 | def __init__(self, table_name: str, region_name: str, access_key: str, secret_key: str) -> None: 101 | self.table_name = table_name 102 | self.region_name = region_name 103 | self.access_key = access_key 104 | self.secret_key = secret_key 105 | 106 | def create_dynamo_client(self): 107 | dynamodb_client = boto3.resource( 108 | 'dynamodb', 109 | region_name=self.region_name, 110 | aws_access_key_id=self.access_key, 111 | aws_secret_access_key=self.secret_key, 112 | ) 113 | return dynamodb_client 114 | 115 | def create_table(self): 116 | dynamodb_client = self.create_dynamo_client() 117 | try: 118 | dynamodb_client.create_table( 119 | TableName="metadata_storage", 120 | KeySchema=[ 121 | { 122 | 'AttributeName': 'kb_id', 123 | 'KeyType': 'HASH' # Partition key 124 | } 125 | ], 126 | AttributeDefinitions=[ 127 | { 128 | 'AttributeName': 'kb_id', 129 | 'AttributeType': 'S' # 'S' for String 130 | } 131 | ], 132 | BillingMode='PAY_PER_REQUEST' 133 | ) 134 | except Exception as e: 135 | # Probably the table already exists 136 | print (e) 137 | 138 | def kb_exists(self, kb_id: str) -> bool: 139 | dynamodb_client = self.create_dynamo_client() 140 | table = dynamodb_client.Table(self.table_name) 141 | response = table.get_item(Key={'kb_id': kb_id}) 142 | return 'Item' in response 143 | 144 | def load(self, kb_id: str) -> dict: 145 | # Read the data from the dynamo table 146 | dynamodb_client = self.create_dynamo_client() 147 | table = dynamodb_client.Table(self.table_name) 148 | response = table.get_item(Key={'kb_id': kb_id}) 149 | data = response.get('Item', {}).get('metadata', {}) 150 | kb_metadata = { 151 | key: value for key, value in data.items() if key != "components" 152 | } 153 | components = data.get("components", {}) 154 | full_data = {**kb_metadata, "components": components} 155 | converted_data = convert_decimal_to_numbers(full_data) 156 | return converted_data 157 | 158 | def save(self, full_data: dict, kb_id: str) -> None: 159 | 160 | # Check if any of the items are a float or int, and convert them to Decimal 161 | converted_data = convert_numbers_to_decimal(full_data) 162 | 163 | # Upload this data to the dynamo table, where the kb_id is the primary key 164 | dynamodb_client = self.create_dynamo_client() 165 | table = dynamodb_client.Table(self.table_name) 166 | table.put_item(Item={'kb_id': kb_id, 'metadata': converted_data}) 167 | 168 | def delete(self, kb_id: str) -> None: 169 | dynamodb_client = self.create_dynamo_client() 170 | table = dynamodb_client.Table(self.table_name) 171 | table.delete_item(Key={'kb_id': kb_id}) 172 | 173 | -------------------------------------------------------------------------------- /dsrag/reranker.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import os 3 | from dsrag.utils.imports import cohere, voyageai 4 | from scipy.stats import beta 5 | 6 | 7 | class Reranker(ABC): 8 | subclasses = {} 9 | 10 | def __init_subclass__(cls, **kwargs): 11 | super().__init_subclass__(**kwargs) 12 | cls.subclasses[cls.__name__] = cls 13 | 14 | def to_dict(self): 15 | return { 16 | 'subclass_name': self.__class__.__name__, 17 | } 18 | 19 | @classmethod 20 | def from_dict(cls, config): 21 | subclass_name = config.pop('subclass_name', None) # Remove subclass_name from config 22 | subclass = cls.subclasses.get(subclass_name) 23 | if subclass: 24 | return subclass(**config) # Pass the modified config without subclass_name 25 | else: 26 | raise ValueError(f"Unknown subclass: {subclass_name}") 27 | 28 | @abstractmethod 29 | def rerank_search_results(self, query: str, search_results: list) -> list: 30 | pass 31 | 32 | class CohereReranker(Reranker): 33 | def __init__(self, model: str = "rerank-english-v3.0"): 34 | self.model = model 35 | cohere_api_key = os.environ['CO_API_KEY'] 36 | base_url = os.environ.get("DSRAG_COHERE_BASE_URL", None) 37 | if base_url is not None: 38 | self.client = cohere.Client(api_key=cohere_api_key) 39 | else: 40 | self.client = cohere.Client(api_key=cohere_api_key) 41 | 42 | def transform(self, x): 43 | """ 44 | transformation function to map the absolute relevance value to a value that is more uniformly distributed between 0 and 1 45 | - this is critical for the new version of RSE to work properly, because it utilizes the absolute relevance values to calculate the similarity scores 46 | """ 47 | a, b = 0.4, 0.4 # These can be adjusted to change the distribution shape 48 | return beta.cdf(x, a, b) 49 | 50 | def rerank_search_results(self, query: str, search_results: list) -> list: 51 | """ 52 | Use Cohere Rerank API to rerank the search results 53 | """ 54 | documents = [] 55 | for result in search_results: 56 | documents.append(f"{result['metadata']['chunk_header']}\n\n{result['metadata']['chunk_text']}") 57 | 58 | reranked_results = self.client.rerank(model=self.model, query=query, documents=documents) 59 | results = reranked_results.results 60 | reranked_indices = [result.index for result in results] 61 | reranked_similarity_scores = [result.relevance_score for result in results] 62 | reranked_search_results = [search_results[i] for i in reranked_indices] 63 | for i, result in enumerate(reranked_search_results): 64 | result['similarity'] = self.transform(reranked_similarity_scores[i]) 65 | return reranked_search_results 66 | 67 | def to_dict(self): 68 | base_dict = super().to_dict() 69 | base_dict.update({ 70 | 'model': self.model 71 | }) 72 | return base_dict 73 | 74 | class VoyageReranker(Reranker): 75 | def __init__(self, model: str = "rerank-2"): 76 | self.model = model 77 | voyage_api_key = os.environ['VOYAGE_API_KEY'] 78 | self.client = voyageai.Client(api_key=voyage_api_key) 79 | 80 | def transform(self, x): 81 | """ 82 | transformation function to map the absolute relevance value to a value that is more uniformly distributed between 0 and 1 83 | - this is critical for the new version of RSE to work properly, because it utilizes the absolute relevance values to calculate the similarity scores 84 | """ 85 | a, b = 0.5, 1.8 # These can be adjusted to change the distribution shape 86 | return beta.cdf(x, a, b) 87 | 88 | def rerank_search_results(self, query: str, search_results: list) -> list: 89 | """ 90 | Use Voyage Rerank API to rerank the search results 91 | """ 92 | documents = [] 93 | for result in search_results: 94 | documents.append(f"{result['metadata']['chunk_header']}\n\n{result['metadata']['chunk_text']}") 95 | 96 | reranked_results = self.client.rerank(model=self.model, query=query, documents=documents) 97 | results = reranked_results.results 98 | reranked_indices = [result.index for result in results] 99 | reranked_similarity_scores = [result.relevance_score for result in results] 100 | reranked_search_results = [search_results[i] for i in reranked_indices] 101 | for i, result in enumerate(reranked_search_results): 102 | result['similarity'] = self.transform(reranked_similarity_scores[i]) 103 | return reranked_search_results 104 | 105 | def to_dict(self): 106 | base_dict = super().to_dict() 107 | base_dict.update({ 108 | 'model': self.model 109 | }) 110 | return base_dict 111 | 112 | class NoReranker(Reranker): 113 | def __init__(self, ignore_absolute_relevance: bool = False): 114 | """ 115 | - ignore_absolute_relevance: if True, the reranker will override the absolute relevance values and assign a default similarity score to each chunk. This is useful when using an embedding model where the absolute relevance values are not reliable or meaningful. 116 | """ 117 | self.ignore_absolute_relevance = ignore_absolute_relevance 118 | 119 | def rerank_search_results(self, query: str, search_results: list) -> list: 120 | if self.ignore_absolute_relevance: 121 | for result in search_results: 122 | result['similarity'] = 0.8 # default similarity score (represents a moderately relevant chunk) 123 | return search_results 124 | 125 | def to_dict(self): 126 | base_dict = super().to_dict() 127 | base_dict.update({ 128 | 'ignore_absolute_relevance': self.ignore_absolute_relevance, 129 | }) 130 | return base_dict -------------------------------------------------------------------------------- /dsrag/utils/imports.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for lazy imports of optional dependencies. 3 | """ 4 | import importlib 5 | 6 | class LazyLoader: 7 | """ 8 | Lazily import a module only when its attributes are accessed. 9 | 10 | This allows optional dependencies to be imported only when actually used, 11 | rather than at module import time. 12 | 13 | Usage: 14 | # Instead of: import chromadb 15 | chromadb = LazyLoader("chromadb") 16 | 17 | # Then use chromadb as normal - it will only be imported when accessed 18 | # If the module is not installed, a helpful error message is shown 19 | """ 20 | def __init__(self, module_name, package_name=None): 21 | """ 22 | Initialize a lazy loader for a module. 23 | 24 | Args: 25 | module_name: The name of the module to import 26 | package_name: Optional package name for pip install instructions 27 | (defaults to module_name if not provided) 28 | """ 29 | self._module_name = module_name 30 | self._package_name = package_name or module_name 31 | self._module = None 32 | 33 | def __getattr__(self, name): 34 | """Called when an attribute is accessed.""" 35 | if self._module is None: 36 | try: 37 | # Use importlib.import_module instead of __import__ for better handling of nested modules 38 | self._module = importlib.import_module(self._module_name) 39 | except ImportError: 40 | raise ImportError( 41 | f"The '{self._module_name}' module is required but not installed. " 42 | f"Please install it with: pip install {self._package_name} " 43 | f"or pip install dsrag[{self._package_name}]" 44 | ) 45 | 46 | # Try to get the attribute from the module 47 | try: 48 | return getattr(self._module, name) 49 | except AttributeError: 50 | # If the attribute is not found, it might be a nested module 51 | try: 52 | # Try to import the nested module 53 | nested_module = importlib.import_module(f"{self._module_name}.{name}") 54 | # Cache it on the module for future access 55 | setattr(self._module, name, nested_module) 56 | return nested_module 57 | except ImportError: 58 | # If that fails, re-raise the original AttributeError 59 | raise AttributeError( 60 | f"Module '{self._module_name}' has no attribute '{name}'" 61 | ) 62 | 63 | # Create lazy loaders for commonly used optional dependencies 64 | instructor = LazyLoader("instructor") 65 | openai = LazyLoader("openai") 66 | cohere = LazyLoader("cohere") 67 | voyageai = LazyLoader("voyageai") 68 | ollama = LazyLoader("ollama") 69 | anthropic = LazyLoader("anthropic") 70 | genai = LazyLoader("google.generativeai", "google-generativeai") 71 | genai_new = LazyLoader("google.genai", "google-genai") 72 | boto3 = LazyLoader("boto3") 73 | faiss = LazyLoader("faiss", "faiss-cpu") 74 | psycopg2 = LazyLoader("psycopg2", "psycopg2-binary") 75 | pgvector = LazyLoader("pgvector") -------------------------------------------------------------------------------- /eval/financebench/dsrag_financebench_create_kb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | # add ../../ to sys.path 5 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) 6 | 7 | # import the necessary modules 8 | from dsrag.knowledge_base import KnowledgeBase 9 | from dsrag.embedding import CohereEmbedding 10 | 11 | document_directory = "" # path to the directory containing the documents, links to which are contained in the `tests/data/financebench_sample_150.csv` file 12 | kb_id = "finance_bench" 13 | 14 | """ 15 | Note: to run this, you'll need to first download all of the documents (~80 of them, mostly 10-Ks and 10-Qs) and convert them to txt files. You'll want to make sure they all have consistent and meaningful file names, such as costco_2023_10k.txt, to match the performance we quote. 16 | """ 17 | 18 | # create a new KnowledgeBase object 19 | embedding_model = CohereEmbedding() 20 | kb = KnowledgeBase(kb_id, embedding_model=embedding_model, exists_ok=False) 21 | 22 | for file_name in os.listdir(document_directory): 23 | with open(os.path.join(document_directory, file_name), "r") as f: 24 | text = f.read() 25 | kb.add_document(doc_id=file_name, text=text) 26 | print (f"Added {file_name} to the knowledge base.") -------------------------------------------------------------------------------- /eval/financebench/dsrag_financebench_eval.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | import sys 4 | 5 | # add ../../ to sys.path 6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) 7 | 8 | # import the necessary modules 9 | from dsrag.knowledge_base import KnowledgeBase 10 | from dsrag.auto_query import get_search_queries 11 | from dsrag.llm import OpenAIChatAPI 12 | 13 | AUTO_QUERY_GUIDANCE = """ 14 | The knowledge base contains SEC filings for publicly traded companies, like 10-Ks, 10-Qs, and 8-Ks. Keep this in mind when generating search queries. The things you search for should be things that are likely to be found in these documents. 15 | 16 | When deciding what to search for, first consider the pieces of information that will be needed to answer the question. Then, consider what to search for to find those pieces of information. For example, if the question asks what the change in revenue was from 2019 to 2020, you would want to search for the 2019 and 2020 revenue numbers in two separate search queries, since those are the two separate pieces of information needed. You should also think about where you are most likely to find the information you're looking for. If you're looking for assets and liabilities, you may want to search for the balance sheet, for example. 17 | """.strip() 18 | 19 | RESPONSE_SYSTEM_MESSAGE = """ 20 | You are a response generation system. Please generate a response to the user input based on the provided context. Your response should be as concise as possible while still fully answering the user's question. 21 | 22 | CONTEXT 23 | {context} 24 | """.strip() 25 | 26 | def get_response(question: str, context: str): 27 | client = OpenAIChatAPI(model="gpt-4-turbo", temperature=0.0) 28 | chat_messages = [{"role": "system", "content": RESPONSE_SYSTEM_MESSAGE.format(context=context)}, {"role": "user", "content": question}] 29 | return client.make_llm_call(chat_messages) 30 | 31 | # Get the absolute path of the script file 32 | script_dir = os.path.dirname(os.path.abspath(__file__)) 33 | 34 | # Construct the absolute path to the dataset file 35 | dataset_file_path = os.path.join(script_dir, "../../tests/data/financebench_sample_150.csv") 36 | 37 | # Read in the data 38 | df = pd.read_csv(dataset_file_path) 39 | 40 | # get the questions and answers from the question column - turn into lists 41 | questions = df.question.tolist() 42 | answers = df.answer.tolist() 43 | 44 | # load the knowledge base 45 | kb = KnowledgeBase("finance_bench") 46 | 47 | # set parameters for relevant segment extraction - slightly longer context than default 48 | rse_params = { 49 | 'max_length': 12, 50 | 'overall_max_length': 30, 51 | 'overall_max_length_extension': 6, 52 | } 53 | 54 | # adjust range if you only want to run a subset of the questions (there are 150 total) 55 | for i in range(150): 56 | print (f"Question {i+1}") 57 | question = questions[i] 58 | answer = answers[i] 59 | search_queries = get_search_queries(question, max_queries=6, auto_query_guidance=AUTO_QUERY_GUIDANCE) 60 | relevant_segments = kb.query(search_queries) 61 | context = "\n\n".join([segment['text'] for segment in relevant_segments]) 62 | response = get_response(question, context) 63 | print (f"\nQuestion: {question}") 64 | print (f"\nSearch queries: {search_queries}") 65 | print (f"\nModel response: {response}") 66 | print (f"\nGround truth answer: {answer}") 67 | print ("\n---\n") -------------------------------------------------------------------------------- /eval/rse_eval.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../'))) 5 | 6 | from dsrag.create_kb import create_kb_from_file 7 | from dsrag.knowledge_base import KnowledgeBase 8 | from dsrag.reranker import NoReranker, CohereReranker, VoyageReranker 9 | from dsrag.embedding import OpenAIEmbedding, CohereEmbedding 10 | from dsrag.dsparse.file_parsing.non_vlm_file_parsing import extract_text_from_pdf 11 | 12 | 13 | def test_create_kb_from_file(): 14 | cleanup() # delete the KnowledgeBase object if it exists so we can start fresh 15 | 16 | # Get the absolute path of the script file 17 | script_dir = os.path.dirname(os.path.abspath(__file__)) 18 | 19 | # Construct the absolute path to the dataset file 20 | file_path = os.path.join(script_dir, "../tests/data/levels_of_agi.pdf") 21 | 22 | kb_id = "levels_of_agi" 23 | #kb = create_kb_from_file(kb_id, file_path) 24 | kb = KnowledgeBase(kb_id=kb_id, exists_ok=False) 25 | text = extract_text_from_pdf(file_path) 26 | 27 | auto_context_config = { 28 | 'use_generated_title': True, 29 | 'get_document_summary': True, 30 | 'get_section_summaries': True, 31 | } 32 | 33 | kb.add_document(file_path, text, auto_context_config=auto_context_config) 34 | 35 | def cleanup(): 36 | kb = KnowledgeBase(kb_id="levels_of_agi", exists_ok=True) 37 | kb.delete() 38 | 39 | """ 40 | This script is for doing qualitative evaluation of RSE 41 | """ 42 | 43 | 44 | if __name__ == "__main__": 45 | # create the KnowledgeBase 46 | try: 47 | test_create_kb_from_file() 48 | except ValueError as e: 49 | print(e) 50 | 51 | rse_params = { 52 | 'max_length': 20, 53 | 'overall_max_length': 30, 54 | 'minimum_value': 0.5, 55 | 'irrelevant_chunk_penalty': 0.2, 56 | 'overall_max_length_extension': 5, 57 | 'decay_rate': 30, 58 | 'top_k_for_document_selection': 7, 59 | } 60 | 61 | #reranker = NoReranker(ignore_absolute_relevance=True) 62 | reranker = CohereReranker() 63 | #reranker = VoyageReranker() 64 | 65 | # load the KnowledgeBase and query it 66 | kb = KnowledgeBase(kb_id="levels_of_agi", exists_ok=True, reranker=reranker) 67 | #print (kb.chunk_db.get_all_doc_ids()) 68 | 69 | #search_queries = ["What are the levels of AGI?"] 70 | #search_queries = ["Who is the president of the United States?"] 71 | #search_queries = ["AI"] 72 | #search_queries = ["What is the difference between AGI and ASI?"] 73 | #search_queries = ["How does autonomy factor into AGI?"] 74 | #search_queries = ["Self-driving cars"] 75 | search_queries = ["Methodology for determining levels of AGI"] 76 | #search_queries = ["Principles for defining levels of AGI"] 77 | #search_queries = ["What is Autonomy Level 3"] 78 | #search_queries = ["Use of existing AI benchmarks like Big-bench and HELM"] 79 | #search_queries = ["Introduction", "Conclusion"] 80 | #search_queries = ["References"] 81 | 82 | relevant_segments = kb.query(search_queries=search_queries, rse_params=rse_params) 83 | 84 | print () 85 | for segment in relevant_segments: 86 | #print (len(segment["text"])) 87 | #print (segment["score"]) 88 | print (segment["text"]) 89 | print ("---\n") -------------------------------------------------------------------------------- /eval/vlm_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from typing import Dict, List 4 | import PIL.Image 5 | import google.generativeai as genai 6 | 7 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))) 8 | from dsrag.knowledge_base import KnowledgeBase 9 | 10 | 11 | def create_kb_and_add_document(kb_id: str, file_path: str, document_title: str): 12 | kb = KnowledgeBase(kb_id=kb_id, exists_ok=True) 13 | kb.delete() 14 | 15 | kb = KnowledgeBase(kb_id=kb_id, exists_ok=False) 16 | kb.add_document( 17 | doc_id="test_doc_1", 18 | file_path=file_path, 19 | document_title=document_title, 20 | file_parsing_config={ 21 | "use_vlm": True, 22 | "vlm_config": { 23 | "provider": "gemini", 24 | "model": "gemini-1.5-flash-002", 25 | } 26 | } 27 | ) 28 | 29 | return kb 30 | 31 | def get_response(user_input: str, search_results: List[Dict], model_name: str = "gemini-1.5-pro-002"): 32 | """ 33 | Given a user input and potentially multimodal search results, generate a response. 34 | 35 | We'll use the Gemini API since we already have code for it. 36 | """ 37 | genai.configure(api_key=os.environ["GEMINI_API_KEY"]) 38 | generation_config = { 39 | "temperature": 0.0, 40 | "max_output_tokens": 2000 41 | } 42 | 43 | model = genai.GenerativeModel(model_name) 44 | 45 | all_image_paths = [] 46 | for search_result in search_results: 47 | if type(search_result["content"]) == list: 48 | all_image_paths += search_result["content"] 49 | 50 | images = [PIL.Image.open(image_path) for image_path in all_image_paths] 51 | system_message = "Please cite the page number of the document used to answer the question." 52 | generation_input = images + [user_input, system_message] # user input is the last element 53 | response = model.generate_content( 54 | contents=generation_input, 55 | generation_config=generation_config 56 | ) 57 | [image.close() for image in images] 58 | return response.text 59 | 60 | 61 | kb_id = "state_of_ai_flash" 62 | 63 | """ 64 | # create or load KB 65 | file_path = "/Users/zach/Code/test_docs/state_of_ai_report_2024.pdf" 66 | document_title = "State of AI Report 2024" 67 | kb = create_kb_and_add_document(kb_id=kb_id, file_path=file_path, document_title=document_title) 68 | 69 | """ 70 | 71 | kb = KnowledgeBase(kb_id=kb_id, exists_ok=True) 72 | 73 | #query = "What is the McKinsey Energy Report about?" 74 | #query = "How does the oil and gas industry compare to other industries in terms of its value prop to employees?" 75 | 76 | query = "Who is Nathan Benaich?" # page 2 77 | #query = "Which country had the most AI publications in 2024?" # page 84 78 | #query = "Did frontier labs increase or decrease publications in 2024? By how much?" # page 84 79 | #query = "Who has the most H100s? How many do they have?" # page 92 80 | #query = "How many H100s does Lambda have?" # page 92 81 | 82 | rse_params = { 83 | "minimum_value": 0.5, 84 | "irrelevant_chunk_penalty": 0.2, 85 | } 86 | 87 | search_results = kb.query(search_queries=[query], rse_params=rse_params, return_mode="page_images") 88 | print (search_results) 89 | 90 | response = get_response(user_input=query, search_results=search_results) 91 | print(response) -------------------------------------------------------------------------------- /examples/example_kb_data/chunk_storage/levels_of_agi.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-Star-AI/dsRAG/b326c1d0930c9a80ff486ab0df5e771e788e3234/examples/example_kb_data/chunk_storage/levels_of_agi.pkl -------------------------------------------------------------------------------- /examples/example_kb_data/chunk_storage/nike_10k.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-Star-AI/dsRAG/b326c1d0930c9a80ff486ab0df5e771e788e3234/examples/example_kb_data/chunk_storage/nike_10k.pkl -------------------------------------------------------------------------------- /examples/example_kb_data/metadata/levels_of_agi.json: -------------------------------------------------------------------------------- 1 | { 2 | "title": "", 3 | "description": "", 4 | "language": "en", 5 | "chunk_size": 800, 6 | "components": { 7 | "embedding_model": { 8 | "subclass_name": "OpenAIEmbedding", 9 | "dimension": 768, 10 | "model": "text-embedding-3-small" 11 | }, 12 | "reranker": { 13 | "subclass_name": "CohereReranker", 14 | "model": "rerank-english-v3.0" 15 | }, 16 | "auto_context_model": { 17 | "subclass_name": "AnthropicChatAPI", 18 | "model": "claude-3-haiku-20240307", 19 | "temperature": 0.2, 20 | "max_tokens": 1000 21 | }, 22 | "vector_db": { 23 | "subclass_name": "BasicVectorDB", 24 | "kb_id": "levels_of_agi", 25 | "storage_directory": "example_kb_data", 26 | "use_faiss": true 27 | }, 28 | "chunk_db": { 29 | "subclass_name": "BasicChunkDB", 30 | "kb_id": "levels_of_agi", 31 | "storage_directory": "example_kb_data" 32 | } 33 | } 34 | } -------------------------------------------------------------------------------- /examples/example_kb_data/metadata/nike_10k.json: -------------------------------------------------------------------------------- 1 | { 2 | "title": "", 3 | "description": "", 4 | "language": "en", 5 | "chunk_size": 800, 6 | "components": { 7 | "embedding_model": { 8 | "subclass_name": "OpenAIEmbedding", 9 | "dimension": 768, 10 | "model": "text-embedding-3-small" 11 | }, 12 | "reranker": { 13 | "subclass_name": "CohereReranker", 14 | "model": "rerank-english-v3.0" 15 | }, 16 | "auto_context_model": { 17 | "subclass_name": "AnthropicChatAPI", 18 | "model": "claude-3-haiku-20240307", 19 | "temperature": 0.2, 20 | "max_tokens": 1000 21 | }, 22 | "vector_db": { 23 | "subclass_name": "BasicVectorDB", 24 | "kb_id": "nike_10k", 25 | "storage_directory": "example_kb_data", 26 | "use_faiss": true 27 | }, 28 | "chunk_db": { 29 | "subclass_name": "BasicChunkDB", 30 | "kb_id": "nike_10k", 31 | "storage_directory": "example_kb_data" 32 | }, 33 | "file_system": { 34 | "subclass_name": "LocalFileSystem", 35 | "base_path": "example_kb_data" 36 | } 37 | } 38 | } -------------------------------------------------------------------------------- /examples/example_kb_data/vector_storage/levels_of_agi.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-Star-AI/dsRAG/b326c1d0930c9a80ff486ab0df5e771e788e3234/examples/example_kb_data/vector_storage/levels_of_agi.pkl -------------------------------------------------------------------------------- /examples/example_kb_data/vector_storage/nike_10k.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-Star-AI/dsRAG/b326c1d0930c9a80ff486ab0df5e771e788e3234/examples/example_kb_data/vector_storage/nike_10k.pkl -------------------------------------------------------------------------------- /integrations/langchain_retriever.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Union 2 | 3 | from langchain_core.documents import Document 4 | from langchain_core.retrievers import BaseRetriever 5 | from langchain_core.callbacks import CallbackManagerForRetrieverRun 6 | 7 | from dsrag.knowledge_base import KnowledgeBase 8 | 9 | class DsRAGLangchainRetriever(BaseRetriever): 10 | """ 11 | A retriever for the DsRAG project that uses the KnowledgeBase class to retrieve documents. 12 | """ 13 | 14 | kb_id: str = "" 15 | rse_params: Union[Dict, str] = "balanced" 16 | 17 | def __init__(self, kb_id: str, rse_params: Union[Dict, str] = "balanced"): 18 | super().__init__() 19 | self.kb_id = kb_id 20 | self.rse_params = rse_params 21 | 22 | 23 | """ 24 | Initializes a DsRAGLangchainRetriever instance. 25 | 26 | Args: 27 | kb_id: A KnowledgeBase id to use for retrieving documents. 28 | rse_params: The RSE parameters to use for querying 29 | """ 30 | 31 | 32 | def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]: 33 | 34 | kb = KnowledgeBase(kb_id=self.kb_id, exists_ok=True) 35 | 36 | segment_info = kb.query([query], rse_params=self.rse_params) 37 | # Create a Document object for each segment 38 | documents = [] 39 | for segment in segment_info: 40 | document = Document( 41 | page_content=segment["text"], 42 | metadata={"doc_id": segment["doc_id"], "chunk_start": segment["chunk_start"], "chunk_end": segment["chunk_end"]}, 43 | ) 44 | documents.append(document) 45 | 46 | return documents 47 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: dsRAG 2 | site_description: Documentation for the dsRAG library 3 | site_url: https://d-star-ai.github.io/dsRAG/ 4 | theme: 5 | name: material 6 | logo: logo.png 7 | favicon: logo.png 8 | features: 9 | - navigation.tabs 10 | - navigation.sections 11 | - toc.integrate 12 | - search.suggest 13 | - search.highlight 14 | - content.tabs.link 15 | - content.code.annotation 16 | - content.code.copy 17 | language: en 18 | palette: 19 | - scheme: default 20 | toggle: 21 | icon: material/toggle-switch-off-outline 22 | name: Switch to dark mode 23 | primary: teal 24 | accent: purple 25 | - scheme: slate 26 | toggle: 27 | icon: material/toggle-switch 28 | name: Switch to light mode 29 | primary: teal 30 | accent: lime 31 | 32 | nav: 33 | - Home: index.md 34 | - Getting Started: 35 | - Installation: getting-started/installation.md 36 | - Quick Start: getting-started/quickstart.md 37 | - Concepts: 38 | - Overview: concepts/overview.md 39 | - Knowledge Bases: concepts/knowledge-bases.md 40 | - Components: concepts/components.md 41 | - Configuration: concepts/config.md 42 | - Chat: concepts/chat.md 43 | - Citations: concepts/citations.md 44 | - VLM File Parsing: concepts/vlm.md 45 | - Logging: concepts/logging.md 46 | - API: 47 | - KnowledgeBase: api/knowledge_base.md 48 | - Chat: api/chat.md 49 | - Community: 50 | - Contributing: community/contributing.md 51 | - Support: community/support.md 52 | - Professional Services: community/professional-services.md 53 | 54 | markdown_extensions: 55 | - pymdownx.highlight: 56 | anchor_linenums: true 57 | - pymdownx.inlinehilite 58 | - pymdownx.snippets 59 | - admonition 60 | - pymdownx.arithmatex: 61 | generic: true 62 | - footnotes 63 | - pymdownx.details 64 | - pymdownx.superfences 65 | - pymdownx.mark 66 | - attr_list 67 | - pymdownx.emoji: 68 | emoji_index: !!python/name:material.extensions.emoji.twemoji 69 | emoji_generator: !!python/name:material.extensions.emoji.to_svg 70 | 71 | plugins: 72 | - search 73 | - mkdocstrings: 74 | handlers: 75 | python: 76 | paths: [dsrag] 77 | options: 78 | docstring_style: google 79 | show_source: true 80 | show_root_heading: true 81 | show_root_full_path: true 82 | show_object_full_path: false 83 | show_category_heading: true 84 | show_if_no_docstring: false 85 | show_signature_annotations: true 86 | separate_signature: true 87 | line_length: 80 88 | merge_init_into_class: true 89 | docstring_section_style: spacy 90 | 91 | copyright: | 92 | © 2025 dsRAG Contributors -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "dsrag" 7 | version = "0.5.3" 8 | description = "State-of-the-art RAG pipeline from D-Star AI" 9 | readme = "README.md" 10 | authors = [{ name = "Zach McCormick", email = "zach@d-star.ai" }, { name = "Nick McCormick", email = "nick@d-star.ai" }] 11 | license = "MIT" 12 | classifiers=[ 13 | "Development Status :: 4 - Beta", 14 | "Intended Audience :: Developers", 15 | "Intended Audience :: Science/Research", 16 | "Operating System :: OS Independent", 17 | "Programming Language :: Python :: 3.9", 18 | "Programming Language :: Python :: 3.10", 19 | "Programming Language :: Python :: 3.11", 20 | "Programming Language :: Python :: 3.12", 21 | "Topic :: Database", 22 | "Topic :: Software Development", 23 | "Topic :: Software Development :: Libraries", 24 | "Topic :: Software Development :: Libraries :: Application Frameworks", 25 | "Topic :: Software Development :: Libraries :: Python Modules" 26 | ] 27 | requires-python = ">=3.9" 28 | # Core dependencies are included directly (matching requirements.in) 29 | dependencies = [ 30 | "langchain-text-splitters>=0.3.0", 31 | "langchain-core>=0.3.0", 32 | "pydantic>=2.8.2", 33 | "numpy>=1.20.0", 34 | "pandas>=2.0.0", 35 | "tiktoken>=0.5.0", 36 | "tqdm>=4.65.0", 37 | "requests>=2.30.0", 38 | "python-dotenv>=1.0.0", 39 | "typing-extensions>=4.5.0", 40 | "instructor>=1.7.0", 41 | "scipy>=1.10.0", 42 | "scikit-learn>=1.2.0", 43 | "pillow>=10.0.0", 44 | "pdf2image>=1.16.0", 45 | "docx2txt>=0.8", 46 | "PyPDF2>=3.0.1" 47 | ] 48 | 49 | [project.urls] 50 | Homepage = "https://github.com/D-Star-AI/dsRAG" 51 | Documentation = "https://github.com/D-Star-AI/dsRAG" 52 | Contact = "https://github.com/D-Star-AI/dsRAG" 53 | 54 | [project.optional-dependencies] 55 | # Database optional dependencies 56 | faiss = ["faiss-cpu>=1.8.0"] 57 | chroma = ["chromadb>=0.5.5"] 58 | weaviate = ["weaviate-client>=4.6.0"] 59 | qdrant = ["qdrant-client>=1.8.0"] 60 | milvus = ["pymilvus>=2.3.5"] 61 | pinecone = ["pinecone>=3.0.0"] 62 | postgres = ["psycopg2-binary>=2.9.0", "pgvector>=0.2.0"] 63 | boto3 = ["boto3>=1.28.0"] 64 | 65 | # LLM/embedding/reranker optional dependencies 66 | openai = ["openai>=1.52.2"] 67 | cohere = ["cohere>=4.0.0"] 68 | voyageai = ["voyageai>=0.1.0"] 69 | ollama = ["ollama>=0.1.0"] 70 | anthropic = ["anthropic>=0.37.1"] 71 | google-generativeai = ["google-generativeai>=0.8.3"] 72 | 73 | # Convenience groups 74 | all-dbs = [ 75 | "dsrag[faiss,chroma,weaviate,qdrant,milvus,pinecone,postgres,boto3]" 76 | ] 77 | 78 | all-models = [ 79 | "dsrag[openai,cohere,voyageai,ollama,anthropic,google-generativeai]" 80 | ] 81 | 82 | # Complete installation with all optional dependencies 83 | all = [ 84 | "dsrag[all-dbs,all-models]" 85 | ] 86 | 87 | [tool.setuptools.packages.find] 88 | where = ["."] 89 | include = ["dsrag", "dsrag.*"] 90 | exclude = ["dsrag.dsparse.tests", "dsrag.dsparse.tests.*", "dsrag.dsparse.dist", "dsrag.dsparse.dist.*"] -------------------------------------------------------------------------------- /requirements.in: -------------------------------------------------------------------------------- 1 | # Core dependencies for dsRAG 2 | # Generate requirements.txt with: pip-compile requirements.in 3 | 4 | # Core functionality - only include direct dependencies here, not transitive ones 5 | langchain-text-splitters>=0.3.0 6 | langchain-core>=0.3.0 # Updated to newer version for tenacity compatibility 7 | pydantic>=2.8.2 # Version compatible with all dependencies 8 | numpy>=1.20.0 9 | pandas>=2.0.0 10 | tiktoken>=0.5.0 11 | tqdm>=4.65.0 12 | requests>=2.30.0 13 | python-dotenv>=1.0.0 14 | typing-extensions>=4.5.0 15 | instructor>=1.7.0 # Required for streaming functionality 16 | # tenacity and jiter are unpinned to allow instructor to select compatible versions 17 | scipy>=1.10.0 18 | scikit-learn>=1.2.0 19 | pillow>=10.0.0 20 | pdf2image>=1.16.0 21 | docx2txt>=0.8 22 | PyPDF2>=3.0.1 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with Python 3.12 3 | # by the following command: 4 | # 5 | # pip-compile requirements.in 6 | # 7 | aiohappyeyeballs==2.4.6 8 | # via aiohttp 9 | aiohttp==3.11.13 10 | # via instructor 11 | aiosignal==1.3.1 12 | # via aiohttp 13 | annotated-types==0.7.0 14 | # via pydantic 15 | anyio==4.8.0 16 | # via 17 | # httpx 18 | # openai 19 | attrs==23.2.0 20 | # via aiohttp 21 | certifi==2025.1.31 22 | # via 23 | # httpcore 24 | # httpx 25 | # requests 26 | charset-normalizer==3.4.1 27 | # via requests 28 | click==8.1.7 29 | # via typer 30 | distro==1.9.0 31 | # via openai 32 | docstring-parser==0.16 33 | # via instructor 34 | docx2txt==0.8 35 | # via -r requirements.in 36 | frozenlist==1.5.0 37 | # via 38 | # aiohttp 39 | # aiosignal 40 | h11==0.14.0 41 | # via httpcore 42 | httpcore==1.0.7 43 | # via httpx 44 | httpx==0.27.0 45 | # via 46 | # langsmith 47 | # openai 48 | idna==3.10 49 | # via 50 | # anyio 51 | # httpx 52 | # requests 53 | # yarl 54 | instructor==1.7.3 55 | # via -r requirements.in 56 | jinja2==3.1.6 57 | # via instructor 58 | jiter==0.8.2 59 | # via 60 | # instructor 61 | # openai 62 | joblib==1.4.2 63 | # via scikit-learn 64 | jsonpatch==1.33 65 | # via langchain-core 66 | jsonpointer==3.0.0 67 | # via jsonpatch 68 | langchain-core==0.3.42 69 | # via 70 | # -r requirements.in 71 | # langchain-text-splitters 72 | langchain-text-splitters==0.3.6 73 | # via -r requirements.in 74 | langsmith==0.1.147 75 | # via langchain-core 76 | markdown-it-py==3.0.0 77 | # via rich 78 | markupsafe==3.0.2 79 | # via jinja2 80 | mdurl==0.1.2 81 | # via markdown-it-py 82 | multidict==6.1.0 83 | # via 84 | # aiohttp 85 | # yarl 86 | numpy==1.26.4 87 | # via 88 | # -r requirements.in 89 | # pandas 90 | # scikit-learn 91 | # scipy 92 | openai==1.64.0 93 | # via instructor 94 | orjson==3.10.15 95 | # via langsmith 96 | packaging==24.1 97 | # via langchain-core 98 | pandas==2.2.3 99 | # via -r requirements.in 100 | pdf2image==1.16.0 101 | # via -r requirements.in 102 | pillow==10.0.0 103 | # via 104 | # -r requirements.in 105 | # pdf2image 106 | propcache==0.3.0 107 | # via 108 | # aiohttp 109 | # yarl 110 | pydantic==2.8.2 111 | # via 112 | # -r requirements.in 113 | # instructor 114 | # langchain-core 115 | # langsmith 116 | # openai 117 | pydantic-core==2.20.1 118 | # via 119 | # instructor 120 | # pydantic 121 | pygments==2.18.0 122 | # via rich 123 | pypdf2==3.0.1 124 | # via -r requirements.in 125 | python-dateutil==2.9.0.post0 126 | # via pandas 127 | python-dotenv==1.0.1 128 | # via -r requirements.in 129 | pytz==2025.1 130 | # via pandas 131 | pyyaml==6.0.1 132 | # via langchain-core 133 | regex==2024.5.15 134 | # via tiktoken 135 | requests==2.32.3 136 | # via 137 | # -r requirements.in 138 | # instructor 139 | # langsmith 140 | # requests-toolbelt 141 | # tiktoken 142 | requests-toolbelt==1.0.0 143 | # via langsmith 144 | rich==13.7.1 145 | # via 146 | # instructor 147 | # typer 148 | scikit-learn==1.6.1 149 | # via -r requirements.in 150 | scipy==1.15.2 151 | # via 152 | # -r requirements.in 153 | # scikit-learn 154 | shellingham==1.5.4 155 | # via typer 156 | six==1.16.0 157 | # via python-dateutil 158 | sniffio==1.3.1 159 | # via 160 | # anyio 161 | # httpx 162 | # openai 163 | tenacity==9.0.0 164 | # via 165 | # instructor 166 | # langchain-core 167 | threadpoolctl==3.5.0 168 | # via scikit-learn 169 | tiktoken==0.9.0 170 | # via -r requirements.in 171 | tqdm==4.67.1 172 | # via 173 | # -r requirements.in 174 | # openai 175 | typer==0.12.3 176 | # via instructor 177 | typing-extensions==4.12.2 178 | # via 179 | # -r requirements.in 180 | # anyio 181 | # langchain-core 182 | # openai 183 | # pydantic 184 | # pydantic-core 185 | # typer 186 | tzdata==2024.1 187 | # via pandas 188 | urllib3==2.3.0 189 | # via requests 190 | yarl==1.18.3 191 | # via aiohttp 192 | -------------------------------------------------------------------------------- /tests/data/levels_of_agi.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-Star-AI/dsRAG/b326c1d0930c9a80ff486ab0df5e771e788e3234/tests/data/levels_of_agi.pdf -------------------------------------------------------------------------------- /tests/data/mck_energy_first_5_pages.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-Star-AI/dsRAG/b326c1d0930c9a80ff486ab0df5e771e788e3234/tests/data/mck_energy_first_5_pages.pdf -------------------------------------------------------------------------------- /tests/data/page_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-Star-AI/dsRAG/b326c1d0930c9a80ff486ab0df5e771e788e3234/tests/data/page_7.png -------------------------------------------------------------------------------- /tests/integration/test_chunk_and_segment_headers.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import unittest 4 | 5 | # add ../../ to sys.path 6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) 7 | 8 | from dsrag.dsparse.file_parsing.non_vlm_file_parsing import extract_text_from_pdf 9 | from dsrag.knowledge_base import KnowledgeBase 10 | 11 | 12 | class TestChunkAndSegmentHeaders(unittest.TestCase): 13 | def test_all_context_included(self): 14 | self.cleanup() # delete the KnowledgeBase object if it exists so we can start fresh 15 | 16 | # Get the absolute path of the script file 17 | script_dir = os.path.dirname(os.path.abspath(__file__)) 18 | 19 | # Construct the absolute path to the dataset file 20 | file_path = os.path.join(script_dir, "../data/levels_of_agi.pdf") 21 | 22 | # extract text from the pdf 23 | text, _ = extract_text_from_pdf(file_path) 24 | 25 | kb_id = "levels_of_agi" 26 | kb = KnowledgeBase(kb_id=kb_id, exists_ok=False) 27 | 28 | # set configuration for auto context and semantic sectioning and add the document to the KnowledgeBase 29 | auto_context_config = { 30 | 'use_generated_title': True, 31 | 'get_document_summary': True, 32 | 'get_section_summaries': True, 33 | } 34 | semantic_sectioning_config = { 35 | 'use_semantic_sectioning': True, 36 | } 37 | kb.add_document(doc_id="doc_1", text=text, auto_context_config=auto_context_config, semantic_sectioning_config=semantic_sectioning_config) 38 | 39 | # verify that the chunk headers are as expected by running searches and looking at the chunk_header in the results 40 | search_query = "What are the Levels of AGI and why were they defined this way?" 41 | search_results = kb._search(search_query, top_k=1) 42 | chunk_header = search_results[0]['metadata']['chunk_header'] 43 | print (f"\nCHUNK HEADER\n{chunk_header}") 44 | 45 | # verify that the segment headers are as expected 46 | segment_header = kb._get_segment_header(doc_id="doc_1", chunk_index=0) 47 | 48 | assert "Section context" in chunk_header 49 | assert "AGI" in segment_header 50 | 51 | # delete the KnowledgeBase object 52 | kb.delete() 53 | 54 | def test_no_context_included(self): 55 | self.cleanup() # delete the KnowledgeBase object if it exists so we can start fresh 56 | 57 | # Get the absolute path of the script file 58 | script_dir = os.path.dirname(os.path.abspath(__file__)) 59 | 60 | # Construct the absolute path to the dataset file 61 | file_path = os.path.join(script_dir, "../data/levels_of_agi.pdf") 62 | 63 | # extract text from the pdf 64 | text, _ = extract_text_from_pdf(file_path) 65 | 66 | kb_id = "levels_of_agi" 67 | kb = KnowledgeBase(kb_id=kb_id, exists_ok=False) 68 | 69 | # set configuration for auto context and semantic sectioning and add the document to the KnowledgeBase 70 | auto_context_config = { 71 | 'use_generated_title': False, 72 | 'get_document_summary': False, 73 | 'get_section_summaries': False, 74 | } 75 | semantic_sectioning_config = { 76 | 'use_semantic_sectioning': False, 77 | } 78 | kb.add_document(doc_id="doc_1", text=text, auto_context_config=auto_context_config, semantic_sectioning_config=semantic_sectioning_config) 79 | 80 | # verify that the chunk headers are as expected by running searches and looking at the chunk_header in the results 81 | search_query = "What are the Levels of AGI and why were they defined this way?" 82 | search_results = kb._search(search_query, top_k=1) 83 | chunk_header = search_results[0]['metadata']['chunk_header'] 84 | print (f"\nCHUNK HEADER\n{chunk_header}") 85 | 86 | # verify that the segment headers are as expected 87 | segment_header = kb._get_segment_header(doc_id="doc_1", chunk_index=0) 88 | 89 | assert "Section context" not in chunk_header 90 | assert "AGI" not in segment_header 91 | 92 | # delete the KnowledgeBase object 93 | kb.delete() 94 | 95 | 96 | def cleanup(self): 97 | kb = KnowledgeBase(kb_id="levels_of_agi", exists_ok=True) 98 | kb.delete() 99 | 100 | if __name__ == "__main__": 101 | unittest.main() -------------------------------------------------------------------------------- /tests/integration/test_create_kb_and_query.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import unittest 4 | 5 | # add ../../ to sys.path 6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) 7 | 8 | from dsrag.create_kb import create_kb_from_file 9 | from dsrag.knowledge_base import KnowledgeBase 10 | 11 | 12 | class TestCreateKB(unittest.TestCase): 13 | def test_create_kb_from_file(self): 14 | self.cleanup() # delete the KnowledgeBase object if it exists so we can start fresh 15 | 16 | # Get the absolute path of the script file 17 | script_dir = os.path.dirname(os.path.abspath(__file__)) 18 | 19 | # Construct the absolute path to the dataset file 20 | file_path = os.path.join(script_dir, "../data/levels_of_agi.pdf") 21 | 22 | kb_id = "levels_of_agi" 23 | kb = create_kb_from_file(kb_id, file_path) 24 | 25 | # verify that the document is in the chunk db 26 | doc_ids = kb.chunk_db.get_all_doc_ids() 27 | self.assertEqual(len(doc_ids), 1) 28 | 29 | # assert that the document title and summary are present and reasonable 30 | document_title = kb.chunk_db.get_document_title(doc_id="levels_of_agi.pdf", chunk_index=0) 31 | self.assertIn("agi", document_title.lower()) 32 | document_summary = kb.chunk_db.get_document_summary(doc_id="levels_of_agi.pdf", chunk_index=0) 33 | self.assertIn("agi", document_summary.lower()) 34 | 35 | # run a query and verify results are returned 36 | search_queries = ["What are the levels of AGI?", "What is the highest level of AGI?"] 37 | segment_info = kb.query(search_queries) 38 | self.assertGreater(len(segment_info[0]), 0) 39 | 40 | # delete the KnowledgeBase object 41 | kb.delete() 42 | 43 | def cleanup(self): 44 | kb = KnowledgeBase(kb_id="levels_of_agi", exists_ok=True) 45 | kb.delete() 46 | 47 | if __name__ == "__main__": 48 | unittest.main() -------------------------------------------------------------------------------- /tests/integration/test_create_kb_and_query_chromadb.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import unittest 4 | 5 | # add ../../ to sys.path 6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) 7 | 8 | from dsrag.knowledge_base import KnowledgeBase 9 | from dsrag.database.vector.chroma_db import ChromaDB 10 | 11 | @unittest.skipIf(os.environ.get('GITHUB_ACTIONS') == 'true', "ChromaDB is not available on GitHub Actions") 12 | class TestCreateKB(unittest.TestCase): 13 | def test__001_create_kb_and_query(self): 14 | self.cleanup() # delete the KnowledgeBase object if it exists so we can start fresh 15 | 16 | # Get the absolute path of the script file 17 | script_dir = os.path.dirname(os.path.abspath(__file__)) 18 | 19 | # Construct the absolute path to the dataset file 20 | file_path = os.path.join(script_dir, "../data/levels_of_agi.pdf") 21 | 22 | kb_id = "levels_of_agi" 23 | 24 | vector_db = ChromaDB(kb_id=kb_id) 25 | kb = KnowledgeBase(kb_id=kb_id, vector_db=vector_db, exists_ok=False) 26 | kb.reranker.model = "rerank-english-v3.0" 27 | kb.add_document( 28 | doc_id="levels_of_agi.pdf", 29 | document_title="Levels of AGI", 30 | file_path=file_path, 31 | semantic_sectioning_config={"use_semantic_sectioning": False}, 32 | auto_context_config={ 33 | "use_generated_title": False, 34 | "get_document_summary": False 35 | } 36 | ) 37 | 38 | # verify that the document is in the chunk db 39 | self.assertEqual(len(kb.chunk_db.get_all_doc_ids()), 1) 40 | 41 | # run a query and verify results are returned 42 | search_queries = [ 43 | "What are the levels of AGI?", 44 | "What is the highest level of AGI?", 45 | ] 46 | segment_info = kb.query(search_queries) 47 | self.assertGreater(len(segment_info[0]), 0) 48 | 49 | # Assert that the chunk page start and end are correct 50 | self.assertEqual(segment_info[0]["chunk_page_start"], 5) 51 | self.assertEqual(segment_info[0]["chunk_page_end"], 8) 52 | 53 | def test__002_test_query_with_metadata_filter(self): 54 | 55 | kb = KnowledgeBase(kb_id="levels_of_agi", exists_ok=True) 56 | document_text = "AGI has many levels. The highest level of AGI is level 5." 57 | kb.add_document( 58 | doc_id="test_doc", 59 | document_title="AGI Test Document", 60 | text=document_text, 61 | semantic_sectioning_config={"use_semantic_sectioning": False}, 62 | auto_context_config={ 63 | "use_generated_title": False, # No need to generate a title for this document for testing purposes 64 | "get_document_summary": False 65 | } 66 | ) 67 | 68 | search_queries = [ 69 | "What is the highest level of AGI?", 70 | ] 71 | metadata_filter = {"field": "doc_id", "operator": "equals", "value": "test_doc"} 72 | segment_info = kb.query(search_queries, metadata_filter=metadata_filter, rse_params={"minimum_value": 0}) 73 | self.assertEqual(len(segment_info), 1) 74 | self.assertEqual(segment_info[0]["doc_id"], "test_doc") 75 | self.assertEqual(segment_info[0]["chunk_page_start"], None) 76 | self.assertEqual(segment_info[0]["chunk_page_end"], None) 77 | 78 | 79 | def test__003_remove_document(self): 80 | kb = KnowledgeBase(kb_id="levels_of_agi", exists_ok=True) 81 | kb.delete_document(doc_id="levels_of_agi.pdf") 82 | kb.delete_document(doc_id="test_doc") 83 | self.assertEqual(len(kb.chunk_db.get_all_doc_ids()), 0) 84 | 85 | # delete the KnowledgeBase object 86 | kb.delete() 87 | 88 | def test__004_query_empty_kb(self): 89 | kb = KnowledgeBase(kb_id="levels_of_agi", exists_ok=True) 90 | results = kb.query(["What are the levels of AGI?"]) 91 | self.assertEqual(len(results), 0) 92 | kb.delete() 93 | 94 | def cleanup(self): 95 | kb = KnowledgeBase(kb_id="levels_of_agi", exists_ok=True) 96 | kb.delete() 97 | 98 | 99 | if __name__ == "__main__": 100 | unittest.main() 101 | -------------------------------------------------------------------------------- /tests/integration/test_create_kb_and_query_weaviate.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import unittest 4 | 5 | # add ../../ to sys.path 6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) 7 | 8 | from dsrag.knowledge_base import KnowledgeBase 9 | from dsrag.database.vector import WeaviateVectorDB 10 | 11 | 12 | @unittest.skipIf(os.environ.get('GITHUB_ACTIONS') == 'true', "Weaviate is not available on GitHub Actions") 13 | class TestCreateKB(unittest.TestCase): 14 | def test_create_kb_and_query(self): 15 | self.cleanup() # delete the KnowledgeBase object if it exists so we can start fresh 16 | 17 | # Get the absolute path of the script file 18 | script_dir = os.path.dirname(os.path.abspath(__file__)) 19 | 20 | # Construct the absolute path to the dataset file 21 | file_path = os.path.join(script_dir, "../data/levels_of_agi.pdf") 22 | 23 | kb_id = "levels_of_agi" 24 | 25 | vector_db = WeaviateVectorDB(kb_id=kb_id, use_embedded_weaviate=True) 26 | kb = KnowledgeBase(kb_id=kb_id, vector_db=vector_db, exists_ok=False) 27 | kb.reranker.model = "rerank-english-v3.0" 28 | kb.add_document( 29 | doc_id="levels_of_agi.pdf", 30 | document_title="Levels of AGI", 31 | file_path=file_path, 32 | semantic_sectioning_config={"use_semantic_sectioning": False}, 33 | auto_context_config={ 34 | "use_generated_title": False, 35 | "get_document_summary": False 36 | } 37 | ) 38 | 39 | # verify that the document is in the chunk db 40 | self.assertEqual(len(kb.chunk_db.get_all_doc_ids()), 1) 41 | 42 | # run a query and verify results are returned 43 | search_queries = [ 44 | "What are the levels of AGI?", 45 | "What is the highest level of AGI?", 46 | ] 47 | segment_info = kb.query(search_queries) 48 | self.assertGreater(len(segment_info[0]), 0) 49 | # Assert that the chunk page start and end are correct 50 | self.assertEqual(segment_info[0]["chunk_page_start"], 5) 51 | self.assertEqual(segment_info[0]["chunk_page_end"], 8) 52 | 53 | # delete the KnowledgeBase object 54 | kb.delete() 55 | 56 | def cleanup(self): 57 | kb = KnowledgeBase(kb_id="levels_of_agi", exists_ok=True) 58 | kb.delete() 59 | 60 | 61 | if __name__ == "__main__": 62 | unittest.main() 63 | -------------------------------------------------------------------------------- /tests/integration/test_custom_embedding_subclass.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from openai import OpenAI 4 | import unittest 5 | 6 | # add ../../ to sys.path 7 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) 8 | 9 | from dsrag.embedding import Embedding 10 | from dsrag.knowledge_base import KnowledgeBase 11 | 12 | class CustomEmbedding(Embedding): 13 | def __init__(self, model: str = "text-embedding-3-small", dimension: int = 768): 14 | super().__init__(dimension) 15 | self.model = model 16 | self.client = OpenAI() 17 | 18 | def get_embeddings(self, text, input_type=None): 19 | response = self.client.embeddings.create(input=text, model=self.model, dimensions=int(self.dimension)) 20 | embeddings = [embedding_item.embedding for embedding_item in response.data] 21 | return embeddings[0] if isinstance(text, str) else embeddings 22 | 23 | def to_dict(self): 24 | base_dict = super().to_dict() 25 | base_dict.update({ 26 | 'model': self.model 27 | }) 28 | return base_dict 29 | 30 | 31 | class TestCustomEmbeddingSubclass(unittest.TestCase): 32 | def cleanup(self): 33 | kb = KnowledgeBase(kb_id="test_kb", exists_ok=True) 34 | kb.delete() 35 | 36 | def test_custom_embedding_subclass(self): 37 | # cleanup the knowledge base 38 | self.cleanup() 39 | 40 | # initialize a KnowledgeBase object with the custom embedding 41 | kb_id = "test_kb" 42 | kb = KnowledgeBase(kb_id, embedding_model=CustomEmbedding(model="text-embedding-3-large", dimension=1024), exists_ok=False) 43 | 44 | # load the knowledge base 45 | kb = KnowledgeBase(kb_id, exists_ok=True) 46 | 47 | # verify that the embedding model is set correctly 48 | self.assertEqual(kb.embedding_model.model, "text-embedding-3-large") 49 | self.assertEqual(kb.embedding_model.dimension, 1024) 50 | 51 | # delete the knowledge base 52 | self.cleanup() 53 | 54 | if __name__ == "__main__": 55 | unittest.main() -------------------------------------------------------------------------------- /tests/integration/test_custom_langchain_retriever.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import unittest 4 | 5 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) 6 | 7 | from dsrag.create_kb import create_kb_from_file 8 | from dsrag.knowledge_base import KnowledgeBase 9 | from integrations.langchain_retriever import DsRAGLangchainRetriever 10 | from langchain_core.documents import Document 11 | 12 | 13 | class TestDsRAGLangchainRetriever(unittest.TestCase): 14 | 15 | def setUp(self): 16 | 17 | """ We have to create a knowledge base from a file before we can run the retriever tests.""" 18 | 19 | # Get the absolute path of the script file 20 | script_dir = os.path.dirname(os.path.abspath(__file__)) 21 | 22 | # Construct the absolute path to the dataset file 23 | file_path = os.path.join(script_dir, "../data/levels_of_agi.pdf") 24 | 25 | kb_id = "levels_of_agi" 26 | create_kb_from_file(kb_id, file_path) 27 | self.kb_id = kb_id 28 | 29 | def test_create_kb_and_query(self): 30 | 31 | retriever = DsRAGLangchainRetriever(self.kb_id) 32 | documents = retriever.invoke("What are the levels of AGI?") 33 | 34 | # Make sure results are returned 35 | self.assertTrue(len(documents) > 0) 36 | 37 | # Make sure the results are Document objects, and that they have page content and metadata 38 | for doc in documents: 39 | assert isinstance(doc, Document) 40 | assert doc.page_content 41 | assert doc.metadata 42 | 43 | def tearDown(self): 44 | # Delete the KnowledgeBase object after the test runs 45 | kb = KnowledgeBase(kb_id="levels_of_agi", exists_ok=True) 46 | kb.delete() 47 | 48 | 49 | if __name__ == "__main__": 50 | unittest.main() -------------------------------------------------------------------------------- /tests/integration/test_save_and_load.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import unittest 4 | import shutil 5 | from unittest.mock import patch 6 | 7 | # add ../../ to sys.path 8 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) 9 | 10 | from dsrag.knowledge_base import KnowledgeBase 11 | from dsrag.llm import OpenAIChatAPI 12 | from dsrag.embedding import CohereEmbedding 13 | from dsrag.database.vector import BasicVectorDB # Import necessary DBs for override 14 | from dsrag.database.chunk import BasicChunkDB 15 | from dsrag.dsparse.file_parsing.file_system import LocalFileSystem 16 | 17 | class TestSaveAndLoad(unittest.TestCase): 18 | def cleanup(self): 19 | # Use a specific kb_id for testing 20 | kb = KnowledgeBase(kb_id="test_kb_override", exists_ok=True) 21 | # Try deleting, ignore if it doesn't exist 22 | try: 23 | kb.delete() 24 | except FileNotFoundError: 25 | pass 26 | # Ensure the base KB used in other tests is also cleaned up 27 | kb_main = KnowledgeBase(kb_id="test_kb", exists_ok=True) 28 | try: 29 | kb_main.delete() 30 | except FileNotFoundError: 31 | pass 32 | 33 | def test_save_and_load(self): 34 | self.cleanup() 35 | 36 | # initialize a KnowledgeBase object 37 | auto_context_model = OpenAIChatAPI(model="gpt-4o-mini") 38 | embedding_model = CohereEmbedding(model="embed-english-v3.0") 39 | kb_id = "test_kb" 40 | kb = KnowledgeBase(kb_id=kb_id, auto_context_model=auto_context_model, embedding_model=embedding_model, exists_ok=False) 41 | 42 | # load the KnowledgeBase object 43 | kb1 = KnowledgeBase(kb_id=kb_id) 44 | 45 | # verify that the KnowledgeBase object has the right parameters 46 | self.assertEqual(kb1.auto_context_model.model, "gpt-4o-mini") 47 | self.assertEqual(kb1.embedding_model.model, "embed-english-v3.0") 48 | 49 | # delete the KnowledgeBase object 50 | kb1.delete() 51 | 52 | # Check if the underlying storage path exists after deletion 53 | chunk_db_path = os.path.join(os.path.expanduser("~/dsRAG"), "chunk_data", kb_id) 54 | self.assertFalse(os.path.exists(chunk_db_path), f"Chunk DB path {chunk_db_path} should not exist after delete.") 55 | 56 | def test_load_override_warning(self): 57 | """Test that warnings are logged when overriding components during load.""" 58 | self.cleanup() 59 | kb_id = "test_kb_override" 60 | 61 | # 1. Create initial KB to save config 62 | kb_initial = KnowledgeBase(kb_id=kb_id, exists_ok=False) 63 | 64 | # Mock DB/FS instances for overriding 65 | mock_vector_db = BasicVectorDB(kb_id=kb_id, storage_directory="~/dsRAG_temp_v") 66 | mock_chunk_db = BasicChunkDB(kb_id=kb_id, storage_directory="~/dsRAG_temp_c") 67 | mock_file_system = LocalFileSystem(base_path="~/dsRAG_temp_f") 68 | 69 | # 2. Patch logging.warning 70 | with patch('logging.warning') as mock_warning: 71 | # 3. Load KB again with overrides 72 | kb_loaded = KnowledgeBase( 73 | kb_id=kb_id, 74 | vector_db=mock_vector_db, 75 | chunk_db=mock_chunk_db, 76 | file_system=mock_file_system, 77 | exists_ok=True # This ensures _load is called 78 | ) 79 | 80 | # 4. Assertions 81 | self.assertEqual(mock_warning.call_count, 3, "Expected 3 warning calls (vector_db, chunk_db, file_system)") 82 | 83 | expected_calls = [ 84 | f"Overriding stored vector_db for KB '{kb_id}' during load.", 85 | f"Overriding stored chunk_db for KB '{kb_id}' during load.", 86 | f"Overriding stored file_system for KB '{kb_id}' during load.", 87 | ] 88 | 89 | call_args_list = [call.args[0] for call in mock_warning.call_args_list] 90 | call_kwargs_list = [call.kwargs for call in mock_warning.call_args_list] 91 | 92 | # Check if all expected messages are present in the actual calls 93 | for expected_msg in expected_calls: 94 | self.assertTrue(any(expected_msg in str(arg) for arg in call_args_list), f"Expected warning message '{expected_msg}' not found.") 95 | 96 | # Check the 'extra' parameter in all calls 97 | expected_extra = {'kb_id': kb_id} 98 | for kwargs in call_kwargs_list: 99 | self.assertIn('extra', kwargs, "Warning call missing 'extra' parameter.") 100 | self.assertEqual(kwargs['extra'], expected_extra, f"Warning call 'extra' parameter mismatch. Expected {expected_extra}, got {kwargs['extra']}") 101 | 102 | 103 | # 5. Cleanup 104 | kb_loaded.delete() 105 | # Explicitly try to remove temp dirs if they exist (though kb.delete should handle them) 106 | temp_dirs = ["~/dsRAG_temp_v", "~/dsRAG_temp_c", "~/dsRAG_temp_f"] 107 | for temp_dir in temp_dirs: 108 | expanded_path = os.path.expanduser(temp_dir) 109 | if os.path.exists(expanded_path): 110 | try: 111 | shutil.rmtree(expanded_path) # Use shutil.rmtree 112 | # print(f"Cleaned up temp dir: {expanded_path}") # Optional: uncomment for debug 113 | except OSError as e: 114 | print(f"Error removing temp dir {expanded_path}: {e}") # Log error if removal fails 115 | 116 | 117 | if __name__ == "__main__": 118 | unittest.main() -------------------------------------------------------------------------------- /tests/unit/test_chunking.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import unittest 4 | 5 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) 6 | from dsrag.dsparse.sectioning_and_chunking.chunking import chunk_document, chunk_sub_section, find_lines_in_range 7 | 8 | class TestChunking(unittest.TestCase): 9 | def setUp(self): 10 | self.sections = [ 11 | { 12 | 'title': 'Section 1', 13 | 'start': 0, 14 | 'end': 3, 15 | 'content': 'This is the first section of the document.' 16 | }, 17 | { 18 | 'title': 'Section 2', 19 | 'start': 4, 20 | 'end': 7, 21 | 'content': 'This is the second section of the document.' 22 | } 23 | ] 24 | 25 | self.document_lines = [ 26 | { 27 | 'content': 'This is the first line of the document. And here is another sentence.', 28 | 'element_type': 'NarrativeText', 29 | 'page_number': 1, 30 | 'is_visual': False 31 | }, 32 | { 33 | 'content': 'This is the second line of the document.', 34 | 'element_type': 'NarrativeText', 35 | 'page_number': 1, 36 | 'is_visual': False 37 | }, 38 | { 39 | 'content': 'This is the third line of the document........', 40 | 'element_type': 'NarrativeText', 41 | 'page_number': 1, 42 | 'is_visual': False 43 | }, 44 | { 45 | 'content': 'This is the fourth line of the document.', 46 | 'element_type': 'NarrativeText', 47 | 'page_number': 1, 48 | 'is_visual': False 49 | }, 50 | { 51 | 'content': 'This is the fifth line of the document. With another sentence.', 52 | 'element_type': 'NarrativeText', 53 | 'page_number': 1, 54 | 'is_visual': False 55 | }, 56 | { 57 | 'content': 'This is the sixth line of the document.', 58 | 'element_type': 'NarrativeText', 59 | 'page_number': 1, 60 | 'is_visual': False 61 | }, 62 | { 63 | 'content': 'This is the seventh line of the document.', 64 | 'element_type': 'NarrativeText', 65 | 'page_number': 1, 66 | 'is_visual': False 67 | }, 68 | { 69 | 'content': 'This is the eighth line of the document. And here is another sentence that is a bit longer', 70 | 'element_type': 'NarrativeText', 71 | 'page_number': 1, 72 | 'is_visual': False 73 | } 74 | ] 75 | 76 | def test__chunk_document(self): 77 | chunk_size = 90 78 | min_length_for_chunking = 120 79 | chunks = chunk_document(self.sections, self.document_lines, chunk_size, min_length_for_chunking) 80 | assert len(chunks[-1]["content"]) > 45 81 | 82 | def test__chunk_sub_section(self): 83 | chunk_size = 90 84 | chunks_text, chunk_line_indices = chunk_sub_section(4, 7, self.document_lines, chunk_size) 85 | assert len(chunks_text) == 3 86 | assert chunk_line_indices == [(4, 4), (5, 6), (7, 7)] 87 | 88 | def test__find_lines_in_range(self): 89 | # Test find_lines_in_range 90 | # (line_idx, start_char, end_char) 91 | line_char_ranges = [ 92 | (0, 0, 49), 93 | (1, 50, 99), 94 | (2, 100, 149), 95 | (3, 150, 199), 96 | (4, 200, 249), 97 | (5, 250, 299), 98 | (6, 300, 349), 99 | (7, 350, 399) 100 | ] 101 | 102 | chunk_start = 50 103 | chunk_end = 82 104 | line_start = 0 105 | line_end = 0 106 | chunk_line_start, chunk_line_end = find_lines_in_range(chunk_start, chunk_end, line_char_ranges, line_start, line_end) 107 | assert chunk_line_start == 1 108 | assert chunk_line_end == 1 109 | 110 | chunk_start = 50 111 | chunk_end = 150 112 | line_start = 0 113 | line_end = 0 114 | chunk_line_start, chunk_line_end = find_lines_in_range(chunk_start, chunk_end, line_char_ranges, line_start, line_end) 115 | assert chunk_line_start == 1 116 | assert chunk_line_end == 3 117 | 118 | if __name__ == "__main__": 119 | unittest.main() -------------------------------------------------------------------------------- /tests/unit/test_embedding.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import unittest 4 | 5 | sys.path.append(os.path.join(os.path.dirname(__file__), "../..")) 6 | 7 | from dsrag.embedding import ( 8 | OpenAIEmbedding, 9 | CohereEmbedding, 10 | OllamaEmbedding, 11 | Embedding, 12 | ) 13 | 14 | 15 | class TestEmbedding(unittest.TestCase): 16 | def test__get_embeddings_openai(self): 17 | input_text = "Hello, world!" 18 | model = "text-embedding-3-small" 19 | dimension = 768 20 | embedding_provider = OpenAIEmbedding(model, dimension) 21 | embedding = embedding_provider.get_embeddings(input_text) 22 | self.assertEqual(len(embedding), dimension) 23 | 24 | def test__get_embeddings_cohere(self): 25 | input_text = "Hello, world!" 26 | model = "embed-english-v3.0" 27 | embedding_provider = CohereEmbedding(model) 28 | embedding = embedding_provider.get_embeddings(input_text, input_type="query") 29 | self.assertEqual(len(embedding), 1024) 30 | 31 | def test__get_embeddings_ollama(self): 32 | input_text = "Hello, world!" 33 | model = "llama3" 34 | dimension = 4096 35 | try: 36 | embedding_provider = OllamaEmbedding(model, dimension) 37 | embedding = embedding_provider.get_embeddings(input_text) 38 | except Exception as e: 39 | print (e) 40 | if e.__class__.__name__ == "ConnectError": 41 | print("Connection failed") 42 | return 43 | self.assertEqual(len(embedding), dimension) 44 | 45 | def test__get_embeddings_openai_with_list(self): 46 | input_texts = ["Hello, world!", "Goodbye, world!"] 47 | model = "text-embedding-3-small" 48 | dimension = 768 49 | embedding_provider = OpenAIEmbedding(model, dimension) 50 | embeddings = embedding_provider.get_embeddings(input_texts) 51 | self.assertEqual(len(embeddings), 2) 52 | self.assertTrue(all(len(embed) == dimension for embed in embeddings)) 53 | 54 | def test__get_embeddings_cohere_with_list(self): 55 | input_texts = ["Hello, world!", "Goodbye, world!"] 56 | model = "embed-english-v3.0" 57 | embedding_provider = CohereEmbedding(model) 58 | embeddings = embedding_provider.get_embeddings(input_texts, input_type="query") 59 | assert len(embeddings) == 2 60 | assert all(len(embed) == 1024 for embed in embeddings) 61 | self.assertEqual(len(embeddings), 2) 62 | self.assertTrue(all(len(embed) == 1024 for embed in embeddings)) 63 | 64 | def test__get_embeddings_ollama_with_list_llama2(self): 65 | input_texts = ["Hello, world!", "Goodbye, world!"] 66 | model = "llama2" 67 | dimension = 4096 68 | try: 69 | embedding_provider = OllamaEmbedding(model, dimension) 70 | embeddings = embedding_provider.get_embeddings(input_texts) 71 | except Exception as e: 72 | print (e) 73 | if e.__class__.__name__ == "ConnectError": 74 | print ("Connection failed") 75 | return 76 | self.assertEqual(len(embeddings), 2) 77 | self.assertTrue(all(len(embed) == dimension for embed in embeddings)) 78 | 79 | def test__get_embeddings_ollama_with_list_minilm(self): 80 | input_texts = ["Hello, world!", "Goodbye, world!"] 81 | model = "all-minilm" 82 | dimension = 384 83 | try: 84 | embedding_provider = OllamaEmbedding(model, dimension) 85 | embeddings = embedding_provider.get_embeddings(input_texts) 86 | except Exception as e: 87 | print (e) 88 | if e.__class__.__name__ == "ConnectError": 89 | print ("Connection failed") 90 | return 91 | self.assertEqual(len(embeddings), 2) 92 | self.assertTrue(all(len(embed) == dimension for embed in embeddings)) 93 | 94 | def test__get_embeddings_ollama_with_list_nomic(self): 95 | input_texts = ["Hello, world!", "Goodbye, world!"] 96 | model = "nomic-embed-text" 97 | dimension = 768 98 | try: 99 | embedding_provider = OllamaEmbedding(model, dimension) 100 | embeddings = embedding_provider.get_embeddings(input_texts) 101 | except Exception as e: 102 | print (e) 103 | if e.__class__.__name__ == "ConnectError": 104 | print ("Connection failed") 105 | return 106 | self.assertEqual(len(embeddings), 2) 107 | self.assertTrue(all(len(embed) == dimension for embed in embeddings)) 108 | 109 | def test__initialize_from_config(self): 110 | config = { 111 | 'subclass_name': 'OpenAIEmbedding', 112 | 'model': 'text-embedding-3-small', 113 | 'dimension': 1024 114 | } 115 | embedding_instance = Embedding.from_dict(config) 116 | self.assertIsInstance(embedding_instance, OpenAIEmbedding) 117 | self.assertEqual(embedding_instance.model, 'text-embedding-3-small') 118 | self.assertEqual(embedding_instance.dimension, 1024) 119 | 120 | 121 | if __name__ == "__main__": 122 | unittest.main() -------------------------------------------------------------------------------- /tests/unit/test_file_system.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import unittest 4 | import dotenv 5 | 6 | dotenv.load_dotenv() 7 | 8 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) 9 | from dsrag.dsparse.file_parsing.file_system import LocalFileSystem, S3FileSystem 10 | from dsrag.knowledge_base import KnowledgeBase 11 | 12 | 13 | class TestFileSystem(unittest.TestCase): 14 | 15 | @classmethod 16 | def setUpClass(self): 17 | self.base_path = "~/dsrag_test_file_system" 18 | self.local_file_system = LocalFileSystem(base_path=self.base_path) 19 | self.s3_file_system = S3FileSystem( 20 | base_path=self.base_path, 21 | bucket_name=os.environ["AWS_S3_BUCKET_NAME"], 22 | region_name=os.environ["AWS_S3_REGION"], 23 | access_key=os.environ["AWS_S3_ACCESS_KEY"], 24 | secret_key=os.environ["AWS_S3_SECRET_KEY"] 25 | ) 26 | self.kb_id = "test_file_system_kb" 27 | self.kb = KnowledgeBase(kb_id=self.kb_id, file_system=self.local_file_system) 28 | 29 | def test__001_load_kb(self): 30 | kb = KnowledgeBase(kb_id=self.kb_id) 31 | # Assert the class 32 | assert isinstance(kb.file_system, LocalFileSystem) 33 | 34 | def test__002_overwrite_file_system(self): 35 | kb = KnowledgeBase(kb_id=self.kb_id, file_system=self.s3_file_system) 36 | # Assert the class. We don't want to allow overwriting the file system 37 | assert isinstance(kb.file_system, S3FileSystem) 38 | 39 | @classmethod 40 | def tearDownClass(self): 41 | try: 42 | os.system(f"rm -rf {self.base_path}") 43 | except: 44 | pass 45 | 46 | 47 | if __name__ == "__main__": 48 | unittest.main() -------------------------------------------------------------------------------- /tests/unit/test_llm_class.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import unittest 4 | 5 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) 6 | from dsrag.llm import OpenAIChatAPI, AnthropicChatAPI, OllamaAPI, LLM, GeminiAPI 7 | 8 | 9 | class TestLLM(unittest.TestCase): 10 | def test__openai_chat_api(self): 11 | chat_api = OpenAIChatAPI() 12 | chat_messages = [ 13 | {"role": "user", "content": "Hello, how are you?"}, 14 | ] 15 | response = chat_api.make_llm_call(chat_messages) 16 | self.assertIsInstance(response, str) 17 | self.assertGreater(len(response), 0) 18 | 19 | def test__anthropic_chat_api(self): 20 | chat_api = AnthropicChatAPI() 21 | chat_messages = [ 22 | {"role": "user", "content": "What's the weather like today?"}, 23 | ] 24 | response = chat_api.make_llm_call(chat_messages) 25 | self.assertIsInstance(response, str) 26 | self.assertGreater(len(response), 0) 27 | 28 | def test__ollama_api(self): 29 | try: 30 | chat_api = OllamaAPI() 31 | except Exception as e: 32 | print(e) 33 | if e.__class__.__name__ == "ConnectError": 34 | print ("Connection failed") 35 | return 36 | chat_messages = [ 37 | {"role": "user", "content": "Who's your daddy?"}, 38 | ] 39 | response = chat_api.make_llm_call(chat_messages) 40 | self.assertIsInstance(response, str) 41 | self.assertGreater(len(response), 0) 42 | 43 | def test__gemini_api(self): 44 | try: 45 | chat_api = GeminiAPI() 46 | except ValueError as e: 47 | # Skip test if API key is not set 48 | if "GEMINI_API_KEY environment variable not set" in str(e): 49 | print(f"Skipping Gemini test: {e}") 50 | return 51 | else: 52 | raise # Re-raise other ValueErrors 53 | except Exception as e: 54 | # Handle other potential exceptions during init (e.g., genai configuration issues) 55 | print(f"Skipping Gemini test due to initialization error: {e}") 56 | return 57 | 58 | chat_messages = [ 59 | {"role": "user", "content": "Explain the concept of RAG briefly."}, 60 | ] 61 | try: 62 | response = chat_api.make_llm_call(chat_messages) 63 | self.assertIsInstance(response, str) 64 | self.assertGreater(len(response), 0) 65 | except Exception as e: 66 | # Catch potential API call errors (e.g., connection, authentication within genai) 67 | self.fail(f"GeminiAPI make_llm_call failed with exception: {e}") 68 | 69 | def test__save_and_load_from_dict(self): 70 | chat_api = OpenAIChatAPI(temperature=0.5, max_tokens=2000) 71 | config = chat_api.to_dict() 72 | chat_api_loaded = LLM.from_dict(config) 73 | self.assertEqual(chat_api_loaded.model, chat_api.model) 74 | self.assertEqual(chat_api_loaded.temperature, chat_api.temperature) 75 | self.assertEqual(chat_api_loaded.max_tokens, chat_api.max_tokens) 76 | 77 | if __name__ == "__main__": 78 | unittest.main() -------------------------------------------------------------------------------- /tests/unit/test_reranker.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import unittest 4 | sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) 5 | 6 | from dsrag.reranker import Reranker, CohereReranker, NoReranker 7 | 8 | 9 | class TestReranker(unittest.TestCase): 10 | def test_rerank_search_results_cohere(self): 11 | query = "Hello, world!" 12 | search_results = [ 13 | { 14 | "metadata": { 15 | "chunk_header": "", 16 | "chunk_text": "Hello, world!" 17 | } 18 | }, 19 | { 20 | "metadata": { 21 | "chunk_header": "", 22 | "chunk_text": "Goodbye, world!" 23 | } 24 | } 25 | ] 26 | reranker = CohereReranker() 27 | reranked_search_results = reranker.rerank_search_results(query, search_results) 28 | self.assertEqual(len(reranked_search_results), 2) 29 | self.assertEqual(reranked_search_results[0]["metadata"]["chunk_text"], "Hello, world!") 30 | self.assertEqual(reranked_search_results[1]["metadata"]["chunk_text"], "Goodbye, world!") 31 | 32 | def test_save_and_load_from_dict(self): 33 | reranker = CohereReranker(model="embed-english-v3.0") 34 | config = reranker.to_dict() 35 | reranker_instance = Reranker.from_dict(config) 36 | self.assertEqual(reranker_instance.model, "embed-english-v3.0") 37 | 38 | def test_no_reranker(self): 39 | reranker = NoReranker() 40 | query = "Hello, world!" 41 | search_results = [ 42 | { 43 | "metadata": { 44 | "chunk_header": "", 45 | "chunk_text": "Hello, world!" 46 | } 47 | }, 48 | { 49 | "metadata": { 50 | "chunk_header": "", 51 | "chunk_text": "Goodbye, world!" 52 | } 53 | } 54 | ] 55 | reranked_search_results = reranker.rerank_search_results(query, search_results) 56 | self.assertEqual(len(reranked_search_results), 2) 57 | self.assertEqual(reranked_search_results[0]["metadata"]["chunk_text"], "Hello, world!") 58 | self.assertEqual(reranked_search_results[1]["metadata"]["chunk_text"], "Goodbye, world!") 59 | 60 | 61 | if __name__ == "__main__": 62 | unittest.main() --------------------------------------------------------------------------------