├── .github
└── workflows
│ ├── deploy-docs.yml
│ ├── manual-workflow.yml
│ └── run-tests.yml
├── .gitignore
├── LICENSE
├── MANIFEST.in
├── README.md
├── docs
├── api
│ ├── chat.md
│ └── knowledge_base.md
├── community
│ ├── contributing.md
│ ├── professional-services.md
│ └── support.md
├── concepts
│ ├── chat.md
│ ├── citations.md
│ ├── components.md
│ ├── config.md
│ ├── knowledge-bases.md
│ ├── logging.md
│ ├── overview.md
│ └── vlm.md
├── getting-started
│ ├── installation.md
│ └── quickstart.md
├── index.md
└── logo.png
├── dsrag
├── __init__.py
├── add_document.py
├── auto_context.py
├── auto_query.py
├── chat
│ ├── __init__.py
│ ├── auto_query.py
│ ├── chat.py
│ ├── chat_types.py
│ ├── citations.py
│ ├── instructor_get_response.py
│ └── model_names.py
├── create_kb.py
├── custom_term_mapping.py
├── database
│ ├── __init__.py
│ ├── chat_thread
│ │ ├── basic_db.py
│ │ ├── db.py
│ │ └── sqlite_db.py
│ ├── chunk
│ │ ├── __init__.py
│ │ ├── basic_db.py
│ │ ├── db.py
│ │ ├── dynamo_db.py
│ │ ├── postgres_db.py
│ │ ├── sqlite_db.py
│ │ └── types.py
│ └── vector
│ │ ├── __init__.py
│ │ ├── basic_db.py
│ │ ├── chroma_db.py
│ │ ├── db.py
│ │ ├── milvus_db.py
│ │ ├── pinecone_db.py
│ │ ├── postgres_db.py
│ │ ├── qdrant_db.py
│ │ ├── types.py
│ │ └── weaviate_db.py
├── dsparse
│ ├── MANIFEST.in
│ ├── README.md
│ ├── __init__.py
│ ├── file_parsing
│ │ ├── __init__.py
│ │ ├── element_types.py
│ │ ├── file_system.py
│ │ ├── non_vlm_file_parsing.py
│ │ ├── vlm.py
│ │ └── vlm_file_parsing.py
│ ├── main.py
│ ├── models
│ │ ├── __init__.py
│ │ └── types.py
│ ├── pyproject.toml
│ ├── requirements.in
│ ├── requirements.txt
│ ├── sectioning_and_chunking
│ │ ├── __init__.py
│ │ ├── chunking.py
│ │ └── semantic_sectioning.py
│ ├── tests
│ │ ├── data
│ │ │ ├── chunks.json
│ │ │ ├── document_lines.json
│ │ │ ├── elements.json
│ │ │ └── sections.json
│ │ ├── integration
│ │ │ ├── test_semantic_sectioning.py
│ │ │ └── test_vlm_file_parsing.py
│ │ └── unit
│ │ │ ├── test_chunking.py
│ │ │ ├── test_file_system.py
│ │ │ └── test_semantic_sectioning.py
│ └── utils
│ │ ├── __init__.py
│ │ └── imports.py
├── embedding.py
├── knowledge_base.py
├── llm.py
├── metadata.py
├── reranker.py
├── rse.py
└── utils
│ └── imports.py
├── eval
├── financebench
│ ├── dsrag_financebench_create_kb.py
│ └── dsrag_financebench_eval.py
├── rse_eval.py
├── streaming_chat_test.py
└── vlm_eval.py
├── examples
├── dsRAG_motivation.ipynb
├── example_kb_data
│ ├── chunk_storage
│ │ ├── levels_of_agi.pkl
│ │ └── nike_10k.pkl
│ ├── metadata
│ │ ├── levels_of_agi.json
│ │ └── nike_10k.json
│ └── vector_storage
│ │ ├── levels_of_agi.pkl
│ │ └── nike_10k.pkl
└── logging_example.py
├── integrations
└── langchain_retriever.py
├── mkdocs.yml
├── pyproject.toml
├── requirements.in
├── requirements.txt
└── tests
├── data
├── chunks.json
├── document_lines.json
├── elements.json
├── financebench_sample_150.csv
├── les_miserables.txt
├── levels_of_agi.pdf
├── mck_energy_first_5_pages.pdf
├── nike_2023_annual_report.txt
├── page_7.png
└── sections.json
├── integration
├── test_chat_w_citations.py
├── test_chunk_and_segment_headers.py
├── test_create_kb_and_query.py
├── test_create_kb_and_query_chromadb.py
├── test_create_kb_and_query_weaviate.py
├── test_custom_embedding_subclass.py
├── test_custom_langchain_retriever.py
├── test_save_and_load.py
└── test_vlm_file_parsing.py
└── unit
├── test_auto_context.py
├── test_chunk_db.py
├── test_chunking.py
├── test_embedding.py
├── test_file_system.py
├── test_llm_class.py
├── test_reranker.py
└── test_vector_db.py
/.github/workflows/deploy-docs.yml:
--------------------------------------------------------------------------------
1 | name: Deploy Documentation
2 | on:
3 | push:
4 | branches:
5 | - main
6 | workflow_dispatch:
7 |
8 | # Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages
9 | permissions:
10 | contents: read
11 | pages: write
12 | id-token: write
13 |
14 | # Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued.
15 | concurrency:
16 | group: "pages"
17 | cancel-in-progress: false
18 |
19 | jobs:
20 | build:
21 | runs-on: ubuntu-latest
22 | steps:
23 | - uses: actions/checkout@v4
24 |
25 | - name: Setup Pages
26 | uses: actions/configure-pages@v5
27 |
28 | - uses: actions/setup-python@v4
29 | with:
30 | python-version: '3.12'
31 |
32 | - name: Install dependencies
33 | run: |
34 | python -m pip install --upgrade pip
35 | pip install mkdocs-material
36 | pip install mkdocstrings[python]
37 | # Install the package in development mode
38 | pip install -e .
39 |
40 | - name: Build with MkDocs
41 | run: mkdocs build
42 |
43 | - name: Upload artifact
44 | uses: actions/upload-pages-artifact@v3
45 | with:
46 | path: ./site
47 |
48 | # Deployment job
49 | deploy:
50 | environment:
51 | name: github-pages
52 | url: ${{ steps.deployment.outputs.page_url }}
53 | runs-on: ubuntu-latest
54 | needs: build
55 | steps:
56 | - name: Deploy to GitHub Pages
57 | id: deployment
58 | uses: actions/deploy-pages@v4
--------------------------------------------------------------------------------
/.github/workflows/manual-workflow.yml:
--------------------------------------------------------------------------------
1 | name: Manual Test Run
2 |
3 | on:
4 | workflow_dispatch:
5 | inputs:
6 | repository:
7 | description: 'Repository to run tests on (e.g., user/repo)'
8 | required: true
9 | branch:
10 | description: 'Branch to run tests on'
11 | required: true
12 |
13 | jobs:
14 | run-tests:
15 | runs-on: ubuntu-latest
16 |
17 | env:
18 | ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
19 | CO_API_KEY: ${{ secrets.CO_API_KEY }}
20 | OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
21 | VOYAGE_API_KEY: ${{ secrets.VOYAGE_API_KEY }}
22 |
23 | steps:
24 | - name: Checkout code
25 | uses: actions/checkout@v2
26 |
27 | - name: Set up Python
28 | uses: actions/setup-python@v2
29 | with:
30 | python-version: '3.x'
31 |
32 | - name: Install dependencies
33 | run: |
34 | python -m pip install --upgrade pip
35 | pip install -e ".[all-dbs,all-models]" # Install package with optional dependencies
36 | pip install pytest
37 |
38 | - name: Install dsParse dependencies
39 | run: |
40 | cd dsrag/dsparse && pip install -r requirements.txt
41 |
42 | - name: Install additional test dependencies
43 | run: |
44 | pip install vertexai psycopg2-binary
45 |
46 | - name: Run unit tests
47 | run: |
48 | pytest tests/unit
49 |
50 | - name: Run integration tests
51 | run: |
52 | pytest tests/integration
53 |
--------------------------------------------------------------------------------
/.github/workflows/run-tests.yml:
--------------------------------------------------------------------------------
1 | name: Run Tests
2 |
3 | on:
4 | pull_request_review:
5 | types: [submitted]
6 |
7 | jobs:
8 | run-tests:
9 | if: github.event.review.state == 'approved'
10 | runs-on: ubuntu-latest
11 |
12 | env:
13 | ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
14 | CO_API_KEY: ${{ secrets.CO_API_KEY }}
15 | OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
16 | VOYAGE_API_KEY: ${{ secrets.VOYAGE_API_KEY }}
17 | AWS_S3_BUCKET_NAME: ${{ secrets.AWS_S3_BUCKET_NAME }}
18 | AWS_S3_REGION: ${{ secrets.AWS_S3_REGION }}
19 | AWS_S3_ACCESS_KEY: ${{ secrets.AWS_S3_ACCESS_KEY }}
20 | AWS_S3_SECRET_KEY: ${{ secrets.AWS_S3_SECRET_KEY }}
21 | GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
22 | AWS_REGION: ${{ secrets.AWS_REGION }}
23 | AWS_DYNAMO_ACCESS_KEY: ${{ secrets.AWS_DYNAMO_ACCESS_KEY }}
24 | AWS_DYNAMO_SECRET_KEY: ${{ secrets.AWS_DYNAMO_SECRET_KEY }}
25 | PINECONE_API_KEY: ${{ secrets.PINECONE_API_KEY }}
26 | POSTGRES_DB: ${{ secrets.POSTGRES_DB }}
27 | POSTGRES_USER: ${{ secrets.POSTGRES_USER }}
28 | POSTGRES_PASSWORD: ${{ secrets.POSTGRES_PASSWORD }}
29 | POSTGRES_HOST: ${{ secrets.POSTGRES_HOST }}
30 |
31 | steps:
32 | - name: Checkout code
33 | uses: actions/checkout@v2
34 |
35 | - name: Set up Python
36 | uses: actions/setup-python@v2
37 | with:
38 | python-version: '3.12'
39 |
40 | - name: Install Poppler
41 | run: sudo apt-get install -y poppler-utils
42 |
43 | - name: Install dependencies
44 | run: |
45 | python -m pip install --upgrade pip
46 | pip install -e ".[all-dbs,all-models]" # Install package with optional dependencies
47 | pip install pytest
48 |
49 | - name: Install dsParse dependencies
50 | run: |
51 | cd dsrag/dsparse && pip install -r requirements.txt
52 |
53 | - name: Install additional test dependencies
54 | run: |
55 | pip install vertexai psycopg2-binary
56 |
57 | - name: Run unit tests
58 | run: |
59 | pytest tests/unit
60 |
61 | - name: Run integration tests
62 | run: |
63 | pytest tests/integration
64 |
65 | - name: Run dsParse unit tests
66 | run: |
67 | pytest dsrag/dsparse/tests/unit
68 |
69 | - name: Run dsParse integration tests
70 | run: |
71 | pytest dsrag/dsparse/tests/integration
72 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | .idea/
161 |
162 | # Embedded Weaviate data
163 | weaviate
164 |
165 | # DS_Store files
166 | .DS_Store
167 |
168 | .python-version
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 D-Star
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | exclude dsrag/dsparse/requirements.txt
2 | exclude dsrag/dsparse/requirements.in
3 | exclude dsrag/dsparse/MANIFEST.in
4 | exclude dsrag/dsparse/README.md
5 | exclude dsrag/dsparse/pyproject.toml
6 |
--------------------------------------------------------------------------------
/docs/api/chat.md:
--------------------------------------------------------------------------------
1 | # Chat
2 |
3 | The Chat module provides functionality for managing chat-based interactions with knowledge bases. It handles chat thread creation, message history, and generating responses using knowledge base search.
4 |
5 | ### Public Methods
6 |
7 | #### create_new_chat_thread
8 |
9 | ```python
10 | def create_new_chat_thread(chat_thread_params: ChatThreadParams, chat_thread_db: ChatThreadDB) -> str
11 | ```
12 |
13 | Create a new chat thread in the database.
14 |
15 | **Arguments**:
16 | - `chat_thread_params`: Parameters for the chat thread. Example:
17 | ```python
18 | {
19 | # Knowledge base IDs to use
20 | "kb_ids": ["kb1", "kb2"],
21 |
22 | # LLM model to use
23 | "model": "gpt-4o-mini",
24 |
25 | # Temperature for LLM sampling
26 | "temperature": 0.2,
27 |
28 | # System message for LLM
29 | "system_message": "You are a helpful assistant",
30 |
31 | # Model for auto-query generation
32 | "auto_query_model": "gpt-4o-mini",
33 |
34 | # Guidance for auto-query generation
35 | "auto_query_guidance": "",
36 |
37 | # Target response length (short/medium/long)
38 | "target_output_length": "medium",
39 |
40 | # Maximum tokens in chat history
41 | "max_chat_history_tokens": 8000,
42 |
43 | # Optional supplementary ID
44 | "supp_id": ""
45 | }
46 | ```
47 | - `chat_thread_db`: Database instance for storing chat threads.
48 |
49 | **Returns**: Unique identifier (str) for the created chat thread.
50 |
51 | #### get_chat_thread_response
52 |
53 | ```python
54 | def get_chat_thread_response(thread_id: str, get_response_input: ChatResponseInput,
55 | chat_thread_db: ChatThreadDB, knowledge_bases: dict, stream: bool = False)
56 | ```
57 |
58 | Get a response for a chat thread using knowledge base search, with optional streaming support.
59 |
60 | **Arguments**:
61 | - `thread_id`: Unique identifier for the chat thread.
62 | - `get_response_input`: Input parameters containing:
63 | - `user_input`: User's message text
64 | - `chat_thread_params`: Optional parameter overrides
65 | - `metadata_filter`: Optional search filter
66 | - `chat_thread_db`: Database instance for chat threads.
67 | - `knowledge_bases`: Dictionary mapping knowledge base IDs to instances.
68 | - `stream`: Whether to stream the response (defaults to False).
69 |
70 | **Returns**:
71 | If `stream=False` (default):
72 | - Formatted interaction containing:
73 | - `user_input`: User message with content and timestamp
74 | - `model_response`: Model response with content and timestamp
75 | - `search_queries`: Generated search queries
76 | - `relevant_segments`: Retrieved relevant segments with file names and types
77 | - `message`: Error message if something went wrong (optional)
78 |
79 | If `stream=True`:
80 | - Generator that yields partial response objects with the same structure as above
81 | - The final response is automatically saved to the chat thread database
82 |
83 | ### Chat Types
84 |
85 | #### ChatThreadParams
86 |
87 | Type definition for chat thread parameters.
88 |
89 | ```python
90 | ChatThreadParams = TypedDict('ChatThreadParams', {
91 | 'kb_ids': List[str], # List of knowledge base IDs to use
92 | 'model': str, # LLM model name (e.g., "gpt-4o-mini")
93 | 'temperature': float, # Temperature for LLM sampling (0.0-1.0)
94 | 'system_message': str, # Custom system message for LLM
95 | 'auto_query_model': str, # Model for generating search queries
96 | 'auto_query_guidance': str, # Custom guidance for query generation
97 | 'target_output_length': str, # Response length ("short", "medium", "long")
98 | 'max_chat_history_tokens': int, # Maximum tokens in chat history
99 | 'thread_id': str, # Unique thread identifier (auto-generated)
100 | 'supp_id': str, # Optional supplementary identifier
101 | })
102 | ```
103 |
104 | #### ChatResponseInput
105 |
106 | Input parameters for getting a chat response.
107 |
108 | ```python
109 | ChatResponseInput = TypedDict('ChatResponseInput', {
110 | 'user_input': str, # User's message text
111 | 'chat_thread_params': Optional[ChatThreadParams], # Optional parameter overrides
112 | 'metadata_filter': Optional[MetadataFilter], # Optional search filter
113 | })
114 | ```
115 |
116 | #### MetadataFilter
117 |
118 | Filter criteria for knowledge base searches.
119 |
120 | ```python
121 | MetadataFilter = TypedDict('MetadataFilter', {
122 | 'field_name': str, # Name of the metadata field to filter on
123 | 'field_value': Any, # Value to match against
124 | 'comparison_type': str, # Type of comparison ("equals", "contains", etc.)
125 | }, total=False)
126 | ```
127 |
128 | You can use these types to ensure type safety when working with the chat functions. For example:
129 |
130 | ```python
131 | # Create chat thread parameters
132 | params: ChatThreadParams = {
133 | "kb_ids": ["kb1"],
134 | "model": "gpt-4o-mini",
135 | "temperature": 0.2,
136 | "system_message": "You are a helpful assistant",
137 | "auto_query_model": "gpt-4o-mini",
138 | "auto_query_guidance": "",
139 | "target_output_length": "medium",
140 | "max_chat_history_tokens": 8000,
141 | "supp_id": ""
142 | }
143 |
144 | # Create chat response input
145 | response_input: ChatResponseInput = {
146 | "user_input": "What is the capital of France?",
147 | "chat_thread_params": None,
148 | "metadata_filter": None
149 | }
150 | ```
--------------------------------------------------------------------------------
/docs/api/knowledge_base.md:
--------------------------------------------------------------------------------
1 | # KnowledgeBase
2 |
3 | The KnowledgeBase class is the main interface for working with dsRAG. It handles document processing, storage, and retrieval.
4 |
5 | ### Public Methods
6 |
7 | The following methods are part of the public API:
8 |
9 | - `__init__`: Initialize a new KnowledgeBase instance
10 | - `add_document`: Add a single document to the knowledge base
11 | - `add_documents`: Add multiple documents in parallel
12 | - `delete`: Delete the entire knowledge base and all associated data
13 | - `delete_document`: Delete a specific document from the knowledge base
14 | - `query`: Search the knowledge base with one or more queries
15 |
16 | ::: dsrag.knowledge_base.KnowledgeBase
17 | options:
18 | show_root_heading: false
19 | show_root_full_path: false
20 | members:
21 | - __init__
22 | - add_document
23 | - add_documents
24 | - delete
25 | - delete_document
26 | - query
27 |
28 | ## KB Components
29 |
30 | ### Vector Databases
31 |
32 | ::: dsrag.database.vector.VectorDB
33 | options:
34 | show_root_heading: true
35 | show_root_full_path: false
36 |
37 | ### Chunk Databases
38 |
39 | ::: dsrag.database.chunk.ChunkDB
40 | options:
41 | show_root_heading: true
42 | show_root_full_path: false
43 |
44 | ### Embedding Models
45 |
46 | ::: dsrag.embedding.Embedding
47 | options:
48 | show_root_heading: true
49 | show_root_full_path: false
50 |
51 | ### Rerankers
52 |
53 | ::: dsrag.reranker.Reranker
54 | options:
55 | show_root_heading: true
56 | show_root_full_path: false
57 |
58 | ### LLM Providers
59 |
60 | ::: dsrag.llm.LLM
61 | options:
62 | show_root_heading: true
63 | show_root_full_path: false
64 |
65 | ### File Systems
66 |
67 | ::: dsrag.dsparse.file_parsing.file_system.FileSystem
68 | options:
69 | show_root_heading: true
70 | show_root_full_path: false
--------------------------------------------------------------------------------
/docs/community/contributing.md:
--------------------------------------------------------------------------------
1 | # Contributing to dsRAG
2 |
3 | We welcome contributions from the community! Whether it's fixing bugs, improving documentation, or proposing new features, your contributions make dsRAG better for everyone.
4 |
5 | ## Getting Started
6 |
7 | 1. Fork the repository on GitHub
8 | 2. Clone your fork locally:
9 | ```bash
10 | git clone https://github.com/your-username/dsRAG.git
11 | cd dsRAG
12 | ```
13 | 3. Create a new branch for your changes:
14 | ```bash
15 | git checkout -b feature/your-feature-name
16 | ```
17 |
18 | ## Development Setup
19 |
20 | 1. Create a virtual environment:
21 | ```bash
22 | python -m venv venv
23 | source venv/bin/activate # On Windows: venv\Scripts\activate
24 | ```
25 |
26 | 2. Install development dependencies:
27 | ```bash
28 | pip install -e ".[dev]"
29 | ```
30 |
31 | ## Making Changes
32 |
33 | 1. Make your changes in your feature branch
34 | 2. Write or update tests as needed
35 | 3. Update documentation to reflect your changes
36 | 4. Run the test suite to ensure everything works:
37 | ```bash
38 | python -m unittest discover
39 | ```
40 |
41 | ## Code Style
42 |
43 | We follow standard Python coding conventions:
44 | - Use [PEP 8](https://peps.python.org/pep-0008/) style guide
45 | - Use meaningful variable and function names
46 | - Write docstrings for functions and classes
47 | - Keep functions focused and concise
48 | - Add comments for complex logic
49 |
50 | ## Documentation
51 |
52 | When adding or modifying features:
53 | - Update docstrings for any modified functions/classes
54 | - Update relevant documentation in the `docs/` directory
55 | - Add examples if appropriate
56 | - Ensure documentation builds without errors:
57 | ```bash
58 | mkdocs serve
59 | ```
60 |
61 | ## Testing
62 |
63 | For new features or bug fixes:
64 | - Add appropriate unit tests in the `tests/` directory
65 | - Tests should subclass `unittest.TestCase`
66 | - Follow the existing test file naming pattern: `test_*.py`
67 | - Update existing tests if needed
68 | - Ensure all tests pass locally
69 | - Maintain or improve code coverage
70 |
71 | Example test structure:
72 | ```python
73 | import unittest
74 | from dsrag import YourModule
75 |
76 | class TestYourFeature(unittest.TestCase):
77 | def setUp(self):
78 | # Set up any test fixtures
79 | pass
80 |
81 | def test_your_feature(self):
82 | # Test your new functionality
83 | result = YourModule.your_function()
84 | self.assertEqual(result, expected_value)
85 |
86 | def tearDown(self):
87 | # Clean up after tests
88 | pass
89 | ```
90 |
91 | ## Submitting Changes
92 |
93 | 1. Commit your changes:
94 | ```bash
95 | git add .
96 | git commit -m "Description of your changes"
97 | ```
98 |
99 | 2. Push to your fork:
100 | ```bash
101 | git push origin feature/your-feature-name
102 | ```
103 |
104 | 3. Open a Pull Request:
105 | - Go to the dsRAG repository on GitHub
106 | - Click "Pull Request"
107 | - Select your feature branch
108 | - Describe your changes and their purpose
109 | - Reference any related issues
110 |
111 | ## Pull Request Guidelines
112 |
113 | Your PR should:
114 |
115 | - Focus on a single feature or fix
116 | - Include appropriate tests
117 | - Update relevant documentation
118 | - Follow the code style guidelines
119 | - Include a clear description of the changes
120 | - Reference any related issues
121 |
122 | ## Getting Help
123 |
124 | If you need help with your contribution:
125 |
126 | - Join our [Discord](https://discord.gg/NTUVX9DmQ3)
127 | - Ask questions in the PR
128 | - Tag maintainers if you need specific guidance
129 |
130 | ## Code of Conduct
131 |
132 | - Be respectful and inclusive
133 | - Focus on constructive feedback
134 | - Help others learn and grow
135 | - Follow project maintainers' guidance
136 |
137 | Thank you for contributing to dsRAG!
--------------------------------------------------------------------------------
/docs/community/professional-services.md:
--------------------------------------------------------------------------------
1 | # Professional Services
2 |
3 | ## About Our Team
4 |
5 | The creators of dsRAG, Zach and Nick McCormick, run a specialized applied AI consulting firm. As former startup founders and YC alums, we bring a unique business and product-centric perspective to highly technical engineering projects.
6 |
7 | ## Our Expertise
8 |
9 | We specialize in:
10 |
11 | - Building high-performance RAG-based applications
12 | - Optimizing existing RAG systems
13 | - Integrating advanced retrieval techniques
14 | - Designing scalable AI architectures
15 |
16 | ## Services Offered
17 |
18 | We provide:
19 |
20 | - Advisory services
21 | - Implementation work
22 | - Performance optimization
23 | - Custom feature development
24 | - Technical architecture design
25 |
26 | ## Getting Started
27 |
28 | If you need professional assistance with:
29 |
30 | - Implementing dsRAG in your production environment
31 | - Optimizing your RAG-based application
32 | - Custom feature development
33 | - Technical consultation
34 |
35 | Please fill out our [contact form](https://forms.gle/zbQwDJp7pBQKtqVT8) and we'll be in touch to discuss how we can help with your specific needs.
--------------------------------------------------------------------------------
/docs/community/support.md:
--------------------------------------------------------------------------------
1 | # Community Support
2 |
3 | ## Discord Community
4 |
5 | Join our [Discord](https://discord.gg/NTUVX9DmQ3) to:
6 |
7 | - Ask questions about dsRAG
8 | - Make suggestions for improvements
9 | - Discuss potential contributions
10 | - Connect with other developers
11 | - Share your experiences
12 |
13 | ## Production Use Cases
14 |
15 | If you're using (or planning to use) dsRAG in production:
16 |
17 | 1. Please fill out our [use case form](https://forms.gle/RQ5qFVReonSHDcCu5)
18 |
19 | 2. This helps us:
20 | - Understand how dsRAG is being used
21 | - Prioritize new features
22 | - Improve documentation
23 | - Focus on the most important issues
24 |
25 | 3. In return, you'll receive:
26 | - Direct email contact with the creators
27 | - Priority email support
28 | - Early access to new features
--------------------------------------------------------------------------------
/docs/concepts/citations.md:
--------------------------------------------------------------------------------
1 | # Citations
2 |
3 | dsRAG's citation system ensures that responses are grounded in your knowledge base content and provides transparency about information sources. Citations are included automatically in chat responses.
4 |
5 | ## Overview
6 |
7 | Citations in dsRAG:
8 |
9 | - Track the source of information in responses
10 | - Include document IDs and page numbers
11 | - Provide exact quoted text from sources
12 |
13 | ## Citation Structure
14 |
15 | Each citation includes:
16 |
17 | - `doc_id`: Unique identifier of the source document
18 | - `page_number`: Page where the information was found (if available)
19 | - `cited_text`: Exact text containing the cited information
20 | - `kb_id`: ID of the knowledge base containing the document
21 |
22 | ## Working with Citations
23 |
24 | Citations are automatically included in chat responses:
25 |
26 | ```python
27 | from dsrag.chat.chat import get_chat_thread_response
28 | from dsrag.chat.chat_types import ChatResponseInput
29 | from dsrag.database.chat_thread.sqlite_db import SQLiteChatThreadDB
30 |
31 | # Create the knowledge base instances
32 | knowledge_bases = {
33 | "my_knowledge_base": KnowledgeBase(kb_id="my_knowledge_base")
34 | }
35 |
36 | # Initialize database and get response
37 | chat_thread_db = SQLiteChatThreadDB()
38 | response = get_chat_thread_response(
39 | thread_id=thread_id,
40 | get_response_input=ChatResponseInput(
41 | user_input="What is the system architecture?"
42 | ),
43 | chat_thread_db=chat_thread_db,
44 | knowledge_bases=knowledge_bases
45 | )
46 |
47 | # Access citations
48 | citations = response["model_response"]["citations"]
49 | for citation in citations:
50 | print(f"""
51 | Source: {citation['doc_id']}
52 | Page: {citation['page_number']}
53 | Text: {citation['cited_text']}
54 | Knowledge Base: {citation['kb_id']}
55 | """)
56 | ```
57 |
58 | ## Technical Details
59 |
60 | Citations are managed through the `ResponseWithCitations` model:
61 |
62 | ```python
63 | from dsrag.chat.citations import ResponseWithCitations, Citation
64 |
65 | response = ResponseWithCitations(
66 | response="The system architecture is...",
67 | citations=[
68 | Citation(
69 | doc_id="arch-doc-v1",
70 | page_number=12,
71 | cited_text="The system uses a microservices architecture"
72 | )
73 | ]
74 | )
75 | ```
76 |
77 | This structured format ensures consistent citation handling throughout the system.
--------------------------------------------------------------------------------
/docs/concepts/components.md:
--------------------------------------------------------------------------------
1 | # Components
2 |
3 | There are six key components that define the configuration of a KnowledgeBase. Each component is customizable, with several built-in options available.
4 |
5 | ## VectorDB
6 |
7 | The VectorDB component stores embedding vectors and associated metadata.
8 |
9 | Available options:
10 |
11 | - `BasicVectorDB`
12 | - `WeaviateVectorDB`
13 | - `ChromaDB`
14 | - `QdrantVectorDB`
15 | - `MilvusDB`
16 | - `PineconeDB`
17 |
18 | ## ChunkDB
19 |
20 | The ChunkDB stores the content of text chunks in a nested dictionary format, keyed on `doc_id` and `chunk_index`. This is used by RSE to retrieve the full text associated with specific chunks.
21 |
22 | Available options:
23 |
24 | - `BasicChunkDB`
25 | - `SQLiteDB`
26 |
27 | ## Embedding
28 |
29 | The Embedding component defines the embedding model used for vectorizing text.
30 |
31 | Available options:
32 |
33 | - `OpenAIEmbedding`
34 | - `CohereEmbedding`
35 | - `VoyageAIEmbedding`
36 | - `OllamaEmbedding`
37 |
38 | ## Reranker
39 |
40 | The Reranker component provides more accurate ranking of chunks after vector database search and before RSE. This is optional but highly recommended.
41 |
42 | Available options:
43 |
44 | - `CohereReranker`
45 | - `VoyageReranker`
46 | - `NoReranker`
47 |
48 | ## LLM
49 |
50 | The LLM component is used in AutoContext for:
51 |
52 | - Document title generation
53 | - Document summarization
54 | - Section summarization
55 |
56 | Available options:
57 |
58 | - `OpenAIChatAPI`
59 | - `AnthropicChatAPI`
60 | - `OllamaChatAPI`
61 | - `GeminiAPI`
62 |
63 | ## FileSystem
64 |
65 | The FileSystem component defines where to save PDF images and extracted elementsfor VLM file parsing.
66 |
67 | Available options:
68 |
69 | - `LocalFileSystem`
70 | - `S3FileSystem`
71 |
72 | #### LocalFileSystem Configuration
73 | Only requires a `base_path` parameter to define where files will be stored on the system.
74 |
75 | #### S3FileSystem Configuration
76 | Requires the following parameters:
77 |
78 | - `base_path`: Used when downloading files from S3
79 | - `bucket_name`: S3 bucket name
80 | - `region_name`: AWS region
81 | - `access_key`: AWS access key
82 | - `access_secret`: AWS secret key
83 |
84 | Note: Files must be stored locally temporarily for use in the retrieval system, even when using S3.
--------------------------------------------------------------------------------
/docs/concepts/config.md:
--------------------------------------------------------------------------------
1 | # Configuration
2 |
3 | dsRAG uses several configuration dictionaries to organize its many parameters. These configs can be passed to different methods of the KnowledgeBase class.
4 |
5 | ## Document Addition Configs
6 |
7 | The following configuration dictionaries can be passed to the `add_document` method:
8 |
9 | ### AutoContext Config
10 |
11 | ```python
12 | auto_context_config = {
13 | "use_generated_title": bool, # whether to use an LLM-generated title if no title is provided (default: True)
14 | "document_title_guidance": str, # Additional guidance for generating the document title (default: "")
15 | "get_document_summary": bool, # whether to get a document summary (default: True)
16 | "document_summarization_guidance": str, # Additional guidance for summarizing the document (default: "")
17 | "get_section_summaries": bool, # whether to get section summaries (default: False)
18 | "section_summarization_guidance": str # Additional guidance for summarizing the sections (default: "")
19 | }
20 | ```
21 |
22 | ### File Parsing Config
23 |
24 | ```python
25 | file_parsing_config = {
26 | "use_vlm": bool, # whether to use VLM for parsing the file (default: False)
27 | "vlm_config": {
28 | "provider": str, # the VLM provider to use (default: "gemini", "vertex_ai" is also supported)
29 | "model": str, # the VLM model to use (default: "gemini-2.0-flash")
30 | "project_id": str, # GCP project ID (only required for "vertex_ai")
31 | "location": str, # GCP location (only required for "vertex_ai")
32 | "save_path": str, # path to save intermediate files during VLM processing
33 | "exclude_elements": list, # element types to exclude (default: ["Header", "Footer"])
34 | "images_already_exist": bool # whether images are pre-extracted (default: False)
35 | },
36 | "always_save_page_images": bool # save page images even if VLM is not used (default: False)
37 | }
38 | ```
39 |
40 | ### Semantic Sectioning Config
41 |
42 | ```python
43 | semantic_sectioning_config = {
44 | "llm_provider": str, # LLM provider (default: "openai", "anthropic" and "gemini" are also supported)
45 | "model": str, # LLM model to use (default: "gpt-4o-mini")
46 | "use_semantic_sectioning": bool # if False, skip semantic sectioning (default: True)
47 | }
48 | ```
49 |
50 | ### Chunking Config
51 |
52 | ```python
53 | chunking_config = {
54 | "chunk_size": int, # maximum characters per chunk (default: 800)
55 | "min_length_for_chunking": int # minimum text length to allow chunking (default: 2000)
56 | }
57 | ```
58 |
59 | ## Query Config
60 |
61 | The following configuration dictionary can be passed to the `query` method:
62 |
63 | ### RSE Parameters
64 |
65 | ```python
66 | rse_params = {
67 | "max_length": int, # maximum segment length in chunks (default: 15)
68 | "overall_max_length": int, # maximum total length of all segments (default: 30)
69 | "minimum_value": float, # minimum relevance value for segments (default: 0.5)
70 | "irrelevant_chunk_penalty": float, # penalty for irrelevant chunks (0-1) (default: 0.18)
71 | "overall_max_length_extension": int, # length increase per additional query (default: 5)
72 | "decay_rate": float, # rate at which relevance decays (default: 30)
73 | "top_k_for_document_selection": int, # maximum number of documents to consider (default: 10)
74 | "chunk_length_adjustment": bool # whether to scale by chunk length (default: True)
75 | }
76 | ```
77 |
78 | ## Metadata Query Filters
79 |
80 | Some vector databases (currently only ChromaDB) support metadata filtering during queries. This allows for more controlled document selection.
81 |
82 | Example metadata filter format:
83 | ```python
84 | metadata_filter = {
85 | "field": str, # The metadata field to filter by
86 | "operator": str, # One of: 'equals', 'not_equals', 'in', 'not_in',
87 | # 'greater_than', 'less_than', 'greater_than_equals', 'less_than_equals'
88 | "value": str | int | float | list # If list, all items must be same type
89 | }
90 | ```
91 |
92 | Example usage:
93 | ```python
94 | # Filter with "equals" operator
95 | metadata_filter = {
96 | "field": "doc_id",
97 | "operator": "equals",
98 | "value": "test_id_1"
99 | }
100 |
101 | # Filter with "in" operator
102 | metadata_filter = {
103 | "field": "doc_id",
104 | "operator": "in",
105 | "value": ["test_id_1", "test_id_2"]
106 | }
107 | ```
--------------------------------------------------------------------------------
/docs/concepts/knowledge-bases.md:
--------------------------------------------------------------------------------
1 | # Knowledge Bases
2 |
3 | A knowledge base in dsRAG is a searchable collection of documents that can be queried to find relevant information. The `KnowledgeBase` class handles document processing, storage, and retrieval.
4 |
5 | ## Creating a Knowledge Base
6 |
7 | To create a knowledge base:
8 |
9 | ```python
10 | from dsrag.knowledge_base import KnowledgeBase
11 |
12 | # Create a basic knowledge base
13 | kb = KnowledgeBase(
14 | kb_id="my_kb",
15 | title="Product Documentation",
16 | description="Technical documentation for XYZ product"
17 | )
18 |
19 | # Or with custom configuration
20 | kb = KnowledgeBase(
21 | kb_id="my_kb",
22 | storage_directory="path/to/storage", # Where to store KB data
23 | embedding_model=custom_embedder, # Custom embedding model
24 | reranker=custom_reranker, # Custom reranking model
25 | vector_db=custom_vector_db, # Custom vector database
26 | chunk_db=custom_chunk_db # Custom chunk database
27 | )
28 | ```
29 |
30 | ## Adding Documents
31 |
32 | Documents can be added from text or files:
33 |
34 | ```python
35 | # Add from text
36 | kb.add_document(
37 | doc_id="intro-guide",
38 | text="This is the introduction guide...",
39 | document_title="Introduction Guide",
40 | metadata={"type": "guide", "version": "1.0"}
41 | )
42 |
43 | # Add from file
44 | kb.add_document(
45 | doc_id="user-manual",
46 | file_path="path/to/manual.pdf",
47 | metadata={"type": "manual", "department": "engineering"}
48 | )
49 |
50 | # Add with advanced configuration
51 | kb.add_document(
52 | doc_id="technical-spec",
53 | file_path="path/to/spec.pdf",
54 | file_parsing_config={
55 | "use_vlm": True, # Use vision language model for PDFs
56 | "always_save_page_images": True # Save page images for visual content
57 | },
58 | chunking_config={
59 | "chunk_size": 800, # Characters per chunk
60 | "min_length_for_chunking": 2000 # Minimum length to chunk
61 | },
62 | auto_context_config={
63 | "use_generated_title": True, # Generate title if not provided
64 | "get_document_summary": True # Generate document summary
65 | }
66 | )
67 | ```
68 |
69 | ## Querying the Knowledge Base
70 |
71 | Search the knowledge base for relevant information:
72 |
73 | ```python
74 | # Simple query
75 | results = kb.query(
76 | search_queries=["How to configure the system?"]
77 | )
78 |
79 | # Advanced query with filtering and parameters
80 | results = kb.query(
81 | search_queries=[
82 | "System configuration steps",
83 | "Configuration prerequisites"
84 | ],
85 | metadata_filter={
86 | "field": "doc_id",
87 | "operator": "equals",
88 | "value": "user_manual"
89 | },
90 | rse_params="precise", # Use preset RSE parameters
91 | return_mode="text" # Return text content
92 | )
93 |
94 | # Process results
95 | for segment in results:
96 | print(f"""
97 | Document: {segment['doc_id']}
98 | Pages: {segment['segment_page_start']} - {segment['segment_page_end']}
99 | Content: {segment['content']}
100 | Relevance: {segment['score']}
101 | """)
102 | ```
103 |
104 | ## RSE Parameters
105 |
106 | The Relevant Segment Extraction (RSE) system can be tuned using different parameter presets:
107 | - `"balanced"`: Default preset balancing precision and comprehensiveness
108 | - `"precise"`: Favors shorter, more focused segments
109 | - `"comprehensive"`: Returns longer segments with more context
110 |
111 | Or configure custom RSE parameters:
112 | ```python
113 | results = kb.query(
114 | search_queries=["system requirements"],
115 | rse_params={
116 | "max_length": 5, # Max segments length (in number of chunks)
117 | "overall_max_length": 20, # Total length limit across all segments (in number of chunks)
118 | "minimum_value": 0.5, # Minimum relevance score
119 | "irrelevant_chunk_penalty": 0.2 # Penalty for irrelevant chunks in a segment - higher penalty leads to shorter segments
120 | }
121 | )
122 | ```
123 |
124 | ## Metadata Query Filters
125 |
126 | Certain vector DBs support metadata filtering when running a query (currently only ChromaDB). This allows you to have more control over what document(s) get searched. A common use case would be asking questions about a single document in a knowledge base, in which case you would supply the `doc_id` as a metadata filter.
127 |
128 | The metadata filter should be a dictionary with the following structure:
129 |
130 | ```python
131 | metadata_filter = {
132 | "field": "doc_id", # The metadata field to filter on
133 | "operator": "equals", # The comparison operator
134 | "value": "doc123" # The value to compare against
135 | }
136 | ```
137 |
138 | Supported operators:
139 | - `equals`
140 | - `not_equals`
141 | - `in`
142 | - `not_in`
143 | - `greater_than`
144 | - `less_than`
145 | - `greater_than_equals`
146 | - `less_than_equals`
147 |
148 | For operators that take multiple values (`in` and `not_in`), the value should be a list where all items are of the same type (string, integer, or float).
149 |
150 | Example usage:
151 | ```python
152 | # Query a specific document
153 | results = kb.query(
154 | search_queries=["system requirements"],
155 | metadata_filter={
156 | "field": "doc_id",
157 | "operator": "equals",
158 | "value": "technical_spec_v1"
159 | }
160 | )
161 |
162 | # Query documents from multiple departments
163 | results = kb.query(
164 | search_queries=["security protocols"],
165 | metadata_filter={
166 | "field": "department",
167 | "operator": "in",
168 | "value": ["security", "compliance"]
169 | }
170 | )
171 | ```
--------------------------------------------------------------------------------
/docs/concepts/overview.md:
--------------------------------------------------------------------------------
1 | # Architecture Overview
2 |
3 | dsRAG is built around three key methods that improve performance over vanilla RAG systems:
4 |
5 | 1. Semantic sectioning
6 | 2. AutoContext
7 | 3. Relevant Segment Extraction (RSE)
8 |
9 | ## Key Methods
10 |
11 | ### Semantic Sectioning
12 |
13 | Semantic sectioning uses an LLM to break a document into cohesive sections. The process works as follows:
14 |
15 | 1. The document is annotated with line numbers
16 | 2. An LLM identifies the starting and ending lines for each "semantically cohesive section"
17 | 3. Sections typically range from a few paragraphs to a few pages long
18 | 4. Sections are broken into smaller chunks if needed
19 | 5. The LLM generates descriptive titles for each section
20 | 6. Section titles are used in contextual chunk headers created by AutoContext
21 |
22 | This process provides additional context to the ranking models (embeddings and reranker), enabling better retrieval.
23 |
24 | ### AutoContext (Contextual Chunk Headers)
25 |
26 | AutoContext creates contextual chunk headers that contain:
27 |
28 | - Document-level context
29 | - Section-level context
30 |
31 | These headers are prepended to chunks before embedding. Benefits include:
32 |
33 | - More accurate and complete representation of text content and meaning
34 | - Dramatic improvement in retrieval quality
35 | - Reduced rate of irrelevant results
36 | - Reduced LLM misinterpretation in downstream applications
37 |
38 | ### Relevant Segment Extraction (RSE)
39 |
40 | RSE is a query-time post-processing step that:
41 |
42 | 1. Takes clusters of relevant chunks
43 | 2. Intelligently combines them into longer sections (segments)
44 | 3. Provides better context to the LLM than individual chunks
45 |
46 | RSE is particularly effective for:
47 |
48 | - Complex questions where answers span multiple chunks
49 | - Adapting the context length based on query type
50 | - Maintaining coherent context while avoiding irrelevant information
51 |
52 | ## Document Processing Flow
53 |
54 | 1. Documents → VLM file parsing
55 | 2. Semantic sectioning
56 | 3. Chunking
57 | 4. AutoContext
58 | 5. Embedding
59 | 6. Chunk and vector database upsert
60 |
61 | ## Query Processing Flow
62 |
63 | 1. Queries
64 | 2. Vector database search
65 | 3. Reranking
66 | 4. RSE
67 | 5. Results
68 |
69 | For more detailed information about specific components and configuration options, please refer to:
70 | - [Components Documentation](components.md)
71 | - [Configuration Options](config.md)
72 | - [Knowledge Base Details](knowledge-base.md)
--------------------------------------------------------------------------------
/docs/concepts/vlm.md:
--------------------------------------------------------------------------------
1 | # VLM File Parsing
2 |
3 | dsRAG supports Vision Language Model (VLM) integration for enhanced PDF parsing capabilities. This feature is particularly useful for documents with complex layouts, tables, diagrams, or other visual elements that traditional text extraction might miss.
4 |
5 | ## Overview
6 |
7 | When VLM parsing is enabled:
8 |
9 | 1. The PDF is converted to images (one per page)
10 | 2. Each page image is analyzed by the VLM to identify and extract text and visual elements
11 | 3. The extracted elements are converted to a structured format with page numbers preserved
12 |
13 | ## Supported Providers
14 |
15 | Currently supported VLM providers:
16 |
17 | - Google's Gemini (default)
18 | - Google Cloud Vertex AI
19 |
20 | ## Configuration
21 |
22 | To use VLM for file parsing, configure the `file_parsing_config` when adding documents:
23 |
24 | ```python
25 | file_parsing_config = {
26 | "use_vlm": True, # Enable VLM parsing
27 | "vlm_config": {
28 | # VLM Provider (required)
29 | "provider": "gemini", # or "vertex_ai"
30 |
31 | # Model name (optional - defaults based on provider)
32 | "model": "gemini-2.0-flash", # default for Gemini
33 |
34 | # For Vertex AI, additional required fields:
35 | "project_id": "your-gcp-project-id",
36 | "location": "us-central1",
37 |
38 | # Elements to exclude from parsing (optional)
39 | "exclude_elements": ["Header", "Footer"],
40 |
41 | # Whether images are pre-extracted (optional)
42 | "images_already_exist": False
43 | },
44 |
45 | # Save page images even if VLM unused (optional)
46 | "always_save_page_images": False
47 | }
48 |
49 | # Add document with VLM parsing (must be a PDF file)
50 | kb.add_document(
51 | doc_id="my_document",
52 | file_path="path/to/document.pdf", # Only PDF files are supported with VLM
53 | file_parsing_config=file_parsing_config
54 | )
55 | ```
56 |
57 | ## Element Types
58 |
59 | The VLM is prompted to categorize page content into the following element types:
60 |
61 | Text Elements:
62 |
63 | - NarrativeText: Main text content including paragraphs, lists, and titles
64 | - Header: Page header content (typically at top of page)
65 | - Footnote: References or notes at bottom of content
66 | - Footer: Page footer content (at very bottom of page)
67 |
68 | Visual Elements:
69 |
70 | - Figure: Charts, graphs, and diagrams with associated titles and legends
71 | - Image: Photos, illustrations, and other visual content
72 | - Table: Tabular data arrangements with titles and captions
73 | - Equation: Mathematical formulas and expressions
74 |
75 | By default, Header and Footer elements are excluded from parsing as they rarely contain valuable information and can break the flow between pages. You can modify which elements to exclude using the `exclude_elements` configuration option.
76 |
77 | ## Best Practices
78 |
79 | 1. Use VLM parsing for documents with:
80 | - Complex layouts
81 | - Important visual elements
82 | - Tables that need precise formatting
83 | - Mathematical formulas
84 | - Scanned documents that need OCR
85 |
86 | 2. Consider traditional parsing for:
87 | - Simple text documents
88 | - Documents where visual layout isn't critical
89 | - Large volumes of documents (VLM parsing is slower and more expensive)
90 |
91 | 3. Configure `exclude_elements` to ignore irrelevant elements like headers/footers
--------------------------------------------------------------------------------
/docs/getting-started/installation.md:
--------------------------------------------------------------------------------
1 | # Installation
2 |
3 | ## Installing dsRAG
4 |
5 | Install the package using pip:
6 |
7 | ```bash
8 | pip install dsrag
9 | ```
10 |
11 | If you want to use VLM file parsing, you will need one non-Python dependency: poppler. This is used for converting PDFs to images. On MacOS, you can install it using Homebrew:
12 |
13 | ```bash
14 | brew install poppler
15 | ```
16 |
17 | ## Third-party dependencies
18 |
19 | To get the most out of dsRAG, you'll need a few third-party API keys set as environment variables. Here are the most important ones:
20 |
21 | - `OPENAI_API_KEY`: Recommended for embeddings, AutoContext, and semantic sectioning
22 | - `CO_API_KEY`: Recommended for reranking
23 | - `GEMINI_API_KEY`: Recommended for VLM file parsing
--------------------------------------------------------------------------------
/docs/getting-started/quickstart.md:
--------------------------------------------------------------------------------
1 | # Quick Start Guide
2 |
3 | This guide will help you get started with dsRAG quickly. We'll cover the basics of creating a knowledge base and querying it.
4 |
5 | ## Setting Up
6 |
7 | First, make sure you have the necessary API keys set as environment variables:
8 | - `OPENAI_API_KEY` for embeddings and AutoContext
9 | - `CO_API_KEY` for reranking with Cohere
10 |
11 | If you don't have both of these available, you can use the OpenAI-only configuration or local-only configuration below.
12 |
13 | ## Creating a Knowledge Base
14 |
15 | Create a new knowledge base and add documents to it:
16 |
17 | ```python
18 | from dsrag.knowledge_base import KnowledgeBase
19 |
20 | # Create a knowledge base
21 | kb = KnowledgeBase(kb_id="my_knowledge_base")
22 |
23 | # Add documents
24 | kb.add_document(
25 | doc_id="user_manual", # Use a meaningful ID if possible
26 | file_path="path/to/your/document.pdf",
27 | document_title="User Manual", # Optional but recommended
28 | metadata={"type": "manual"} # Optional metadata
29 | )
30 | ```
31 |
32 | The KnowledgeBase object persists to disk automatically, so you don't need to explicitly save it.
33 |
34 | ## Querying the Knowledge Base
35 |
36 | Once you have created a knowledge base, you can load it by its `kb_id` and query it:
37 |
38 | ```python
39 | from dsrag.knowledge_base import KnowledgeBase
40 |
41 | # Load the knowledge base
42 | kb = KnowledgeBase("my_knowledge_base")
43 |
44 | # You can query with multiple search queries
45 | search_queries = [
46 | "What are the main topics covered?",
47 | "What are the key findings?"
48 | ]
49 |
50 | # Get results
51 | results = kb.query(search_queries)
52 | for segment in results:
53 | print(segment)
54 | ```
55 |
56 | ## Using OpenAI-Only Configuration
57 |
58 | If you prefer to use only OpenAI services (without Cohere), you can customize the configuration:
59 |
60 | ```python
61 | from dsrag.llm import OpenAIChatAPI
62 | from dsrag.reranker import NoReranker
63 |
64 | # Configure components
65 | llm = OpenAIChatAPI(model='gpt-4o-mini')
66 | reranker = NoReranker()
67 |
68 | # Create knowledge base with custom configuration
69 | kb = KnowledgeBase(
70 | kb_id="my_knowledge_base",
71 | reranker=reranker,
72 | auto_context_model=llm
73 | )
74 |
75 | # Add documents
76 | kb.add_document(
77 | doc_id="user_manual", # Use a meaningful ID
78 | file_path="path/to/your/document.pdf",
79 | document_title="User Manual", # Optional but recommended
80 | metadata={"type": "manual"} # Optional metadata
81 | )
82 | ```
83 |
84 | ## Local-only configuration
85 |
86 | If you don't want to use any third-party services, you can configure dsRAG to run fully locally using Ollama. There will be a couple limitations:
87 | - You will not be able to use VLM file parsing
88 | - You will not be able to use semantic sectioning
89 | - You will not be able to use a reranker, unless you implement your own custom Reranker class
90 |
91 | ```python
92 | from dsrag.llm import OllamaChatAPI
93 | from dsrag.reranker import NoReranker
94 | from dsrag.embedding import OllamaEmbedding
95 |
96 | llm = OllamaChatAPI(model="llama3.1:8b")
97 | reranker = NoReranker()
98 | embedding = OllamaEmbedding(model="nomic-embed-text")
99 |
100 | kb = KnowledgeBase(
101 | kb_id="my_knowledge_base",
102 | reranker=reranker,
103 | auto_context_model=llm,
104 | embedding_model=embedding
105 | )
106 |
107 | # Disable semantic sectioning
108 | semantic_sectioning_config = {
109 | "use_semantic_sectioning": False,
110 | }
111 |
112 | # Add documents
113 | kb.add_document(
114 | doc_id="user_manual",
115 | file_path="path/to/your/document.pdf",
116 | document_title="User Manual",
117 | semantic_sectioning_config=semantic_sectioning_config
118 | )
119 | ```
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | # Welcome to dsRAG
2 |
3 | dsRAG is a retrieval engine for unstructured data. It is especially good at handling challenging queries over dense text, like financial reports, legal documents, and academic papers. dsRAG achieves substantially higher accuracy than vanilla RAG baselines on complex open-book question answering tasks. On one especially challenging benchmark, [FinanceBench](https://arxiv.org/abs/2311.11944), dsRAG gets accurate answers 96.6% of the time, compared to the vanilla RAG baseline which only gets 32% of questions correct.
4 |
5 | ## Key Methods
6 |
7 | dsRAG uses three key methods to improve performance over vanilla RAG systems:
8 |
9 | ### 1. Semantic Sectioning
10 | Semantic sectioning uses an LLM to break a document into sections. It works by annotating the document with line numbers and then prompting an LLM to identify the starting and ending lines for each "semantically cohesive section." These sections should be anywhere from a few paragraphs to a few pages long. The sections then get broken into smaller chunks if needed. The LLM is also prompted to generate descriptive titles for each section. These section titles get used in the contextual chunk headers created by AutoContext, which provides additional context to the ranking models (embeddings and reranker), enabling better retrieval.
11 |
12 | ### 2. AutoContext
13 | AutoContext creates contextual chunk headers that contain document-level and section-level context, and prepends those chunk headers to the chunks prior to embedding them. This gives the embeddings a much more accurate and complete representation of the content and meaning of the text. In our testing, this feature leads to a dramatic improvement in retrieval quality. In addition to increasing the rate at which the correct information is retrieved, AutoContext also substantially reduces the rate at which irrelevant results show up in the search results. This reduces the rate at which the LLM misinterprets a piece of text in downstream chat and generation applications.
14 |
15 | ### 3. Relevant Segment Extraction (RSE)
16 | Relevant Segment Extraction (RSE) is a query-time post-processing step that takes clusters of relevant chunks and intelligently combines them into longer sections of text that we call segments. These segments provide better context to the LLM than any individual chunk can. For simple factual questions, the answer is usually contained in a single chunk; but for more complex questions, the answer usually spans a longer section of text. The goal of RSE is to intelligently identify the section(s) of text that provide the most relevant information, without being constrained to fixed length chunks.
17 |
18 | ## Quick Example
19 |
20 | ```python
21 | from dsrag.create_kb import create_kb_from_file
22 |
23 | # Create a knowledge base from a file
24 | file_path = "path/to/your/document.pdf"
25 | kb_id = "my_knowledge_base"
26 | kb = create_kb_from_file(kb_id, file_path)
27 |
28 | # Query the knowledge base
29 | search_queries = ["What are the main topics covered?"]
30 | results = kb.query(search_queries)
31 | for segment in results:
32 | print(segment)
33 | ```
34 |
35 | ## Evaluation Results
36 |
37 | ### FinanceBench
38 | On [FinanceBench](https://arxiv.org/abs/2311.11944), which uses a corpus of hundreds of 10-Ks and 10-Qs with challenging queries that often require combining multiple pieces of information:
39 |
40 | - Baseline retrieval pipeline: 32% accuracy
41 | - dsRAG (with default parameters and Claude 3.5 Sonnet): 96.6% accuracy
42 |
43 | ### KITE Benchmark
44 | On the [KITE](https://github.com/D-Star-AI/KITE) benchmark, which includes diverse datasets (AI papers, company 10-Ks, company handbooks, and Supreme Court opinions), dsRAG shows significant improvements:
45 |
46 | | | Top-k | RSE | CCH+Top-k | CCH+RSE |
47 | |-------------------------|----------|--------|--------------|------------|
48 | | AI Papers | 4.5 | 7.9 | 4.7 | 7.9 |
49 | | BVP Cloud | 2.6 | 4.4 | 6.3 | 7.8 |
50 | | Sourcegraph | 5.7 | 6.6 | 5.8 | 9.4 |
51 | | Supreme Court Opinions | 6.1 | 8.0 | 7.4 | 8.5 |
52 | | **Average** | 4.72 | 6.73 | 6.04 | 8.42 |
53 |
54 | ## Getting Started
55 |
56 | Check out our [Quick Start Guide](getting-started/quickstart.md) to begin using dsRAG in your projects.
57 |
58 | ## Community and Support
59 |
60 | - Join our [Discord](https://discord.gg/NTUVX9DmQ3) for community support
61 | - Fill out our [use case form](https://forms.gle/RQ5qFVReonSHDcCu5) if using dsRAG in production
62 | - Need professional help? Contact our [team](https://forms.gle/zbQwDJp7pBQKtqVT8)
--------------------------------------------------------------------------------
/docs/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/D-Star-AI/dsRAG/b326c1d0930c9a80ff486ab0df5e771e788e3234/docs/logo.png
--------------------------------------------------------------------------------
/dsrag/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | # Configure the root dsrag logger with a NullHandler to prevent "No handler found" warnings
4 | # This follows Python best practices for library logging
5 | # Users will need to configure their own handlers if they want to see dsrag logs
6 | logger = logging.getLogger("dsrag")
7 | logger.addHandler(logging.NullHandler())
--------------------------------------------------------------------------------
/dsrag/auto_query.py:
--------------------------------------------------------------------------------
1 | # NOTE: This is a legacy file and is not used in the current implementation. Will be removed in the future.
2 |
3 | import os
4 | from dsrag.utils.imports import instructor
5 | from anthropic import Anthropic
6 | from pydantic import BaseModel
7 | from typing import List
8 |
9 | SYSTEM_MESSAGE = """
10 | You are a query generation system. Please generate one or more search queries (up to a maximum of {max_queries}) based on the provided user input. DO NOT generate the answer, just queries.
11 |
12 | Each of the queries you generate will be used to search a knowledge base for information that can be used to respond to the user input. Make sure each query is specific enough to return relevant information. If multiple pieces of information would be useful, you should generate multiple queries, one for each specific piece of information needed.
13 |
14 | {auto_query_guidance}
15 | """.strip()
16 |
17 |
18 | def get_search_queries(user_input: str, auto_query_guidance: str = "", max_queries: int = 5):
19 | base_url = os.environ.get("DSRAG_ANTHROPIC_BASE_URL", None)
20 | if base_url is not None:
21 | client = instructor.from_anthropic(Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"], base_url=base_url))
22 | else:
23 | client = instructor.from_anthropic(Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"]))
24 |
25 | class Queries(BaseModel):
26 | queries: List[str]
27 |
28 | resp = client.messages.create(
29 | model="claude-3-5-sonnet-20241022",
30 | max_tokens=400,
31 | temperature=0.2,
32 | system=SYSTEM_MESSAGE.format(max_queries=max_queries, auto_query_guidance=auto_query_guidance),
33 | messages=[
34 | {
35 | "role": "user",
36 | "content": user_input
37 | }
38 | ],
39 | response_model=Queries,
40 | )
41 |
42 | return resp.queries[:max_queries]
--------------------------------------------------------------------------------
/dsrag/chat/__init__.py:
--------------------------------------------------------------------------------
1 | from dsrag.chat.chat import create_new_chat_thread, get_chat_thread_response
2 | from dsrag.chat.chat_types import ChatThreadParams, ChatResponseInput
3 | from dsrag.chat.citations import ResponseWithCitations, Citation
4 |
5 | __all__ = [
6 | 'create_new_chat_thread',
7 | 'get_chat_thread_response',
8 | 'ChatThreadParams',
9 | 'ChatResponseInput',
10 | 'ResponseWithCitations',
11 | 'Citation'
12 | ]
--------------------------------------------------------------------------------
/dsrag/chat/auto_query.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pydantic import BaseModel
3 | from typing import List
4 | from dsrag.chat.instructor_get_response import get_response
5 |
6 | SYSTEM_MESSAGE = """
7 | You are a query generation system. Please generate one or more search queries (up to a maximum of {max_queries}) based on the provided user input. DO NOT generate the answer, just queries. You must specify the knowledge base you want to use for each query by providing the knowledge_base_id.
8 |
9 | Each of the queries you generate will be used to search a knowledge base for information that can be used to respond to the user input. Make sure each query is specific enough to return relevant information. If multiple pieces of information would be useful, you should generate multiple queries, one for each specific piece of information needed.
10 |
11 | {auto_query_guidance}
12 |
13 | Here are the knowledge bases you can search:
14 | {knowledge_base_descriptions}
15 | """.strip()
16 |
17 | class Query(BaseModel):
18 | query: str
19 | knowledge_base_id: str
20 |
21 | class Queries(BaseModel):
22 | queries: List[Query]
23 |
24 |
25 | def get_knowledge_base_descriptions_str(kb_info: list[dict]):
26 | kb_descriptions = []
27 | for kb in kb_info:
28 | kb_descriptions.append(f"kb_id: {kb['id']}\ndescription: {kb['title']} - {kb['description']}")
29 | return "\n\n".join(kb_descriptions)
30 |
31 | def validate_queries(queries: List[Query], kb_info: list[dict]) -> List[dict]:
32 | """
33 | Validates and potentially modifies search queries based on knowledge base availability.
34 |
35 | Args:
36 | queries: List of Query objects to validate
37 | kb_info: List of available knowledge base information
38 |
39 | Returns:
40 | List of validated query dictionaries
41 | """
42 | valid_kb_ids = {kb["id"] for kb in kb_info}
43 | validated_queries = []
44 |
45 | for query in queries:
46 | if query.knowledge_base_id in valid_kb_ids:
47 | validated_queries.append({"query": query.query, "kb_id": query.knowledge_base_id})
48 | else:
49 | # If invalid KB ID is found
50 | if len(kb_info) == 1:
51 | # If only one KB exists, use that
52 | validated_queries.append({"query": query.query, "kb_id": kb_info[0]["id"]})
53 | else:
54 | # If multiple KBs exist, create a query for each KB
55 | for kb in kb_info:
56 | validated_queries.append({"query": query.query, "kb_id": kb["id"]})
57 |
58 | return validated_queries
59 |
60 | def get_search_queries(chat_messages: list[dict], kb_info: list[dict], auto_query_guidance: str = "", max_queries: int = 5, auto_query_model: str = "gpt-4o-mini") -> List[dict]:
61 | """
62 | Input:
63 | - chat_messages: list of dictionaries, where each dictionary has the keys "role" and "content". This should include the current user input as the final message.
64 | - kb_info: list of dictionaries, where each dictionary has the keys "id", "title", and "description". This should include information about the available knowledge bases.
65 | - auto_query_guidance: str, optional additional instructions for the auto_query system
66 | - max_queries: int, maximum number of queries to generate
67 | - auto_query_model: str, the model to use for generating queries
68 |
69 | Returns
70 | - queries: list of dictionaries, where each dictionary has the keys "query" and "knowledge_base_id"
71 | """
72 | knowledge_base_descriptions = get_knowledge_base_descriptions_str(kb_info)
73 |
74 | system_message = SYSTEM_MESSAGE.format(
75 | max_queries=max_queries,
76 | auto_query_guidance=auto_query_guidance,
77 | knowledge_base_descriptions=knowledge_base_descriptions
78 | )
79 |
80 | messages = [{"role": "system", "content": system_message}] + chat_messages
81 |
82 | # Try with initial model
83 | try:
84 | queries = get_response(
85 | messages=messages,
86 | model_name=auto_query_model,
87 | response_model=Queries,
88 | max_tokens=600,
89 | temperature=0.0
90 | )
91 | except:
92 | # Fallback to a stronger model
93 | fallback_model = "claude-3-5-sonnet-20241022" if "gpt" in auto_query_model else "gpt-4o"
94 | queries = get_response(
95 | messages=messages,
96 | model_name=fallback_model,
97 | response_model=Queries,
98 | max_tokens=600,
99 | temperature=0.0
100 | )
101 |
102 | # Validate and potentially modify the queries
103 | validated_queries = validate_queries(queries.queries[:max_queries], kb_info)[:max_queries]
104 | return validated_queries
--------------------------------------------------------------------------------
/dsrag/chat/chat_types.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union, Literal
2 | from typing_extensions import TypedDict
3 | from pydantic import BaseModel
4 |
5 | class ChatThreadParams(TypedDict):
6 | kb_ids: Optional[list[str]]
7 | model: Optional[str]
8 | temperature: Optional[float]
9 | system_message: Optional[str]
10 | auto_query_model: Optional[str]
11 | auto_query_guidance: Optional[str]
12 | rse_params: Optional[dict]
13 | target_output_length: Optional[str]
14 | max_chat_history_tokens: Optional[int]
15 |
16 | class ChatResponseOutput(TypedDict):
17 | response: str
18 | metadata: dict
19 |
20 | class MetadataFilter(TypedDict):
21 | field: str
22 | operator: Literal['equals', 'not_equals', 'in', 'not_in', 'greater_than', 'less_than', 'greater_than_equals', 'less_than_equals']
23 | value: Union[str, int, float, list[str], list[int], list[float]]
24 |
25 | class ChatResponseInput(BaseModel):
26 | user_input: str
27 | chat_thread_params: Optional[ChatThreadParams] = None
28 | metadata_filter: Optional[MetadataFilter] = None
--------------------------------------------------------------------------------
/dsrag/chat/citations.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel, Field
2 | from typing import Optional, List
3 | import instructor
4 |
5 | class Citation(BaseModel):
6 | source_index: int = Field(..., description="The index of thesource where the information used to generate the response was found.")
7 | page_number: Optional[int] = Field(None, description="The page where the information was found. May only be None if page numbers are not provided.")
8 | cited_text: str = Field(..., description="The exact text containing the information used to generate the response.")
9 |
10 | class ResponseWithCitations(BaseModel):
11 | response: str = Field(..., description="The response to the user's question")
12 | citations: List[Citation] = Field(..., description="The citations used to generate the response")
13 |
14 | # Create a partial version of ResponseWithCitations for streaming
15 | PartialResponseWithCitations = instructor.Partial[ResponseWithCitations]
16 |
17 | def format_page_content(page_number: int, content: str) -> str:
18 | """Format a single page's content with page number annotations"""
19 | return f"\n{content}\n"
20 |
21 | def get_source_text(kb_id: str, doc_id: str, source_index: int, page_start: Optional[int], page_end: Optional[int], file_system) -> Optional[str]:
22 | """
23 | Get the source text for a given document and page range with page number annotations.
24 | Returns None if no page content is available.
25 | """
26 | if page_start is None or page_end is None:
27 | return None
28 |
29 | page_contents = file_system.load_page_content_range(kb_id, doc_id, page_start, page_end)
30 | if not page_contents:
31 | print(f"No page contents found for doc_id: {doc_id}, page_start: {page_start}, page_end: {page_end}")
32 | return None
33 |
34 | source_text = f"\n"
35 | for i, content in enumerate(page_contents):
36 | page_number = page_start + i
37 | source_text += format_page_content(page_number, content) + "\n"
38 | source_text += f""
39 |
40 | return source_text
41 |
42 | def format_sources_for_context(search_results: list[dict], kb_id: str, file_system) -> tuple[str, list[str]]:
43 | """
44 | Format all search results into a context string for the LLM.
45 | Handles both cases with and without page numbers.
46 | """
47 | context_parts = []
48 | all_doc_ids = []
49 |
50 | for result in search_results:
51 | doc_id = result["doc_id"]
52 | source_index = result["source_index"]
53 | page_start = result.get("segment_page_start")
54 | page_end = result.get("segment_page_end")
55 | all_doc_ids.append(doc_id)
56 |
57 | full_page_source_text = None
58 | if page_start is not None and page_end is not None:
59 | # if we have page numbers, then we assume we have the page content - but this is not always the case
60 | full_page_source_text = get_source_text(kb_id, doc_id, source_index, page_start, page_end, file_system)
61 |
62 | if full_page_source_text:
63 | context_parts.append(full_page_source_text)
64 | else:
65 | result_content = result.get("content", "")
66 | context_parts.append(f"\n{result_content}\n")
67 |
68 | return "\n\n".join(context_parts), all_doc_ids
69 |
70 | def convert_elements_to_page_content(elements: list[dict], kb_id: str, doc_id: str, file_system) -> None:
71 | """
72 | Convert elements to page content and save it using the page content methods.
73 | This should be called when a document is first added to the knowledge base.
74 | Only processes documents where elements have page numbers.
75 | """
76 | # Check if this document has page numbers
77 | if not elements or "page_number" not in elements[0]:
78 | print(f"No page numbers found for document {doc_id}")
79 | return
80 |
81 | # Group elements by page
82 | pages = {}
83 | for element in elements:
84 | page_num = element["page_number"]
85 | if page_num not in pages:
86 | pages[page_num] = []
87 | pages[page_num].append(element["content"])
88 |
89 | # Save each page's content
90 | for page_num, contents in pages.items():
91 | page_content = "\n".join(contents)
92 | file_system.save_page_content(kb_id, doc_id, page_num, page_content)
--------------------------------------------------------------------------------
/dsrag/chat/model_names.py:
--------------------------------------------------------------------------------
1 | # NOTE: this is deprecated but remains for backwards compatibility.
2 | # The new way to specify model names for Chat is {provider}/{model_name} so we don't have to do this lookup (and maintain an up-to-date mapping here).
3 |
4 | ANTHROPIC_MODEL_NAMES = [
5 | "claude-3-haiku-20240307",
6 | "claude-3-5-haiku-20241022",
7 | "claude-3-5-haiku-latest",
8 | "claude-3-5-sonnet-20240620",
9 | "claude-3-5-sonnet-20241022",
10 | "claude-3-5-sonnet-latest",
11 | "claude-3-7-sonnet-latest",
12 | "claude-3-7-sonnet-20250219",
13 | ]
14 |
15 | GEMINI_MODEL_NAMES = [
16 | "gemini-1.5-flash-002",
17 | "gemini-1.5-pro-002",
18 | "gemini-2.0-flash",
19 | "gemini-2.0-flash-lite",
20 | "gemini-2.0-flash-thinking-exp",
21 | "gemini-2.0-pro-exp",
22 | "gemini-2.5-pro-exp-03-25",
23 | "gemini-2.5-pro-preview-03-25",
24 | ]
25 |
26 | OPENAI_MODEL_NAMES = [
27 | "gpt-4o-mini",
28 | "gpt-4o-mini-2024-07-18",
29 | "gpt-4o",
30 | "gpt-4o-2024-05-13",
31 | "gpt-4o-2024-08-06",
32 | "gpt-4o-2024-11-20",
33 | "o1",
34 | "o1-2024-12-17"
35 | "o1-preview",
36 | "o1-preview-2024-09-12",
37 | "o3-mini",
38 | "o3-mini-2025-01-31",
39 | "gpt-4.5-preview",
40 | "gpt-4.5-preview-2025-02-27",
41 | "chatgpt-4o-latest",
42 | ]
--------------------------------------------------------------------------------
/dsrag/create_kb.py:
--------------------------------------------------------------------------------
1 | from dsrag.dsparse.file_parsing.non_vlm_file_parsing import extract_text_from_pdf, extract_text_from_docx
2 | from dsrag.knowledge_base import KnowledgeBase
3 | import os
4 | import time
5 |
6 | def create_kb_from_directory(kb_id: str, directory: str, title: str = None, description: str = "", language: str = 'en'):
7 | """
8 | - kb_id is the name of the knowledge base
9 | - directory is the absolute path to the directory containing the documents
10 | - no support for manually defined chunk headers here, because they would have to be defined for each file in the directory
11 |
12 | Supported file types: .docx, .md, .txt, .pdf
13 | """
14 | if not title:
15 | title = kb_id
16 |
17 | # create a new KB
18 | kb = KnowledgeBase(kb_id, title=title, description=description, language=language, exists_ok=False)
19 |
20 | # add documents
21 | for root, dirs, files in os.walk(directory):
22 | for file_name in files:
23 | if file_name.endswith(('.docx', '.md', '.txt', '.pdf')):
24 | try:
25 | file_path = os.path.join(root, file_name)
26 | clean_file_path = file_path.replace(directory, "")
27 |
28 | if file_name.endswith('.docx'):
29 | text = extract_text_from_docx(file_path)
30 | elif file_name.endswith('.pdf'):
31 | text, _ = extract_text_from_pdf(file_path)
32 | elif file_name.endswith('.md') or file_name.endswith('.txt'):
33 | with open(file_path, 'r') as f:
34 | text = f.read()
35 |
36 | kb.add_document(doc_id=clean_file_path, text=text)
37 | time.sleep(1) # pause for 1 second to avoid hitting API rate limits
38 | except:
39 | print (f"Error reading {file_name}")
40 | continue
41 | else:
42 | print (f"Unsupported file type: {file_name}")
43 | continue
44 |
45 | return kb
46 |
47 | def create_kb_from_file(kb_id: str, file_path: str, title: str = None, description: str = "", language: str = 'en'):
48 | """
49 | - kb_id is the name of the knowledge base
50 | - file_path is the absolute path to the file containing the documents
51 |
52 | Supported file types: .docx, .md, .txt, .pdf
53 | """
54 | if not title:
55 | title = kb_id
56 |
57 | # create a new KB
58 | kb = KnowledgeBase(kb_id, title=title, description=description, language=language, exists_ok=False)
59 |
60 | print (f'Creating KB with id {kb_id}...')
61 |
62 | file_name = os.path.basename(file_path)
63 |
64 | # add document
65 | if file_path.endswith(('.docx', '.md', '.txt', '.pdf')):
66 | # define clean file path as just the file name here since we're not using a directory
67 | clean_file_path = file_name
68 |
69 | if file_path.endswith('.docx'):
70 | text = extract_text_from_docx(file_path)
71 | elif file_name.endswith('.pdf'):
72 | text, _ = extract_text_from_pdf(file_path)
73 | elif file_path.endswith('.md') or file_path.endswith('.txt'):
74 | with open(file_path, 'r') as f:
75 | text = f.read()
76 |
77 | kb.add_document(doc_id=clean_file_path, text=text)
78 | else:
79 | print (f"Unsupported file type: {file_name}")
80 | return
81 |
82 | return kb
--------------------------------------------------------------------------------
/dsrag/custom_term_mapping.py:
--------------------------------------------------------------------------------
1 | import re
2 | from typing import List, Dict
3 | from pydantic import BaseModel, Field
4 | from dsrag.chat.instructor_get_response import get_response
5 |
6 | FIND_TARGET_TERMS_PROMPT = """
7 | Your task is to take a list of target terms and return a list of all instances of those terms in the provided text. The instances you return must be direct matches to the text, but can be fuzzy matches to the terms. In other words, you should allow for some variation in spelling and capitalization relative to the terms. For example, if the synonym is "capital of France", and the test has the phrase "Capitol of Frances", you should return the phrase "Capitol of Frances".
8 |
9 |
10 | {target_terms}
11 |
12 |
13 |
14 | {text}
15 |
16 | """.strip()
17 |
18 | class Terms(BaseModel):
19 | terms: List[str] = Field(description="A list of target terms (or variations thereof) that were found in the text")
20 |
21 | def find_target_terms_batch(text: str, target_terms: List[str]) -> List[str]:
22 | """Find fuzzy matches of target terms in a batch of text using LLM."""
23 | target_terms_str = "\n".join(target_terms)
24 | terms = get_response(
25 | prompt=FIND_TARGET_TERMS_PROMPT.format(target_terms=target_terms_str, text=text),
26 | response_model=Terms,
27 | ).terms
28 | return terms
29 |
30 | def find_all_term_variations(chunks: List[str], target_terms: List[str]) -> List[str]:
31 | """Find all variations of target terms across all chunks, processing in batches."""
32 | all_variations = set()
33 | batch_size = 15 # Process 15 chunks at a time
34 |
35 | for i in range(0, len(chunks), batch_size):
36 | batch_chunks = chunks[i:i + batch_size]
37 | batch_text = "\n\n".join(batch_chunks)
38 | variations = find_target_terms_batch(batch_text, target_terms)
39 | all_variations.update(variations)
40 |
41 | return list(all_variations)
42 |
43 | def annotate_chunk(chunk: str, key: str, terms: List[str]) -> str:
44 | """Annotate a single chunk with terms and their mapping."""
45 | # Process instances in reverse order to avoid messing up string indices
46 | for term in terms:
47 | instances = re.finditer(re.escape(term), chunk)
48 | instances = list(instances)
49 | for instance in reversed(instances):
50 | start = instance.start()
51 | end = instance.end()
52 | chunk = chunk[:end] + f" ({key})" + chunk[end:]
53 | return chunk.strip()
54 |
55 | def annotate_chunks(chunks: List[str], custom_term_mapping: Dict[str, List[str]]) -> List[str]:
56 | """
57 | Annotate chunks with custom term mappings.
58 |
59 | Args:
60 | chunks: list[str] - list of all chunks in the document
61 | custom_term_mapping: dict - a dictionary of custom term mapping for the document
62 | - key: str - the term to map to
63 | - value: list[str] - the list of terms to map to the key
64 |
65 | Returns:
66 | list[str] - annotated chunks
67 | """
68 | # First, find all variations for each key's terms
69 | term_variations = {}
70 | for key, target_terms in custom_term_mapping.items():
71 | variations = find_all_term_variations(chunks, target_terms)
72 | term_variations[key] = list(set(variations + target_terms)) # Include original terms
73 | # remove an exact match of the key from the list of terms
74 | term_variations[key] = [term for term in term_variations[key] if term != key]
75 |
76 | # Then annotate each chunk with all terms
77 | annotated_chunks = []
78 | for chunk in chunks:
79 | annotated_chunk = chunk
80 | for key, terms in term_variations.items():
81 | annotated_chunk = annotate_chunk(annotated_chunk, key, terms)
82 | annotated_chunks.append(annotated_chunk)
83 |
84 | return annotated_chunks
--------------------------------------------------------------------------------
/dsrag/database/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/D-Star-AI/dsRAG/b326c1d0930c9a80ff486ab0df5e771e788e3234/dsrag/database/__init__.py
--------------------------------------------------------------------------------
/dsrag/database/chat_thread/basic_db.py:
--------------------------------------------------------------------------------
1 | import json
2 | from .db import ChatThreadDB
3 | import uuid
4 |
5 | class BasicChatThreadDB(ChatThreadDB):
6 | def __init__(self):
7 | try:
8 | with open("chat_thread_db.json", "r") as f:
9 | self.chat_threads = json.load(f)
10 | except FileNotFoundError:
11 | self.chat_threads = {}
12 |
13 | def create_chat_thread(self, chat_thread_params: dict) -> dict:
14 | """
15 | chat_thread_params: dict with the following keys:
16 | - thread_id: str
17 | - kb_ids: list[str]
18 | - model: str
19 | - title: str
20 | - temperature: float
21 | - system_message: str
22 | - auto_query_guidance: str
23 | - target_output_length: str
24 | - max_chat_history_tokens: int
25 |
26 | Returns
27 | chat_thread: dict with the following keys:
28 | - params: dict
29 | - interactions: list[dict]
30 | """
31 |
32 | chat_thread = {
33 | "params": chat_thread_params,
34 | "interactions": []
35 | }
36 |
37 | self.chat_threads[chat_thread_params["thread_id"]] = chat_thread
38 | self.save()
39 | return chat_thread
40 |
41 | def list_chat_threads(self) -> list[dict]:
42 | return [self.chat_threads[thread_id] for thread_id in self.chat_threads]
43 |
44 | def get_chat_thread(self, thread_id: str) -> dict:
45 | return self.chat_threads.get(thread_id)
46 |
47 | def update_chat_thread(self, thread_id: str, chat_thread_params: dict) -> dict:
48 | self.chat_threads[thread_id]["params"] = chat_thread_params
49 | self.save()
50 | return chat_thread_params
51 |
52 | def delete_chat_thread(self, thread_id: str) -> dict:
53 | chat_thread = self.chat_threads.pop(thread_id, None)
54 | self.save()
55 | return chat_thread
56 |
57 | def add_interaction(self, thread_id: str, interaction: dict) -> dict:
58 | """
59 | interaction should be a dict with the following keys:
60 | - user_input: dict
61 | - content: str
62 | - timestamp: str
63 | - model_response: dict
64 | - content: str
65 | - citations: list[dict]
66 | - doc_id: str
67 | - page_numbers: list[int]
68 | - cited_text: str
69 | - timestamp: str
70 | - status: str (pending, streaming, highlighting_citations, finished, failed)
71 | - relevant_segments: list[dict]
72 | - text: str
73 | - doc_id: str
74 | - kb_id: str
75 | - search_queries: list[dict]
76 | - query: str
77 | - kb_id: str
78 | """
79 | message_id = str(uuid.uuid4())
80 | interaction["message_id"] = message_id
81 |
82 | # Set default status if not provided
83 | if "model_response" in interaction and "status" not in interaction["model_response"]:
84 | interaction["model_response"]["status"] = "pending"
85 |
86 | self.chat_threads[thread_id]["interactions"].append(interaction)
87 | self.save()
88 | return interaction
89 |
90 | def update_interaction(self, thread_id: str, message_id: str, interaction_update: dict) -> dict:
91 | """
92 | Updates an existing interaction in a chat thread.
93 | Only updates the fields provided in interaction_update.
94 | """
95 | # Find the interaction with the matching message_id
96 | for i, interaction in enumerate(self.chat_threads[thread_id]["interactions"]):
97 | if interaction.get("message_id") == message_id:
98 | # Update only the fields provided in interaction_update
99 | if "model_response" in interaction_update:
100 | interaction["model_response"].update(interaction_update["model_response"])
101 |
102 | # Ensure status field exists (for backward compatibility)
103 | if "status" not in interaction["model_response"]:
104 | interaction["model_response"]["status"] = "finished"
105 |
106 | # Save the changes
107 | self.save()
108 | return {"message_id": message_id, "updated": True}
109 |
110 | return {"message_id": message_id, "updated": False, "error": "Interaction not found"}
111 |
112 | def save(self):
113 | """
114 | Save the ChatThreadDB to a JSON file.
115 | """
116 | with open("chat_thread_db.json", "w") as f:
117 | json.dump(self.chat_threads, f)
118 |
119 | def load(self):
120 | with open("chat_thread_db.json", "r") as f:
121 | self.chat_threads = json.load(f)
--------------------------------------------------------------------------------
/dsrag/database/chat_thread/db.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | class ChatThreadDB(ABC):
4 | @abstractmethod
5 | def create_chat_thread(self, chat_thread_params: dict) -> dict:
6 | """
7 | Creates a chat thread in the database.
8 | """
9 | pass
10 |
11 | @abstractmethod
12 | def list_chat_threads(self) -> list[dict]:
13 | """
14 | Lists all chat threads in the database.
15 | """
16 | pass
17 |
18 | @abstractmethod
19 | def get_chat_thread(self, thread_id: str) -> dict:
20 | """
21 | Gets a chat thread by ID.
22 | """
23 | pass
24 |
25 | @abstractmethod
26 | def update_chat_thread(self, thread_id: str, chat_thread_params: dict) -> dict:
27 | """
28 | Updates a chat thread by ID.
29 | """
30 | pass
31 |
32 | @abstractmethod
33 | def delete_chat_thread(self, thread_id: str) -> dict:
34 | """
35 | Deletes a chat thread by ID.
36 | """
37 | pass
38 |
39 | @abstractmethod
40 | def add_interaction(self, thread_id: str, interaction: dict) -> dict:
41 | """
42 | Adds an interaction to a chat thread.
43 | """
44 | pass
45 |
46 | @abstractmethod
47 | def update_interaction(self, thread_id: str, message_id: str, interaction_update: dict) -> dict:
48 | """
49 | Updates an existing interaction in a chat thread.
50 | Only updates the fields provided in interaction_update.
51 | """
52 | pass
--------------------------------------------------------------------------------
/dsrag/database/chunk/__init__.py:
--------------------------------------------------------------------------------
1 | from .db import ChunkDB
2 | from .types import FormattedDocument
3 |
4 | # Always import the basic DB as it has no dependencies
5 | from .basic_db import BasicChunkDB
6 |
7 | # Define what's in __all__ for "from dsrag.database.chunk import *"
8 | __all__ = [
9 | "ChunkDB",
10 | "BasicChunkDB",
11 | "FormattedDocument"
12 | ]
13 |
14 | # Lazy load database modules to avoid importing all dependencies at once
15 | def __getattr__(name):
16 | """Lazily import database implementations only when accessed."""
17 | if name == "SQLiteDB":
18 | from .sqlite_db import SQLiteDB
19 | return SQLiteDB
20 | elif name == "PostgresChunkDB":
21 | from .postgres_db import PostgresChunkDB
22 | return PostgresChunkDB
23 | elif name == "DynamoDB":
24 | from .dynamo_db import DynamoDB
25 | return DynamoDB
26 | else:
27 | raise AttributeError(f"module 'dsrag.database.chunk' has no attribute '{name}'")
28 |
--------------------------------------------------------------------------------
/dsrag/database/chunk/db.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Any, Optional
3 |
4 | from dsrag.database.chunk.types import FormattedDocument
5 |
6 |
7 | class ChunkDB(ABC):
8 | subclasses = {}
9 |
10 | def __init_subclass__(cls, **kwargs):
11 | super().__init_subclass__(**kwargs)
12 | cls.subclasses[cls.__name__] = cls
13 |
14 | def to_dict(self):
15 | return {
16 | "subclass_name": self.__class__.__name__,
17 | }
18 |
19 | @classmethod
20 | def from_dict(cls, config) -> "ChunkDB":
21 | subclass_name = config.pop(
22 | "subclass_name", None
23 | ) # Remove subclass_name from config
24 | subclass = cls.subclasses.get(subclass_name)
25 | if subclass:
26 | return subclass(**config) # Pass the modified config without subclass_name
27 | else:
28 | raise ValueError(f"Unknown subclass: {subclass_name}")
29 |
30 | @abstractmethod
31 | def add_document(self, doc_id: str, chunks: dict[int, dict[str, Any]], supp_id: str = "", metadata: dict = {}) -> None:
32 | """
33 | Store all chunks for a given document.
34 | """
35 | pass
36 |
37 | @abstractmethod
38 | def remove_document(self, doc_id: str) -> None:
39 | """
40 | Remove all chunks and metadata associated with a given document ID.
41 | """
42 | pass
43 |
44 | @abstractmethod
45 | def get_chunk_text(self, doc_id: str, chunk_index: int) -> Optional[str]:
46 | """
47 | Retrieve a specific chunk from a given document ID.
48 | """
49 | pass
50 |
51 | @abstractmethod
52 | def get_is_visual(self, doc_id: str, chunk_index: int) -> Optional[bool]:
53 | """
54 | Retrieve the is_visual flag of a specific chunk from a given document ID.
55 | """
56 | pass
57 |
58 | @abstractmethod
59 | def get_chunk_page_numbers(self, doc_id: str, chunk_index: int) -> Optional[tuple[int, int]]:
60 | """
61 | Retrieve the page numbers of a specific chunk from a given document ID.
62 | """
63 | pass
64 |
65 | @abstractmethod
66 | def get_document(self, doc_id: str) -> Optional[FormattedDocument]:
67 | """
68 | Retrieve all chunks from a given document ID.
69 | """
70 | pass
71 |
72 | @abstractmethod
73 | def get_document_title(self, doc_id: str, chunk_index: int) -> Optional[str]:
74 | """
75 | Retrieve the document title of a specific chunk from a given document ID.
76 | """
77 | pass
78 |
79 | @abstractmethod
80 | def get_document_summary(self, doc_id: str, chunk_index: int) -> Optional[str]:
81 | """
82 | Retrieve the document summary of a specific chunk from a given document ID.
83 | """
84 | pass
85 |
86 | @abstractmethod
87 | def get_section_title(self, doc_id: str, chunk_index: int) -> Optional[str]:
88 | """
89 | Retrieve the section title of a specific chunk from a given document ID.
90 | """
91 | pass
92 |
93 | @abstractmethod
94 | def get_section_summary(self, doc_id: str, chunk_index: int) -> Optional[str]:
95 | """
96 | Retrieve the section summary of a specific chunk from a given document ID.
97 | """
98 | pass
99 |
100 | @abstractmethod
101 | def get_all_doc_ids(self, supp_id: Optional[str] = None) -> list[str]:
102 | """
103 | Retrieve all document IDs.
104 | """
105 | pass
106 |
107 | @abstractmethod
108 | def delete(self) -> None:
109 | """
110 | Delete the chunk database.
111 | """
112 | pass
113 |
--------------------------------------------------------------------------------
/dsrag/database/chunk/types.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | from typing import Optional
3 | from typing_extensions import TypedDict
4 |
5 | class FormattedDocument(TypedDict):
6 | id: str
7 | title: str
8 | content: Optional[str]
9 | summary: Optional[str]
10 | created_on: Optional[datetime]
11 | supp_id: Optional[str]
12 | metadata: Optional[dict]
13 | chunk_count: int
14 |
--------------------------------------------------------------------------------
/dsrag/database/vector/__init__.py:
--------------------------------------------------------------------------------
1 | from .db import VectorDB
2 | from .types import VectorSearchResult, Vector, ChunkMetadata
3 |
4 | # Always import the basic DB as it has no dependencies
5 | from .basic_db import BasicVectorDB
6 |
7 | # Define what's in __all__ for "from dsrag.database.vector import *"
8 | __all__ = [
9 | "VectorDB",
10 | "BasicVectorDB",
11 | "VectorSearchResult",
12 | "Vector",
13 | "ChunkMetadata"
14 | ]
15 |
16 | # Lazy load database modules to avoid importing all dependencies at once
17 | def __getattr__(name):
18 | """Lazily import database implementations only when accessed."""
19 | if name == "ChromaDB":
20 | from .chroma_db import ChromaDB
21 | return ChromaDB
22 | elif name == "WeaviateVectorDB":
23 | from .weaviate_db import WeaviateVectorDB
24 | return WeaviateVectorDB
25 | elif name == "QdrantVectorDB":
26 | from .qdrant_db import QdrantVectorDB
27 | return QdrantVectorDB
28 | elif name == "MilvusDB":
29 | from .milvus_db import MilvusDB
30 | return MilvusDB
31 | elif name == "PostgresVectorDB":
32 | from .postgres_db import PostgresVectorDB
33 | return PostgresVectorDB
34 | elif name == "PineconeDB":
35 | from .pinecone_db import PineconeDB
36 | return PineconeDB
37 | else:
38 | raise AttributeError(f"module 'dsrag.database.vector' has no attribute '{name}'")
39 |
--------------------------------------------------------------------------------
/dsrag/database/vector/basic_db.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | from dsrag.database.vector.db import VectorDB
3 | from typing import Sequence, Optional
4 | from dsrag.database.vector.types import ChunkMetadata, Vector, VectorSearchResult
5 | from sklearn.metrics.pairwise import cosine_similarity
6 | import os
7 | import numpy as np
8 | from dsrag.utils.imports import faiss
9 |
10 |
11 | class BasicVectorDB(VectorDB):
12 | def __init__(
13 | self, kb_id: str, storage_directory: str = "~/dsRAG", use_faiss: bool = False
14 | ) -> None:
15 | self.kb_id = kb_id
16 | self.storage_directory = storage_directory
17 | self.use_faiss = use_faiss
18 | self.vector_storage_path = os.path.join(
19 | self.storage_directory, "vector_storage", f"{kb_id}.pkl"
20 | )
21 | self.load()
22 |
23 | def add_vectors(
24 | self, vectors: Sequence[Vector], metadata: Sequence[ChunkMetadata]
25 | ) -> None:
26 | try:
27 | assert len(vectors) == len(metadata)
28 | except AssertionError:
29 | raise ValueError(
30 | "Error in add_vectors: the number of vectors and metadata items must be the same."
31 | )
32 | self.vectors.extend(vectors)
33 | self.metadata.extend(metadata)
34 | self.save()
35 |
36 | def search(self, query_vector, top_k=10, metadata_filter: Optional[dict] = None) -> list[VectorSearchResult]:
37 | if not self.vectors:
38 | return []
39 |
40 | if self.use_faiss:
41 | try:
42 | return self.search_faiss(query_vector, top_k)
43 | except Exception as e:
44 | print(f"Faiss search failed: {e}. Falling back to numpy search.")
45 | return self._fallback_search(query_vector, top_k)
46 | else:
47 | return self._fallback_search(query_vector, top_k)
48 |
49 | def _fallback_search(self, query_vector, top_k=10) -> list[VectorSearchResult]:
50 | """Fallback search method using numpy when faiss is not available."""
51 | similarities = cosine_similarity([query_vector], self.vectors)[0]
52 | indexed_similarities = sorted(
53 | enumerate(similarities), key=lambda x: x[1], reverse=True
54 | )
55 | results: list[VectorSearchResult] = []
56 | for i, similarity in indexed_similarities[:top_k]:
57 | result = VectorSearchResult(
58 | doc_id=None,
59 | vector=None,
60 | metadata=self.metadata[i],
61 | similarity=similarity,
62 | )
63 | results.append(result)
64 | return results
65 |
66 | def search_faiss(self, query_vector, top_k=10) -> list[VectorSearchResult]:
67 | # Limit top_k to the number of vectors we have - Faiss doesn't automatically handle this
68 | top_k = min(top_k, len(self.vectors))
69 |
70 | # faiss expects 2D arrays of vectors
71 | vectors_array = (
72 | np.array(self.vectors).astype("float32").reshape(len(self.vectors), -1)
73 | )
74 | query_vector_array = np.array(query_vector).astype("float32").reshape(1, -1)
75 |
76 | try:
77 | # Access nested modules step by step
78 | contrib = faiss.contrib
79 | exhaustive_search = contrib.exhaustive_search
80 | knn_func = exhaustive_search.knn
81 |
82 | _, I = knn_func(query_vector_array, vectors_array, top_k)
83 | except AttributeError:
84 | # If the nested module structure doesn't work, try direct import as a fallback
85 | try:
86 | from faiss.contrib.exhaustive_search import knn as direct_knn
87 | _, I = direct_knn(query_vector_array, vectors_array, top_k)
88 | except (ImportError, AttributeError):
89 | # If both approaches fail, raise an error
90 | raise ImportError(
91 | "Could not import faiss.contrib.exhaustive_search.knn. "
92 | "Please ensure faiss is installed correctly."
93 | )
94 |
95 | # I is a list of indices in the corpus_vectors array
96 | results: list[VectorSearchResult] = []
97 | for i in I[0]:
98 | result = VectorSearchResult(
99 | doc_id=None,
100 | vector=None,
101 | metadata=self.metadata[i],
102 | similarity=cosine_similarity([query_vector], [self.vectors[i]])[0][0],
103 | )
104 | results.append(result)
105 | return results
106 |
107 | def remove_document(self, doc_id):
108 | i = 0
109 | while i < len(self.metadata):
110 | if self.metadata[i]["doc_id"] == doc_id:
111 | del self.vectors[i]
112 | del self.metadata[i]
113 | else:
114 | i += 1
115 | self.save()
116 |
117 | def save(self):
118 | os.makedirs(
119 | os.path.dirname(self.vector_storage_path), exist_ok=True
120 | ) # Ensure the directory exists
121 | with open(self.vector_storage_path, "wb") as f:
122 | pickle.dump((self.vectors, self.metadata), f)
123 |
124 | def load(self):
125 | if os.path.exists(self.vector_storage_path):
126 | with open(self.vector_storage_path, "rb") as f:
127 | self.vectors, self.metadata = pickle.load(f)
128 | else:
129 | self.vectors = []
130 | self.metadata = []
131 |
132 | def delete(self):
133 | if os.path.exists(self.vector_storage_path):
134 | os.remove(self.vector_storage_path)
135 |
136 | def to_dict(self):
137 | return {
138 | **super().to_dict(),
139 | "kb_id": self.kb_id,
140 | "storage_directory": self.storage_directory,
141 | "use_faiss": self.use_faiss,
142 | }
143 |
--------------------------------------------------------------------------------
/dsrag/database/vector/chroma_db.py:
--------------------------------------------------------------------------------
1 | from dsrag.database.vector.db import VectorDB
2 | from dsrag.database.vector.types import VectorSearchResult, MetadataFilter
3 | from typing import Optional
4 | import os
5 | import numpy as np
6 | from dsrag.utils.imports import LazyLoader
7 |
8 | # Lazy load chromadb
9 | chromadb = LazyLoader("chromadb")
10 |
11 |
12 | def format_metadata_filter(metadata_filter: MetadataFilter) -> dict:
13 | """
14 | Format the metadata filter to be used in the ChromaDB query method.
15 |
16 | Args:
17 | metadata_filter (dict): The metadata filter.
18 |
19 | Returns:
20 | dict: The formatted metadata filter.
21 | """
22 |
23 | field = metadata_filter["field"]
24 | operator = metadata_filter["operator"]
25 | value = metadata_filter["value"]
26 |
27 | operator_mapping = {
28 | "equals": "$eq",
29 | "not_equals": "$ne",
30 | "in": "$in",
31 | "not_in": "$nin",
32 | "greater_than": "$gt",
33 | "less_than": "$lt",
34 | "greater_than_equals": "$gte",
35 | "less_than_equals": "$lte",
36 | }
37 |
38 | formatted_operator = operator_mapping[operator]
39 | formatted_metadata_filter = {field: {formatted_operator: value}}
40 |
41 | return formatted_metadata_filter
42 |
43 |
44 | class ChromaDB(VectorDB):
45 |
46 | def __init__(self, kb_id: str, storage_directory: str = "~/dsRAG"):
47 | self.kb_id = kb_id
48 | self.storage_directory = os.path.expanduser(storage_directory)
49 | self.vector_storage_path = os.path.join(
50 | self.storage_directory, "vector_storage"
51 | )
52 | self.client = chromadb.PersistentClient(path=self.vector_storage_path)
53 | self.collection = self.client.get_or_create_collection(
54 | kb_id, metadata={"hnsw:space": "cosine"}
55 | )
56 |
57 | def get_num_vectors(self):
58 | return self.collection.count()
59 |
60 | def add_vectors(self, vectors: list, metadata: list):
61 |
62 | # Convert NumPy arrays to lists
63 | vectors_as_lists = [vector.tolist() if isinstance(vector, np.ndarray) else vector for vector in vectors]
64 | try:
65 | assert len(vectors_as_lists) == len(metadata)
66 | except AssertionError:
67 | raise ValueError(
68 | "Error in add_vectors: the number of vectors and metadata items must be the same."
69 | )
70 |
71 | # Create the ids from the metadata, defined as {metadata["doc_id"]}_{metadata["chunk_index"]}
72 | ids = [f"{meta['doc_id']}_{meta['chunk_index']}" for meta in metadata]
73 |
74 | self.collection.add(embeddings=vectors_as_lists, metadatas=metadata, ids=ids)
75 |
76 | def search(self, query_vector, top_k=10, metadata_filter: Optional[MetadataFilter] = None) -> list[VectorSearchResult]:
77 |
78 | num_vectors = self.get_num_vectors()
79 | if num_vectors == 0:
80 | return []
81 | # raise ValueError('No vectors stored in the database.')
82 |
83 | if metadata_filter:
84 | formatted_metadata_filter = format_metadata_filter(metadata_filter)
85 |
86 | if isinstance(query_vector, np.ndarray):
87 | query_vector = query_vector.tolist()
88 | query_results = self.collection.query(
89 | query_embeddings=[query_vector],
90 | n_results=top_k,
91 | include=["distances", "metadatas"],
92 | where=formatted_metadata_filter if metadata_filter else None
93 | )
94 |
95 | metadata = query_results["metadatas"][0]
96 | distances = query_results["distances"][0]
97 |
98 | results: list[VectorSearchResult] = []
99 | for _, (distance, metadata) in enumerate(zip(distances, metadata)):
100 | results.append(
101 | VectorSearchResult(
102 | doc_id=metadata["doc_id"],
103 | vector=None,
104 | metadata=metadata,
105 | similarity=1 - distance,
106 | )
107 | )
108 |
109 | results = sorted(results, key=lambda x: x["similarity"], reverse=True)
110 |
111 | return results
112 |
113 | def remove_document(self, doc_id: str):
114 | data = self.collection.get(where={"doc_id": doc_id})
115 | doc_ids = data["ids"]
116 | self.collection.delete(ids=doc_ids)
117 |
118 | def delete(self):
119 | self.client.delete_collection(name=self.kb_id)
120 |
121 | def to_dict(self):
122 | return {
123 | **super().to_dict(),
124 | "kb_id": self.kb_id,
125 | }
126 |
--------------------------------------------------------------------------------
/dsrag/database/vector/db.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Sequence, Optional
3 | from dsrag.database.vector.types import ChunkMetadata, Vector, VectorSearchResult
4 |
5 |
6 | class VectorDB(ABC):
7 | subclasses = {}
8 |
9 | def __init_subclass__(cls, **kwargs):
10 | super().__init_subclass__(**kwargs)
11 | cls.subclasses[cls.__name__] = cls
12 |
13 | def to_dict(self):
14 | return {
15 | "subclass_name": self.__class__.__name__,
16 | }
17 |
18 | @classmethod
19 | def from_dict(cls, config):
20 | subclass_name = config.pop(
21 | "subclass_name", None
22 | ) # Remove subclass_name from config
23 | subclass = cls.subclasses.get(subclass_name)
24 | if subclass:
25 | return subclass(**config) # Pass the modified config without subclass_name
26 | else:
27 | raise ValueError(f"Unknown subclass: {subclass_name}")
28 |
29 | @abstractmethod
30 | def add_vectors(
31 | self, vectors: Sequence[Vector], metadata: Sequence[ChunkMetadata]
32 | ) -> None:
33 | """
34 | Store a list of vectors with associated metadata.
35 | """
36 | pass
37 |
38 | @abstractmethod
39 | def remove_document(self, doc_id) -> None:
40 | """
41 | Remove all vectors and metadata associated with a given document ID.
42 | """
43 | pass
44 |
45 | @abstractmethod
46 | def search(self, query_vector, top_k: int=10, metadata_filter: Optional[dict] = None) -> list[VectorSearchResult]:
47 | """
48 | Retrieve the top-k closest vectors to a given query vector.
49 | - needs to return results as list of dictionaries in this format:
50 | {
51 | 'metadata':
52 | {
53 | 'doc_id': doc_id,
54 | 'chunk_index': chunk_index,
55 | 'chunk_header': chunk_header,
56 | 'chunk_text': chunk_text
57 | },
58 | 'similarity': similarity,
59 | }
60 | """
61 | pass
62 |
63 | @abstractmethod
64 | def delete(self) -> None:
65 | """
66 | Delete the vector database.
67 | """
68 | pass
69 |
--------------------------------------------------------------------------------
/dsrag/database/vector/milvus_db.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Optional, Sequence
3 |
4 | from dsrag.database.vector import VectorSearchResult
5 | from dsrag.database.vector.db import VectorDB
6 | from dsrag.database.vector.types import MetadataFilter, Vector, ChunkMetadata
7 | from dsrag.utils.imports import LazyLoader
8 |
9 | # Lazy load pymilvus
10 | pymilvus = LazyLoader("pymilvus")
11 |
12 | def _convert_metadata_to_expr(metadata_filter: MetadataFilter) -> str:
13 | """
14 | Convert the metadata filter to the expression used in the Milvus query method.
15 |
16 | Args:
17 | metadata_filter (dict): The metadata filter.
18 |
19 | Returns:
20 | str: The formatted expression.
21 | """
22 |
23 | if not metadata_filter:
24 | return ""
25 | field = metadata_filter["field"]
26 | operator = metadata_filter["operator"]
27 | value = metadata_filter["value"]
28 |
29 | operator_mapping = {
30 | "equals": "==",
31 | "not_equals": "!=",
32 | "in": "in",
33 | "not_in": "not in",
34 | "greater_than": ">",
35 | "less_than": "<",
36 | "greater_than_equals": ">=",
37 | "less_than_equals": "<=",
38 | }
39 |
40 | formatted_operator = operator_mapping[operator]
41 | if isinstance(value, str):
42 | value = f'"{value}"'
43 | formatted_metadata_filter = f"metadata['{field}'] {formatted_operator} {value}"
44 |
45 | return formatted_metadata_filter
46 |
47 |
48 | class MilvusDB(VectorDB):
49 | def __init__(self, kb_id: str, storage_directory: str = '~/dsRAG',
50 | dimension: int = 768):
51 | """
52 | Initialize a MilvusDB instance.
53 |
54 | Args:
55 | kb_id (str): The knowledge base ID.
56 | storage_directory (str, optional): The storage directory. Defaults to '~/dsRAG'.
57 | dimension (int, optional): The dimension of the vectors. Defaults to 768.
58 | """
59 | self.kb_id = kb_id
60 | self.storage_directory = os.path.expanduser(storage_directory)
61 | self.collection_name = kb_id
62 | self.dimension = dimension
63 |
64 | # Initialize Milvus client
65 | self.client = pymilvus.MilvusClient(uri='http://localhost:19530')
66 |
67 | # Create collection if it doesn't exist
68 | try:
69 | if not self.client.has_collection(self.collection_name):
70 | self._create_collection(self.collection_name, dimension)
71 | except Exception as e:
72 | print(f"Error initializing Milvus: {e}")
73 | raise
74 |
75 | def _create_collection(self, collection_name: str, dimension: int = 768):
76 | """
77 | Create a collection in Milvus.
78 |
79 | Args:
80 | collection_name (str): The name of the collection.
81 | dimension (int, optional): The dimension of the vectors. Defaults to 768.
82 | """
83 | schema = {
84 | "id": {"data_type": pymilvus.DataType.VARCHAR, "is_primary": True, "max_length": 100},
85 | "doc_id": {"data_type": pymilvus.DataType.VARCHAR, "max_length": 100},
86 | "chunk_index": {"data_type": pymilvus.DataType.INT64},
87 | "text": {"data_type": pymilvus.DataType.VARCHAR, "max_length": 65535},
88 | "embedding": {"data_type": pymilvus.DataType.FLOAT_VECTOR, "dim": dimension}
89 | }
90 |
91 | self.client.create_collection(
92 | collection_name=collection_name,
93 | schema=schema,
94 | index_params={
95 | "field_name": "embedding",
96 | "index_type": "HNSW",
97 | "metric_type": "COSINE",
98 | "params": {"M": 8, "efConstruction": 64}
99 | }
100 | )
101 |
102 | def add_vectors(self, vectors: Sequence[Vector], metadata: Sequence[ChunkMetadata]):
103 | try:
104 | assert len(vectors) == len(metadata)
105 | except AssertionError:
106 | raise ValueError(
107 | 'Error in add_vectors: the number of vectors and metadata items must be the same.')
108 |
109 | # Create the ids from the metadata, defined as {metadata["doc_id"]}_{metadata["chunk_index"]}
110 | ids = [f"{meta['doc_id']}_{meta['chunk_index']}" for meta in metadata]
111 |
112 | data = []
113 | for i, vector in enumerate(vectors):
114 | data.append({
115 | 'doc_id': ids[i],
116 | 'vector': vector,
117 | 'metadata': metadata[i]
118 | })
119 |
120 | self.client.upsert(
121 | collection_name=self.kb_id,
122 | data=data
123 | )
124 |
125 | def search(self, query_vector, top_k: int=10, metadata_filter: Optional[dict] = None) -> list[VectorSearchResult]:
126 | query_results = self.client.search(
127 | collection_name=self.kb_id,
128 | data=[query_vector],
129 | filter=_convert_metadata_to_expr(metadata_filter),
130 | limit=top_k,
131 | output_fields=["*"]
132 | )[0]
133 | results: list[VectorSearchResult] = []
134 | for res in query_results:
135 | results.append(
136 | VectorSearchResult(
137 | doc_id=res['entity']['doc_id'],
138 | metadata=res['entity']['metadata'],
139 | similarity=res['distance'],
140 | vector=res['entity']['vector'],
141 | )
142 | )
143 |
144 | results = sorted(results, key=lambda x: x['similarity'], reverse=True)
145 |
146 | return results
147 |
148 | def remove_document(self, doc_id):
149 | self.client.delete(
150 | collection_name=self.kb_id,
151 | filter=f'metadata["doc_id"] == "{doc_id}"'
152 | )
153 |
154 | def get_num_vectors(self):
155 | res = self.client.query(
156 | collection_name=self.kb_id,
157 | output_fields=["count(*)"]
158 | )
159 | return res[0]["count(*)"]
160 |
161 | def delete(self):
162 | self.client.drop_collection(collection_name=self.kb_id)
163 |
164 | def to_dict(self):
165 | return {
166 | **super().to_dict(),
167 | 'kb_id': self.kb_id,
168 | }
169 |
--------------------------------------------------------------------------------
/dsrag/database/vector/types.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Sequence, Union
2 | from typing_extensions import TypedDict
3 |
4 |
5 | class ChunkMetadata(TypedDict):
6 | doc_id: str
7 | chunk_text: str
8 | chunk_index: int
9 | chunk_header: str
10 |
11 |
12 | Vector = Union[Sequence[float], Sequence[int]]
13 |
14 |
15 | class VectorSearchResult(TypedDict):
16 | doc_id: Optional[str]
17 | vector: Optional[Vector]
18 | metadata: ChunkMetadata
19 | similarity: float
20 |
21 | class MetadataFilter(TypedDict):
22 | field: str
23 | operator: str # Can be one of the following: 'equals', 'not_equals', 'in', 'not_in', 'greater_than', 'less_than', 'greater_than_equals', 'less_than_equals'
24 | value: Union[str, int, float, list[str], list[int], list[float]]
--------------------------------------------------------------------------------
/dsrag/dsparse/MANIFEST.in:
--------------------------------------------------------------------------------
1 | exclude tests/*
--------------------------------------------------------------------------------
/dsrag/dsparse/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | # Configure the dsparse logger with a NullHandler to prevent "No handler found" warnings
4 | logger = logging.getLogger("dsrag.dsparse")
5 | logger.addHandler(logging.NullHandler())
--------------------------------------------------------------------------------
/dsrag/dsparse/file_parsing/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/D-Star-AI/dsRAG/b326c1d0930c9a80ff486ab0df5e771e788e3234/dsrag/dsparse/file_parsing/__init__.py
--------------------------------------------------------------------------------
/dsrag/dsparse/file_parsing/element_types.py:
--------------------------------------------------------------------------------
1 | from ..models.types import ElementType
2 |
3 |
4 | def get_visual_elements_as_str(elements: list[ElementType]) -> str:
5 | visual_elements = [element["name"] for element in elements if element["is_visual"]]
6 | last_element = visual_elements[-1]
7 | if len(visual_elements) > 1:
8 | visual_elements[-1] = f"and {last_element}"
9 | return ", ".join(visual_elements)
10 |
11 | def get_non_visual_elements_as_str(elements: list[ElementType]) -> str:
12 | non_visual_elements = [element["name"] for element in elements if not element["is_visual"]]
13 | last_element = non_visual_elements[-1]
14 | if len(non_visual_elements) > 1:
15 | non_visual_elements[-1] = f"and {last_element}"
16 | return ", ".join(non_visual_elements)
17 |
18 | def get_num_visual_elements(elements: list[ElementType]) -> int:
19 | return len([element for element in elements if element["is_visual"]])
20 |
21 | def get_num_non_visual_elements(elements: list[ElementType]) -> int:
22 | return len([element for element in elements if not element["is_visual"]])
23 |
24 | def get_element_description_block(elements: list[ElementType]) -> str:
25 | element_blocks = []
26 | for element in elements:
27 | element_block = ELEMENT_PROMPT.format(
28 | element_name=element["name"],
29 | element_instructions=element["instructions"],
30 | )
31 | element_blocks.append(element_block)
32 | return "\n".join(element_blocks)
33 |
34 | ELEMENT_PROMPT = """
35 | - {element_name}
36 | - {element_instructions}
37 | """.strip()
38 |
39 | default_element_types = [
40 | {
41 | "name": "NarrativeText",
42 | "instructions": "This is the main text content of the page, including paragraphs, lists, titles, and any other text content that is not part of one of the other more specialized element types. Not all pages have narrative text, but most do. Be sure to use Markdown formatting for the text content. This includes using tags like # for headers, * for lists, etc. Make sure your header tags are properly nested and that your lists are properly formatted.",
43 | "is_visual": False,
44 | },
45 | {
46 | "name": "Figure",
47 | "instructions": "This covers charts, graphs, diagrams, etc. Associated titles, legends, axis titles, etc. should be considered to be part of the figure.",
48 | "is_visual": True,
49 | },
50 | {
51 | "name": "Image",
52 | "instructions": "This is any visual content on the page that doesn't fit into any of the other categories. This could include photos, illustrations, etc. Any title or captions associated with the image should be considered part of the image.",
53 | "is_visual": True,
54 | },
55 | {
56 | "name": "Table",
57 | "instructions": "This covers any tabular data arrangement on the page, including simple and complex tables. Any titles, captions, or notes associated with the table should be considered part of the table element.",
58 | "is_visual": True,
59 | },
60 | {
61 | "name": "Equation",
62 | "instructions": "This covers mathematical equations, formulas, and expressions that appear on the page. Associated equation numbers or labels should be considered part of the equation.",
63 | "is_visual": True,
64 | },
65 | {
66 | "name": "Header",
67 | "instructions": "This is the header of the page, which would be located at the very top of the page and may include things like page numbers. The text in a Header element is usually in a very small font size. This is NOT the same things as a Markdown header or heading. You should never user more than one header element per page. Not all pages have a header. Note that headers are not the same as titles or subtitles within the main text content of the page. Those should be included in NarrativeText elements.",
68 | "is_visual": False,
69 | },
70 | {
71 | "name": "Footnote",
72 | "instructions": "Footnotes should always be included as a separate element from the main text content as they aren't part of the main linear reading flow of the page. Not all pages have footnotes.",
73 | "is_visual": False,
74 | },
75 | {
76 | "name": "Footer",
77 | "instructions": "This is the footer of the page, which would be located at the very bottom of the page. You should never user more than one footer element per page. Not all pages have a footer, but when they do it is always the very last element on the page.",
78 | "is_visual": False,
79 | }
80 | ]
81 |
82 |
--------------------------------------------------------------------------------
/dsrag/dsparse/file_parsing/non_vlm_file_parsing.py:
--------------------------------------------------------------------------------
1 | import PyPDF2
2 | import docx2txt
3 |
4 | def extract_text_from_pdf(file_path: str) -> tuple[str, list]:
5 | with open(file_path, 'rb') as file:
6 | # Create a PDF reader object
7 | pdf_reader = PyPDF2.PdfReader(file)
8 |
9 | # Initialize an empty string to store the extracted text
10 | extracted_text = ""
11 |
12 | # Loop through all the pages in the PDF file
13 | pages = []
14 | for page_num in range(len(pdf_reader.pages)):
15 | # Extract the text from the current page
16 | page_text = pdf_reader.pages[page_num].extract_text()
17 | pages.append(page_text)
18 |
19 | # Add the extracted text to the final text
20 | extracted_text += page_text
21 |
22 | return extracted_text, pages
23 |
24 | def extract_text_from_docx(file_path: str) -> str:
25 | return docx2txt.process(file_path)
26 |
27 | def parse_file_no_vlm(file_path: str) -> str:
28 | pdf_pages = None
29 | if file_path.endswith(".pdf"):
30 | text, pdf_pages = extract_text_from_pdf(file_path)
31 | elif file_path.endswith(".docx"):
32 | text = extract_text_from_docx(file_path)
33 | elif file_path.endswith(".txt") or file_path.endswith(".md"):
34 | with open(file_path, "r") as file:
35 | text = file.read()
36 | else:
37 | raise ValueError(
38 | "Unsupported file format. Only .txt, .md, .pdf and .docx files are supported."
39 | )
40 |
41 | return text, pdf_pages
--------------------------------------------------------------------------------
/dsrag/dsparse/file_parsing/vlm.py:
--------------------------------------------------------------------------------
1 | import PIL.Image
2 | import os
3 | import io
4 | from ..utils.imports import genai, vertexai
5 |
6 | def make_llm_call_gemini(image_path: str, system_message: str, model: str = "gemini-2.0-flash", response_schema: dict = None, max_tokens: int = 4000, temperature: float = 0.5) -> str:
7 | genai.configure(api_key=os.environ["GEMINI_API_KEY"])
8 | generation_config = {
9 | "temperature": temperature,
10 | "response_mime_type": "application/json",
11 | "max_output_tokens": max_tokens
12 | }
13 | if response_schema is not None:
14 | generation_config["response_schema"] = response_schema
15 |
16 | model = genai.GenerativeModel(model)
17 | try:
18 | image = PIL.Image.open(image_path)
19 | # Compress image if needed
20 | compressed_image = compress_image(image)
21 | response = model.generate_content(
22 | [
23 | compressed_image,
24 | system_message
25 | ],
26 | generation_config=generation_config
27 | )
28 | return response.text
29 | finally:
30 | # Ensure image is closed even if an error occurs
31 | if 'image' in locals():
32 | image.close()
33 |
34 | def make_llm_call_vertex(image_path: str, system_message: str, model: str, project_id: str, location: str, response_schema: dict = None, max_tokens: int = 4000, temperature: float = 0.5) -> str:
35 | """
36 | This function calls the Vertex AI Gemini API (not to be confused with the Gemini API) with an image and a system message and returns the response text.
37 | """
38 | vertexai.init(project=project_id, location=location)
39 | model = vertexai.generative_models.GenerativeModel(model)
40 |
41 | if response_schema is not None:
42 | generation_config = vertexai.generative_models.GenerationConfig(temperature=temperature, max_output_tokens=max_tokens, response_mime_type="application/json", response_schema=response_schema)
43 | else:
44 | generation_config = vertexai.generative_models.GenerationConfig(temperature=temperature, max_output_tokens=max_tokens)
45 |
46 | response = model.generate_content(
47 | [
48 | vertexai.generative_models.Part.from_image(vertexai.generative_models.Image.load_from_file(image_path)),
49 | system_message,
50 | ],
51 | generation_config=generation_config,
52 | )
53 | return response.text
54 |
55 | def compress_image(image: PIL.Image.Image, max_size_bytes: int = 1097152, quality: int = 85) -> tuple[PIL.Image.Image, int]:
56 | """
57 | Compress image if it exceeds file size while maintaining aspect ratio.
58 |
59 | Args:
60 | image: PIL Image object
61 | max_size_bytes: Maximum file size in bytes (default ~1MB)
62 | quality: Initial JPEG quality (0-100)
63 |
64 | Returns:
65 | Tuple of (compressed PIL Image object, final quality used)
66 | """
67 | output = io.BytesIO()
68 |
69 | # Initial compression
70 | image.save(output, format='JPEG', quality=quality)
71 |
72 | # Reduce quality if file is too large
73 | while output.tell() > max_size_bytes and quality > 10:
74 | output = io.BytesIO()
75 | quality -= 5
76 | image.save(output, format='JPEG', quality=quality)
77 |
78 | # If reducing quality didn't work, reduce dimensions
79 | if output.tell() > max_size_bytes:
80 | while output.tell() > max_size_bytes:
81 | width, height = image.size
82 | image = image.resize((int(width*0.9), int(height*0.9)), PIL.Image.Resampling.LANCZOS)
83 | output = io.BytesIO()
84 | image.save(output, format='JPEG', quality=quality)
85 |
86 | # Convert back to PIL Image
87 | output.seek(0)
88 | compressed_image = PIL.Image.open(output)
89 | compressed_image.load() # This is important to ensure the BytesIO can be closed
90 | return compressed_image
--------------------------------------------------------------------------------
/dsrag/dsparse/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/D-Star-AI/dsRAG/b326c1d0930c9a80ff486ab0df5e771e788e3234/dsrag/dsparse/models/__init__.py
--------------------------------------------------------------------------------
/dsrag/dsparse/models/types.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, TypedDict
2 | from ..file_parsing.file_system import FileSystem
3 |
4 | class ElementType(TypedDict):
5 | name: str
6 | instructions: str
7 | is_visual: bool
8 |
9 | class Element(TypedDict):
10 | type: str
11 | content: str
12 | page_number: Optional[int]
13 |
14 | class Line(TypedDict):
15 | content: str
16 | element_type: str
17 | page_number: Optional[int]
18 | is_visual: Optional[bool]
19 |
20 | class Section(TypedDict):
21 | title: str
22 | start: int
23 | end: int
24 | content: str
25 |
26 | class Chunk(TypedDict):
27 | line_start: int
28 | line_end: int
29 | content: str
30 | page_start: int
31 | page_end: int
32 | section_index: int
33 | is_visual: bool
34 |
35 | class VLMConfig(TypedDict):
36 | provider: Optional[str]
37 | model: Optional[str]
38 | project_id: Optional[str]
39 | location: Optional[str]
40 | exclude_elements: Optional[list[str]]
41 | element_types: Optional[list[ElementType]]
42 | max_workers: Optional[int] = None
43 |
44 | class SemanticSectioningConfig(TypedDict):
45 | use_semantic_sectioning: Optional[bool]
46 | llm_provider: Optional[str]
47 | model: Optional[str]
48 | language: Optional[str]
49 |
50 | class ChunkingConfig(TypedDict):
51 | chunk_size: Optional[int]
52 | min_length_for_chunking: Optional[int]
53 |
54 | class FileParsingConfig(TypedDict):
55 | use_vlm: Optional[bool]
56 | vlm_config: Optional[VLMConfig]
57 | always_save_page_images: Optional[bool]
--------------------------------------------------------------------------------
/dsrag/dsparse/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0.0", "wheel"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "dsparse"
7 | version = "0.0.1"
8 | description = "Multi-modal file parsing and chunking"
9 | readme = "README.md"
10 | authors = [{ name = "Zach McCormick", email = "zach@d-star.ai" }, { name = "Nick McCormick", email = "nick@d-star.ai" }]
11 | license = { file = "LICENSE" }
12 | classifiers=[
13 | "Development Status :: 4 - Beta",
14 | "Intended Audience :: Developers",
15 | "Intended Audience :: Science/Research",
16 | "License :: OSI Approved :: MIT License",
17 | "Operating System :: OS Independent",
18 | "Programming Language :: Python :: 3.9",
19 | "Programming Language :: Python :: 3.10",
20 | "Programming Language :: Python :: 3.11",
21 | "Programming Language :: Python :: 3.12",
22 | "Topic :: Database",
23 | "Topic :: Software Development",
24 | "Topic :: Software Development :: Libraries",
25 | "Topic :: Software Development :: Libraries :: Application Frameworks",
26 | "Topic :: Software Development :: Libraries :: Python Modules"
27 | ]
28 | requires-python = ">=3.9"
29 | dynamic = ["dependencies"]
30 |
31 | [tool.setuptools.dynamic]
32 | dependencies = {file = ["requirements.txt"]}
33 |
34 | [project.optional-dependencies]
35 | s3 = ["boto3>=1.34.142", "botocore>=1.34.142"]
36 | vertexai = ["vertexai>=1.70.0"]
37 | instructor = ["instructor>=0.4.0"]
38 | openai = ["openai>=1.0.0"]
39 | anthropic = ["anthropic>=0.5.0"]
40 | google-generativeai = ["google-generativeai>=0.3.0"]
41 |
42 | # Convenience groups
43 | all-models = [
44 | "instructor>=0.4.0",
45 | "openai>=1.0.0",
46 | "anthropic>=0.5.0",
47 | "google-generativeai>=0.3.0",
48 | "vertexai>=1.70.0"
49 | ]
50 |
51 | # Complete installation with all optional dependencies
52 | all = [
53 | "boto3>=1.34.142",
54 | "botocore>=1.34.142",
55 | "instructor>=0.4.0",
56 | "openai>=1.0.0",
57 | "anthropic>=0.5.0",
58 | "google-generativeai>=0.3.0",
59 | "vertexai>=1.70.0"
60 | ]
61 |
62 | [project.urls]
63 | Homepage = "https://github.com/D-Star-AI/dsRAG/blob/main/dsrag/dsparse/README.md"
64 | Documentation = "https://github.com/D-Star-AI/dsRAG/blob/main/dsrag/dsparse/README.md"
65 | Contact = "https://github.com/D-Star-AI/dsRAG/blob/main/dsrag/dsparse/README.md"
66 |
67 | [tool.setuptools.packages.find]
68 | where = ["."]
69 | include = ["dsparse"]
--------------------------------------------------------------------------------
/dsrag/dsparse/requirements.in:
--------------------------------------------------------------------------------
1 | # dsparse specific dependencies
2 | anthropic>=0.37.1
3 | docx2txt>=0.8
4 | google-generativeai>=0.8.3
5 | instructor>=1.7.0 # Required for streaming functionality
6 | langchain-text-splitters>=0.3.0
7 | langchain-core>=0.3.0
8 | openai>=1.52.2
9 | pdf2image>=1.16.0
10 | pillow>=10.0.0
11 | pydantic>=2.8.2
12 | PyPDF2>=3.0.1
13 | # tenacity and jiter are unpinned to allow instructor to select compatible versions
--------------------------------------------------------------------------------
/dsrag/dsparse/requirements.txt:
--------------------------------------------------------------------------------
1 | #
2 | # This file is autogenerated by pip-compile with Python 3.12
3 | # by the following command:
4 | #
5 | # pip-compile requirements.in
6 | #
7 | aiohappyeyeballs==2.4.3
8 | # via aiohttp
9 | aiohttp==3.10.10
10 | # via instructor
11 | aiosignal==1.3.1
12 | # via aiohttp
13 | annotated-types==0.7.0
14 | # via pydantic
15 | anthropic==0.37.1
16 | # via -r requirements.in
17 | anyio==4.6.2.post1
18 | # via
19 | # anthropic
20 | # httpx
21 | # openai
22 | attrs==24.2.0
23 | # via aiohttp
24 | cachetools==5.5.0
25 | # via google-auth
26 | certifi==2024.8.30
27 | # via
28 | # httpcore
29 | # httpx
30 | # requests
31 | charset-normalizer==3.4.0
32 | # via requests
33 | click==8.1.7
34 | # via typer
35 | distro==1.9.0
36 | # via
37 | # anthropic
38 | # openai
39 | docstring-parser==0.16
40 | # via instructor
41 | docx2txt==0.8
42 | # via -r requirements.in
43 | filelock==3.16.1
44 | # via huggingface-hub
45 | frozenlist==1.5.0
46 | # via
47 | # aiohttp
48 | # aiosignal
49 | fsspec==2024.10.0
50 | # via huggingface-hub
51 | google-ai-generativelanguage==0.6.10
52 | # via google-generativeai
53 | google-api-core[grpc]==2.21.0
54 | # via
55 | # google-ai-generativelanguage
56 | # google-api-python-client
57 | # google-generativeai
58 | google-api-python-client==2.149.0
59 | # via google-generativeai
60 | google-auth==2.35.0
61 | # via
62 | # google-ai-generativelanguage
63 | # google-api-core
64 | # google-api-python-client
65 | # google-auth-httplib2
66 | # google-generativeai
67 | google-auth-httplib2==0.2.0
68 | # via google-api-python-client
69 | google-generativeai==0.8.3
70 | # via -r requirements.in
71 | googleapis-common-protos==1.65.0
72 | # via
73 | # google-api-core
74 | # grpcio-status
75 | grpcio==1.67.0
76 | # via
77 | # google-api-core
78 | # grpcio-status
79 | grpcio-status==1.67.0
80 | # via google-api-core
81 | h11==0.14.0
82 | # via httpcore
83 | httpcore==1.0.6
84 | # via httpx
85 | httplib2==0.22.0
86 | # via
87 | # google-api-python-client
88 | # google-auth-httplib2
89 | httpx==0.27.2
90 | # via
91 | # anthropic
92 | # langsmith
93 | # openai
94 | huggingface-hub==0.26.1
95 | # via tokenizers
96 | idna==3.10
97 | # via
98 | # anyio
99 | # httpx
100 | # requests
101 | # yarl
102 | instructor==1.7.3
103 | # via -r requirements.in
104 | jinja2==3.1.6
105 | # via instructor
106 | jiter==0.8.2
107 | # via
108 | # anthropic
109 | # instructor
110 | # openai
111 | jsonpatch==1.33
112 | # via langchain-core
113 | jsonpointer==3.0.0
114 | # via jsonpatch
115 | langchain-core==0.3.42
116 | # via
117 | # -r requirements.in
118 | # langchain-text-splitters
119 | langchain-text-splitters==0.3.6
120 | # via -r requirements.in
121 | langsmith==0.1.137
122 | # via langchain-core
123 | markdown-it-py==3.0.0
124 | # via rich
125 | markupsafe==3.0.2
126 | # via jinja2
127 | mdurl==0.1.2
128 | # via markdown-it-py
129 | multidict==6.1.0
130 | # via
131 | # aiohttp
132 | # yarl
133 | openai==1.52.2
134 | # via
135 | # -r requirements.in
136 | # instructor
137 | orjson==3.10.10
138 | # via langsmith
139 | packaging==24.1
140 | # via
141 | # huggingface-hub
142 | # langchain-core
143 | pdf2image==1.16.0
144 | # via -r requirements.in
145 | pillow==10.0.0
146 | # via
147 | # -r requirements.in
148 | # pdf2image
149 | propcache==0.2.0
150 | # via yarl
151 | proto-plus==1.25.0
152 | # via
153 | # google-ai-generativelanguage
154 | # google-api-core
155 | protobuf==5.28.3
156 | # via
157 | # google-ai-generativelanguage
158 | # google-api-core
159 | # google-generativeai
160 | # googleapis-common-protos
161 | # grpcio-status
162 | # proto-plus
163 | pyasn1==0.6.1
164 | # via
165 | # pyasn1-modules
166 | # rsa
167 | pyasn1-modules==0.4.1
168 | # via google-auth
169 | pydantic==2.8.2
170 | # via
171 | # -r requirements.in
172 | # anthropic
173 | # google-generativeai
174 | # instructor
175 | # langchain-core
176 | # langsmith
177 | # openai
178 | pydantic-core==2.20.1
179 | # via
180 | # instructor
181 | # pydantic
182 | pygments==2.18.0
183 | # via rich
184 | pyparsing==3.2.0
185 | # via httplib2
186 | pypdf2==3.0.1
187 | # via -r requirements.in
188 | pyyaml==6.0.2
189 | # via
190 | # huggingface-hub
191 | # langchain-core
192 | requests==2.32.3
193 | # via
194 | # google-api-core
195 | # huggingface-hub
196 | # instructor
197 | # langsmith
198 | # requests-toolbelt
199 | requests-toolbelt==1.0.0
200 | # via langsmith
201 | rich==13.9.3
202 | # via
203 | # instructor
204 | # typer
205 | rsa==4.9
206 | # via google-auth
207 | shellingham==1.5.4
208 | # via typer
209 | sniffio==1.3.1
210 | # via
211 | # anthropic
212 | # anyio
213 | # httpx
214 | # openai
215 | tenacity==9.0.0
216 | # via
217 | # instructor
218 | # langchain-core
219 | tokenizers==0.20.1
220 | # via anthropic
221 | tqdm==4.66.5
222 | # via
223 | # google-generativeai
224 | # huggingface-hub
225 | # openai
226 | typer==0.12.5
227 | # via instructor
228 | typing-extensions==4.12.2
229 | # via
230 | # anthropic
231 | # google-generativeai
232 | # huggingface-hub
233 | # langchain-core
234 | # openai
235 | # pydantic
236 | # pydantic-core
237 | # typer
238 | uritemplate==4.1.1
239 | # via google-api-python-client
240 | urllib3==2.2.3
241 | # via requests
242 | yarl==1.16.0
243 | # via aiohttp
244 |
--------------------------------------------------------------------------------
/dsrag/dsparse/sectioning_and_chunking/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/D-Star-AI/dsRAG/b326c1d0930c9a80ff486ab0df5e771e788e3234/dsrag/dsparse/sectioning_and_chunking/__init__.py
--------------------------------------------------------------------------------
/dsrag/dsparse/tests/integration/test_vlm_file_parsing.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import unittest
4 | import shutil
5 |
6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..')))
7 | sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
8 | from dsparse.main import parse_and_chunk
9 | from dsparse.models.types import Chunk, Section
10 | from dsparse.file_parsing.file_system import LocalFileSystem
11 |
12 |
13 | class TestVLMFileParsing(unittest.TestCase):
14 |
15 | @classmethod
16 | def setUpClass(self):
17 | self.save_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../data/dsparse_output'))
18 | self.file_system = LocalFileSystem(base_path=self.save_path)
19 | self.kb_id = "test_kb"
20 | self.doc_id = "test_doc"
21 | self.test_data_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../tests/data/mck_energy_first_5_pages.pdf'))
22 | try:
23 | shutil.rmtree(self.save_path)
24 | except:
25 | pass
26 |
27 | def test__parse_and_chunk_vlm(self):
28 |
29 | vlm_config = {
30 | "provider": "gemini",
31 | "model": "gemini-2.0-flash",
32 | }
33 | semantic_sectioning_config = {
34 | "llm_provider": "openai",
35 | "model": "gpt-4o-mini",
36 | "language": "en",
37 | }
38 | file_parsing_config = {
39 | "use_vlm": True,
40 | "vlm_config": vlm_config,
41 | "always_save_page_images": True
42 | }
43 | sections, chunks = parse_and_chunk(
44 | kb_id=self.kb_id,
45 | doc_id=self.doc_id,
46 | file_path=self.test_data_path,
47 | file_parsing_config=file_parsing_config,
48 | semantic_sectioning_config=semantic_sectioning_config,
49 | chunking_config={},
50 | file_system=self.file_system,
51 | )
52 |
53 | # Make sure the sections and chunks were created, and are the correct types
54 | self.assertTrue(len(sections) > 0)
55 | self.assertTrue(len(chunks) > 0)
56 | self.assertEqual(type(sections), list)
57 | self.assertEqual(type(chunks), list)
58 |
59 | for key, expected_type in Section.__annotations__.items():
60 | self.assertIsInstance(sections[0][key], expected_type)
61 |
62 | for key, expected_type in Chunk.__annotations__.items():
63 | self.assertIsInstance(chunks[0][key], expected_type)
64 |
65 | self.assertTrue(len(sections[0]["title"]) > 0)
66 | self.assertTrue(len(sections[0]["content"]) > 0)
67 |
68 | # Make sure the elements.json file was created
69 | self.assertTrue(os.path.exists(os.path.join(self.save_path, self.kb_id, self.doc_id, "elements.json")))
70 |
71 | # Delete the save path
72 | try:
73 | shutil.rmtree(self.save_path)
74 | except:
75 | pass
76 |
77 |
78 | def test_non_vlm_file_parsing(self):
79 |
80 | semantic_sectioning_config = {
81 | "llm_provider": "openai",
82 | "model": "gpt-4o-mini",
83 | "language": "en",
84 | }
85 | file_parsing_config = {
86 | "use_vlm": False,
87 | "always_save_page_images": True,
88 | }
89 | sections, chunks = parse_and_chunk(
90 | kb_id=self.kb_id,
91 | doc_id=self.doc_id,
92 | file_path=self.test_data_path,
93 | file_parsing_config=file_parsing_config,
94 | semantic_sectioning_config=semantic_sectioning_config,
95 | chunking_config={},
96 | file_system=self.file_system,
97 | )
98 |
99 | # Make sure the sections and chunks were created, and are the correct types
100 | self.assertTrue(len(sections) > 0)
101 | self.assertTrue(len(chunks) > 0)
102 | self.assertEqual(type(sections), list)
103 | self.assertEqual(type(chunks), list)
104 |
105 | for key, expected_type in Section.__annotations__.items():
106 | self.assertIsInstance(sections[0][key], expected_type)
107 |
108 | for key, expected_type in Chunk.__annotations__.items():
109 | self.assertIsInstance(chunks[0][key], expected_type)
110 |
111 | self.assertTrue(len(sections[0]["title"]) > 0)
112 | self.assertTrue(len(sections[0]["content"]) > 0)
113 |
114 |
115 | @classmethod
116 | def tearDownClass(self):
117 | # Delete the save path
118 | try:
119 | shutil.rmtree(self.save_path)
120 | except:
121 | pass
122 |
123 | if __name__ == "__main__":
124 | unittest.main()
--------------------------------------------------------------------------------
/dsrag/dsparse/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Utility functions for dsparse.
3 | """
--------------------------------------------------------------------------------
/dsrag/dsparse/utils/imports.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for lazy imports of optional dependencies.
3 | """
4 | import importlib
5 |
6 | class LazyLoader:
7 | """
8 | Lazily import a module only when its attributes are accessed.
9 |
10 | This allows optional dependencies to be imported only when actually used,
11 | rather than at module import time.
12 |
13 | Usage:
14 | # Instead of: import instructor
15 | instructor = LazyLoader("instructor")
16 |
17 | # Then use instructor as normal - it will only be imported when accessed
18 | # If the module is not installed, a helpful error message is shown
19 | """
20 | def __init__(self, module_name, package_name=None):
21 | """
22 | Initialize a lazy loader for a module.
23 |
24 | Args:
25 | module_name: The name of the module to import
26 | package_name: Optional package name for pip install instructions
27 | (defaults to module_name if not provided)
28 | """
29 | self._module_name = module_name
30 | self._package_name = package_name or module_name
31 | self._module = None
32 |
33 | def __getattr__(self, name):
34 | """Called when an attribute is accessed."""
35 | if self._module is None:
36 | try:
37 | # Use importlib.import_module instead of __import__ for better handling of nested modules
38 | self._module = importlib.import_module(self._module_name)
39 | except ImportError:
40 | raise ImportError(
41 | f"The '{self._module_name}' module is required but not installed. "
42 | f"Please install it with: pip install {self._package_name} "
43 | f"or pip install dsparse[{self._package_name}]"
44 | )
45 |
46 | # Try to get the attribute from the module
47 | try:
48 | return getattr(self._module, name)
49 | except AttributeError:
50 | # If the attribute is not found, it might be a nested module
51 | try:
52 | # Try to import the nested module
53 | nested_module = importlib.import_module(f"{self._module_name}.{name}")
54 | # Cache it on the module for future access
55 | setattr(self._module, name, nested_module)
56 | return nested_module
57 | except ImportError:
58 | # If that fails, re-raise the original AttributeError
59 | raise AttributeError(
60 | f"Module '{self._module_name}' has no attribute '{name}'"
61 | )
62 |
63 | # Create lazy loaders for dependencies used in dsparse
64 | instructor = LazyLoader("instructor")
65 | openai = LazyLoader("openai")
66 | anthropic = LazyLoader("anthropic")
67 | genai = LazyLoader("google.generativeai", "google-generativeai")
68 | vertexai = LazyLoader("vertexai")
69 | boto3 = LazyLoader("boto3")
--------------------------------------------------------------------------------
/dsrag/metadata.py:
--------------------------------------------------------------------------------
1 | # Metadata storage handling
2 | import os
3 | from decimal import Decimal
4 | import json
5 | from typing import Any
6 | from abc import ABC, abstractmethod
7 | from dsrag.utils.imports import boto3
8 |
9 | class MetadataStorage(ABC):
10 |
11 | def __init__(self) -> None:
12 | pass
13 |
14 | @abstractmethod
15 | def load(self) -> dict:
16 | pass
17 |
18 | @abstractmethod
19 | def save(self, kb: dict) -> None:
20 | pass
21 |
22 | @abstractmethod
23 | def delete(self) -> None:
24 | pass
25 |
26 | class LocalMetadataStorage(MetadataStorage):
27 |
28 | def __init__(self, storage_directory: str) -> None:
29 | super().__init__()
30 | self.storage_directory = storage_directory
31 |
32 | def get_metadata_path(self, kb_id: str) -> str:
33 | return os.path.join(self.storage_directory, "metadata", f"{kb_id}.json")
34 |
35 | def kb_exists(self, kb_id: str) -> bool:
36 | metadata_path = self.get_metadata_path(kb_id)
37 | return os.path.exists(metadata_path)
38 |
39 | def load(self, kb_id: str) -> dict:
40 | metadata_path = self.get_metadata_path(kb_id)
41 | with open(metadata_path, "r") as f:
42 | data = json.load(f)
43 | return data
44 |
45 | def save(self, full_data: dict, kb_id: str):
46 | metadata_path = self.get_metadata_path(kb_id)
47 | metadata_dir = os.path.join(self.storage_directory, "metadata")
48 | if not os.path.exists(metadata_dir):
49 | os.makedirs(metadata_dir)
50 |
51 | with open(metadata_path, "w") as f:
52 | json.dump(full_data, f, indent=4)
53 |
54 | def delete(self, kb_id: str):
55 | metadata_path = self.get_metadata_path(kb_id)
56 | os.remove(metadata_path)
57 |
58 |
59 |
60 | def convert_numbers_to_decimal(obj: Any) -> Any:
61 | """
62 | Recursively traverse the object, converting all integers and floats to Decimals
63 | """
64 | if isinstance(obj, dict):
65 | return {k: convert_numbers_to_decimal(v) for k, v in obj.items()}
66 | elif isinstance(obj, list):
67 | return [convert_numbers_to_decimal(item) for item in obj]
68 | elif isinstance(obj, bool):
69 | return obj # Leave booleans as is
70 | elif isinstance(obj, (int, float)):
71 | return Decimal(str(obj))
72 | else:
73 | return obj
74 |
75 |
76 | def convert_decimal_to_numbers(obj: Any) -> Any:
77 | """
78 | Recursively traverse the object, converting all Decimals to integers or floats
79 | """
80 | if isinstance(obj, dict):
81 | return {k: convert_decimal_to_numbers(v) for k, v in obj.items()}
82 | elif isinstance(obj, list):
83 | return [convert_decimal_to_numbers(item) for item in obj]
84 | elif isinstance(obj, Decimal):
85 | if obj == Decimal('1') or obj == Decimal('0'):
86 | # Convert to bool
87 | return bool(obj)
88 | elif obj % 1 == 0:
89 | # Convert to int
90 | return int(obj)
91 | else:
92 | # Convert to float
93 | return float(obj)
94 | else:
95 | return obj
96 |
97 |
98 | class DynamoDBMetadataStorage(MetadataStorage):
99 |
100 | def __init__(self, table_name: str, region_name: str, access_key: str, secret_key: str) -> None:
101 | self.table_name = table_name
102 | self.region_name = region_name
103 | self.access_key = access_key
104 | self.secret_key = secret_key
105 |
106 | def create_dynamo_client(self):
107 | dynamodb_client = boto3.resource(
108 | 'dynamodb',
109 | region_name=self.region_name,
110 | aws_access_key_id=self.access_key,
111 | aws_secret_access_key=self.secret_key,
112 | )
113 | return dynamodb_client
114 |
115 | def create_table(self):
116 | dynamodb_client = self.create_dynamo_client()
117 | try:
118 | dynamodb_client.create_table(
119 | TableName="metadata_storage",
120 | KeySchema=[
121 | {
122 | 'AttributeName': 'kb_id',
123 | 'KeyType': 'HASH' # Partition key
124 | }
125 | ],
126 | AttributeDefinitions=[
127 | {
128 | 'AttributeName': 'kb_id',
129 | 'AttributeType': 'S' # 'S' for String
130 | }
131 | ],
132 | BillingMode='PAY_PER_REQUEST'
133 | )
134 | except Exception as e:
135 | # Probably the table already exists
136 | print (e)
137 |
138 | def kb_exists(self, kb_id: str) -> bool:
139 | dynamodb_client = self.create_dynamo_client()
140 | table = dynamodb_client.Table(self.table_name)
141 | response = table.get_item(Key={'kb_id': kb_id})
142 | return 'Item' in response
143 |
144 | def load(self, kb_id: str) -> dict:
145 | # Read the data from the dynamo table
146 | dynamodb_client = self.create_dynamo_client()
147 | table = dynamodb_client.Table(self.table_name)
148 | response = table.get_item(Key={'kb_id': kb_id})
149 | data = response.get('Item', {}).get('metadata', {})
150 | kb_metadata = {
151 | key: value for key, value in data.items() if key != "components"
152 | }
153 | components = data.get("components", {})
154 | full_data = {**kb_metadata, "components": components}
155 | converted_data = convert_decimal_to_numbers(full_data)
156 | return converted_data
157 |
158 | def save(self, full_data: dict, kb_id: str) -> None:
159 |
160 | # Check if any of the items are a float or int, and convert them to Decimal
161 | converted_data = convert_numbers_to_decimal(full_data)
162 |
163 | # Upload this data to the dynamo table, where the kb_id is the primary key
164 | dynamodb_client = self.create_dynamo_client()
165 | table = dynamodb_client.Table(self.table_name)
166 | table.put_item(Item={'kb_id': kb_id, 'metadata': converted_data})
167 |
168 | def delete(self, kb_id: str) -> None:
169 | dynamodb_client = self.create_dynamo_client()
170 | table = dynamodb_client.Table(self.table_name)
171 | table.delete_item(Key={'kb_id': kb_id})
172 |
173 |
--------------------------------------------------------------------------------
/dsrag/reranker.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | import os
3 | from dsrag.utils.imports import cohere, voyageai
4 | from scipy.stats import beta
5 |
6 |
7 | class Reranker(ABC):
8 | subclasses = {}
9 |
10 | def __init_subclass__(cls, **kwargs):
11 | super().__init_subclass__(**kwargs)
12 | cls.subclasses[cls.__name__] = cls
13 |
14 | def to_dict(self):
15 | return {
16 | 'subclass_name': self.__class__.__name__,
17 | }
18 |
19 | @classmethod
20 | def from_dict(cls, config):
21 | subclass_name = config.pop('subclass_name', None) # Remove subclass_name from config
22 | subclass = cls.subclasses.get(subclass_name)
23 | if subclass:
24 | return subclass(**config) # Pass the modified config without subclass_name
25 | else:
26 | raise ValueError(f"Unknown subclass: {subclass_name}")
27 |
28 | @abstractmethod
29 | def rerank_search_results(self, query: str, search_results: list) -> list:
30 | pass
31 |
32 | class CohereReranker(Reranker):
33 | def __init__(self, model: str = "rerank-english-v3.0"):
34 | self.model = model
35 | cohere_api_key = os.environ['CO_API_KEY']
36 | base_url = os.environ.get("DSRAG_COHERE_BASE_URL", None)
37 | if base_url is not None:
38 | self.client = cohere.Client(api_key=cohere_api_key)
39 | else:
40 | self.client = cohere.Client(api_key=cohere_api_key)
41 |
42 | def transform(self, x):
43 | """
44 | transformation function to map the absolute relevance value to a value that is more uniformly distributed between 0 and 1
45 | - this is critical for the new version of RSE to work properly, because it utilizes the absolute relevance values to calculate the similarity scores
46 | """
47 | a, b = 0.4, 0.4 # These can be adjusted to change the distribution shape
48 | return beta.cdf(x, a, b)
49 |
50 | def rerank_search_results(self, query: str, search_results: list) -> list:
51 | """
52 | Use Cohere Rerank API to rerank the search results
53 | """
54 | documents = []
55 | for result in search_results:
56 | documents.append(f"{result['metadata']['chunk_header']}\n\n{result['metadata']['chunk_text']}")
57 |
58 | reranked_results = self.client.rerank(model=self.model, query=query, documents=documents)
59 | results = reranked_results.results
60 | reranked_indices = [result.index for result in results]
61 | reranked_similarity_scores = [result.relevance_score for result in results]
62 | reranked_search_results = [search_results[i] for i in reranked_indices]
63 | for i, result in enumerate(reranked_search_results):
64 | result['similarity'] = self.transform(reranked_similarity_scores[i])
65 | return reranked_search_results
66 |
67 | def to_dict(self):
68 | base_dict = super().to_dict()
69 | base_dict.update({
70 | 'model': self.model
71 | })
72 | return base_dict
73 |
74 | class VoyageReranker(Reranker):
75 | def __init__(self, model: str = "rerank-2"):
76 | self.model = model
77 | voyage_api_key = os.environ['VOYAGE_API_KEY']
78 | self.client = voyageai.Client(api_key=voyage_api_key)
79 |
80 | def transform(self, x):
81 | """
82 | transformation function to map the absolute relevance value to a value that is more uniformly distributed between 0 and 1
83 | - this is critical for the new version of RSE to work properly, because it utilizes the absolute relevance values to calculate the similarity scores
84 | """
85 | a, b = 0.5, 1.8 # These can be adjusted to change the distribution shape
86 | return beta.cdf(x, a, b)
87 |
88 | def rerank_search_results(self, query: str, search_results: list) -> list:
89 | """
90 | Use Voyage Rerank API to rerank the search results
91 | """
92 | documents = []
93 | for result in search_results:
94 | documents.append(f"{result['metadata']['chunk_header']}\n\n{result['metadata']['chunk_text']}")
95 |
96 | reranked_results = self.client.rerank(model=self.model, query=query, documents=documents)
97 | results = reranked_results.results
98 | reranked_indices = [result.index for result in results]
99 | reranked_similarity_scores = [result.relevance_score for result in results]
100 | reranked_search_results = [search_results[i] for i in reranked_indices]
101 | for i, result in enumerate(reranked_search_results):
102 | result['similarity'] = self.transform(reranked_similarity_scores[i])
103 | return reranked_search_results
104 |
105 | def to_dict(self):
106 | base_dict = super().to_dict()
107 | base_dict.update({
108 | 'model': self.model
109 | })
110 | return base_dict
111 |
112 | class NoReranker(Reranker):
113 | def __init__(self, ignore_absolute_relevance: bool = False):
114 | """
115 | - ignore_absolute_relevance: if True, the reranker will override the absolute relevance values and assign a default similarity score to each chunk. This is useful when using an embedding model where the absolute relevance values are not reliable or meaningful.
116 | """
117 | self.ignore_absolute_relevance = ignore_absolute_relevance
118 |
119 | def rerank_search_results(self, query: str, search_results: list) -> list:
120 | if self.ignore_absolute_relevance:
121 | for result in search_results:
122 | result['similarity'] = 0.8 # default similarity score (represents a moderately relevant chunk)
123 | return search_results
124 |
125 | def to_dict(self):
126 | base_dict = super().to_dict()
127 | base_dict.update({
128 | 'ignore_absolute_relevance': self.ignore_absolute_relevance,
129 | })
130 | return base_dict
--------------------------------------------------------------------------------
/dsrag/utils/imports.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for lazy imports of optional dependencies.
3 | """
4 | import importlib
5 |
6 | class LazyLoader:
7 | """
8 | Lazily import a module only when its attributes are accessed.
9 |
10 | This allows optional dependencies to be imported only when actually used,
11 | rather than at module import time.
12 |
13 | Usage:
14 | # Instead of: import chromadb
15 | chromadb = LazyLoader("chromadb")
16 |
17 | # Then use chromadb as normal - it will only be imported when accessed
18 | # If the module is not installed, a helpful error message is shown
19 | """
20 | def __init__(self, module_name, package_name=None):
21 | """
22 | Initialize a lazy loader for a module.
23 |
24 | Args:
25 | module_name: The name of the module to import
26 | package_name: Optional package name for pip install instructions
27 | (defaults to module_name if not provided)
28 | """
29 | self._module_name = module_name
30 | self._package_name = package_name or module_name
31 | self._module = None
32 |
33 | def __getattr__(self, name):
34 | """Called when an attribute is accessed."""
35 | if self._module is None:
36 | try:
37 | # Use importlib.import_module instead of __import__ for better handling of nested modules
38 | self._module = importlib.import_module(self._module_name)
39 | except ImportError:
40 | raise ImportError(
41 | f"The '{self._module_name}' module is required but not installed. "
42 | f"Please install it with: pip install {self._package_name} "
43 | f"or pip install dsrag[{self._package_name}]"
44 | )
45 |
46 | # Try to get the attribute from the module
47 | try:
48 | return getattr(self._module, name)
49 | except AttributeError:
50 | # If the attribute is not found, it might be a nested module
51 | try:
52 | # Try to import the nested module
53 | nested_module = importlib.import_module(f"{self._module_name}.{name}")
54 | # Cache it on the module for future access
55 | setattr(self._module, name, nested_module)
56 | return nested_module
57 | except ImportError:
58 | # If that fails, re-raise the original AttributeError
59 | raise AttributeError(
60 | f"Module '{self._module_name}' has no attribute '{name}'"
61 | )
62 |
63 | # Create lazy loaders for commonly used optional dependencies
64 | instructor = LazyLoader("instructor")
65 | openai = LazyLoader("openai")
66 | cohere = LazyLoader("cohere")
67 | voyageai = LazyLoader("voyageai")
68 | ollama = LazyLoader("ollama")
69 | anthropic = LazyLoader("anthropic")
70 | genai = LazyLoader("google.generativeai", "google-generativeai")
71 | genai_new = LazyLoader("google.genai", "google-genai")
72 | boto3 = LazyLoader("boto3")
73 | faiss = LazyLoader("faiss", "faiss-cpu")
74 | psycopg2 = LazyLoader("psycopg2", "psycopg2-binary")
75 | pgvector = LazyLoader("pgvector")
--------------------------------------------------------------------------------
/eval/financebench/dsrag_financebench_create_kb.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | # add ../../ to sys.path
5 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
6 |
7 | # import the necessary modules
8 | from dsrag.knowledge_base import KnowledgeBase
9 | from dsrag.embedding import CohereEmbedding
10 |
11 | document_directory = "" # path to the directory containing the documents, links to which are contained in the `tests/data/financebench_sample_150.csv` file
12 | kb_id = "finance_bench"
13 |
14 | """
15 | Note: to run this, you'll need to first download all of the documents (~80 of them, mostly 10-Ks and 10-Qs) and convert them to txt files. You'll want to make sure they all have consistent and meaningful file names, such as costco_2023_10k.txt, to match the performance we quote.
16 | """
17 |
18 | # create a new KnowledgeBase object
19 | embedding_model = CohereEmbedding()
20 | kb = KnowledgeBase(kb_id, embedding_model=embedding_model, exists_ok=False)
21 |
22 | for file_name in os.listdir(document_directory):
23 | with open(os.path.join(document_directory, file_name), "r") as f:
24 | text = f.read()
25 | kb.add_document(doc_id=file_name, text=text)
26 | print (f"Added {file_name} to the knowledge base.")
--------------------------------------------------------------------------------
/eval/financebench/dsrag_financebench_eval.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import os
3 | import sys
4 |
5 | # add ../../ to sys.path
6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
7 |
8 | # import the necessary modules
9 | from dsrag.knowledge_base import KnowledgeBase
10 | from dsrag.auto_query import get_search_queries
11 | from dsrag.llm import OpenAIChatAPI
12 |
13 | AUTO_QUERY_GUIDANCE = """
14 | The knowledge base contains SEC filings for publicly traded companies, like 10-Ks, 10-Qs, and 8-Ks. Keep this in mind when generating search queries. The things you search for should be things that are likely to be found in these documents.
15 |
16 | When deciding what to search for, first consider the pieces of information that will be needed to answer the question. Then, consider what to search for to find those pieces of information. For example, if the question asks what the change in revenue was from 2019 to 2020, you would want to search for the 2019 and 2020 revenue numbers in two separate search queries, since those are the two separate pieces of information needed. You should also think about where you are most likely to find the information you're looking for. If you're looking for assets and liabilities, you may want to search for the balance sheet, for example.
17 | """.strip()
18 |
19 | RESPONSE_SYSTEM_MESSAGE = """
20 | You are a response generation system. Please generate a response to the user input based on the provided context. Your response should be as concise as possible while still fully answering the user's question.
21 |
22 | CONTEXT
23 | {context}
24 | """.strip()
25 |
26 | def get_response(question: str, context: str):
27 | client = OpenAIChatAPI(model="gpt-4-turbo", temperature=0.0)
28 | chat_messages = [{"role": "system", "content": RESPONSE_SYSTEM_MESSAGE.format(context=context)}, {"role": "user", "content": question}]
29 | return client.make_llm_call(chat_messages)
30 |
31 | # Get the absolute path of the script file
32 | script_dir = os.path.dirname(os.path.abspath(__file__))
33 |
34 | # Construct the absolute path to the dataset file
35 | dataset_file_path = os.path.join(script_dir, "../../tests/data/financebench_sample_150.csv")
36 |
37 | # Read in the data
38 | df = pd.read_csv(dataset_file_path)
39 |
40 | # get the questions and answers from the question column - turn into lists
41 | questions = df.question.tolist()
42 | answers = df.answer.tolist()
43 |
44 | # load the knowledge base
45 | kb = KnowledgeBase("finance_bench")
46 |
47 | # set parameters for relevant segment extraction - slightly longer context than default
48 | rse_params = {
49 | 'max_length': 12,
50 | 'overall_max_length': 30,
51 | 'overall_max_length_extension': 6,
52 | }
53 |
54 | # adjust range if you only want to run a subset of the questions (there are 150 total)
55 | for i in range(150):
56 | print (f"Question {i+1}")
57 | question = questions[i]
58 | answer = answers[i]
59 | search_queries = get_search_queries(question, max_queries=6, auto_query_guidance=AUTO_QUERY_GUIDANCE)
60 | relevant_segments = kb.query(search_queries)
61 | context = "\n\n".join([segment['text'] for segment in relevant_segments])
62 | response = get_response(question, context)
63 | print (f"\nQuestion: {question}")
64 | print (f"\nSearch queries: {search_queries}")
65 | print (f"\nModel response: {response}")
66 | print (f"\nGround truth answer: {answer}")
67 | print ("\n---\n")
--------------------------------------------------------------------------------
/eval/rse_eval.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
5 |
6 | from dsrag.create_kb import create_kb_from_file
7 | from dsrag.knowledge_base import KnowledgeBase
8 | from dsrag.reranker import NoReranker, CohereReranker, VoyageReranker
9 | from dsrag.embedding import OpenAIEmbedding, CohereEmbedding
10 | from dsrag.dsparse.file_parsing.non_vlm_file_parsing import extract_text_from_pdf
11 |
12 |
13 | def test_create_kb_from_file():
14 | cleanup() # delete the KnowledgeBase object if it exists so we can start fresh
15 |
16 | # Get the absolute path of the script file
17 | script_dir = os.path.dirname(os.path.abspath(__file__))
18 |
19 | # Construct the absolute path to the dataset file
20 | file_path = os.path.join(script_dir, "../tests/data/levels_of_agi.pdf")
21 |
22 | kb_id = "levels_of_agi"
23 | #kb = create_kb_from_file(kb_id, file_path)
24 | kb = KnowledgeBase(kb_id=kb_id, exists_ok=False)
25 | text = extract_text_from_pdf(file_path)
26 |
27 | auto_context_config = {
28 | 'use_generated_title': True,
29 | 'get_document_summary': True,
30 | 'get_section_summaries': True,
31 | }
32 |
33 | kb.add_document(file_path, text, auto_context_config=auto_context_config)
34 |
35 | def cleanup():
36 | kb = KnowledgeBase(kb_id="levels_of_agi", exists_ok=True)
37 | kb.delete()
38 |
39 | """
40 | This script is for doing qualitative evaluation of RSE
41 | """
42 |
43 |
44 | if __name__ == "__main__":
45 | # create the KnowledgeBase
46 | try:
47 | test_create_kb_from_file()
48 | except ValueError as e:
49 | print(e)
50 |
51 | rse_params = {
52 | 'max_length': 20,
53 | 'overall_max_length': 30,
54 | 'minimum_value': 0.5,
55 | 'irrelevant_chunk_penalty': 0.2,
56 | 'overall_max_length_extension': 5,
57 | 'decay_rate': 30,
58 | 'top_k_for_document_selection': 7,
59 | }
60 |
61 | #reranker = NoReranker(ignore_absolute_relevance=True)
62 | reranker = CohereReranker()
63 | #reranker = VoyageReranker()
64 |
65 | # load the KnowledgeBase and query it
66 | kb = KnowledgeBase(kb_id="levels_of_agi", exists_ok=True, reranker=reranker)
67 | #print (kb.chunk_db.get_all_doc_ids())
68 |
69 | #search_queries = ["What are the levels of AGI?"]
70 | #search_queries = ["Who is the president of the United States?"]
71 | #search_queries = ["AI"]
72 | #search_queries = ["What is the difference between AGI and ASI?"]
73 | #search_queries = ["How does autonomy factor into AGI?"]
74 | #search_queries = ["Self-driving cars"]
75 | search_queries = ["Methodology for determining levels of AGI"]
76 | #search_queries = ["Principles for defining levels of AGI"]
77 | #search_queries = ["What is Autonomy Level 3"]
78 | #search_queries = ["Use of existing AI benchmarks like Big-bench and HELM"]
79 | #search_queries = ["Introduction", "Conclusion"]
80 | #search_queries = ["References"]
81 |
82 | relevant_segments = kb.query(search_queries=search_queries, rse_params=rse_params)
83 |
84 | print ()
85 | for segment in relevant_segments:
86 | #print (len(segment["text"]))
87 | #print (segment["score"])
88 | print (segment["text"])
89 | print ("---\n")
--------------------------------------------------------------------------------
/eval/vlm_eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from typing import Dict, List
4 | import PIL.Image
5 | import google.generativeai as genai
6 |
7 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
8 | from dsrag.knowledge_base import KnowledgeBase
9 |
10 |
11 | def create_kb_and_add_document(kb_id: str, file_path: str, document_title: str):
12 | kb = KnowledgeBase(kb_id=kb_id, exists_ok=True)
13 | kb.delete()
14 |
15 | kb = KnowledgeBase(kb_id=kb_id, exists_ok=False)
16 | kb.add_document(
17 | doc_id="test_doc_1",
18 | file_path=file_path,
19 | document_title=document_title,
20 | file_parsing_config={
21 | "use_vlm": True,
22 | "vlm_config": {
23 | "provider": "gemini",
24 | "model": "gemini-1.5-flash-002",
25 | }
26 | }
27 | )
28 |
29 | return kb
30 |
31 | def get_response(user_input: str, search_results: List[Dict], model_name: str = "gemini-1.5-pro-002"):
32 | """
33 | Given a user input and potentially multimodal search results, generate a response.
34 |
35 | We'll use the Gemini API since we already have code for it.
36 | """
37 | genai.configure(api_key=os.environ["GEMINI_API_KEY"])
38 | generation_config = {
39 | "temperature": 0.0,
40 | "max_output_tokens": 2000
41 | }
42 |
43 | model = genai.GenerativeModel(model_name)
44 |
45 | all_image_paths = []
46 | for search_result in search_results:
47 | if type(search_result["content"]) == list:
48 | all_image_paths += search_result["content"]
49 |
50 | images = [PIL.Image.open(image_path) for image_path in all_image_paths]
51 | system_message = "Please cite the page number of the document used to answer the question."
52 | generation_input = images + [user_input, system_message] # user input is the last element
53 | response = model.generate_content(
54 | contents=generation_input,
55 | generation_config=generation_config
56 | )
57 | [image.close() for image in images]
58 | return response.text
59 |
60 |
61 | kb_id = "state_of_ai_flash"
62 |
63 | """
64 | # create or load KB
65 | file_path = "/Users/zach/Code/test_docs/state_of_ai_report_2024.pdf"
66 | document_title = "State of AI Report 2024"
67 | kb = create_kb_and_add_document(kb_id=kb_id, file_path=file_path, document_title=document_title)
68 |
69 | """
70 |
71 | kb = KnowledgeBase(kb_id=kb_id, exists_ok=True)
72 |
73 | #query = "What is the McKinsey Energy Report about?"
74 | #query = "How does the oil and gas industry compare to other industries in terms of its value prop to employees?"
75 |
76 | query = "Who is Nathan Benaich?" # page 2
77 | #query = "Which country had the most AI publications in 2024?" # page 84
78 | #query = "Did frontier labs increase or decrease publications in 2024? By how much?" # page 84
79 | #query = "Who has the most H100s? How many do they have?" # page 92
80 | #query = "How many H100s does Lambda have?" # page 92
81 |
82 | rse_params = {
83 | "minimum_value": 0.5,
84 | "irrelevant_chunk_penalty": 0.2,
85 | }
86 |
87 | search_results = kb.query(search_queries=[query], rse_params=rse_params, return_mode="page_images")
88 | print (search_results)
89 |
90 | response = get_response(user_input=query, search_results=search_results)
91 | print(response)
--------------------------------------------------------------------------------
/examples/example_kb_data/chunk_storage/levels_of_agi.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/D-Star-AI/dsRAG/b326c1d0930c9a80ff486ab0df5e771e788e3234/examples/example_kb_data/chunk_storage/levels_of_agi.pkl
--------------------------------------------------------------------------------
/examples/example_kb_data/chunk_storage/nike_10k.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/D-Star-AI/dsRAG/b326c1d0930c9a80ff486ab0df5e771e788e3234/examples/example_kb_data/chunk_storage/nike_10k.pkl
--------------------------------------------------------------------------------
/examples/example_kb_data/metadata/levels_of_agi.json:
--------------------------------------------------------------------------------
1 | {
2 | "title": "",
3 | "description": "",
4 | "language": "en",
5 | "chunk_size": 800,
6 | "components": {
7 | "embedding_model": {
8 | "subclass_name": "OpenAIEmbedding",
9 | "dimension": 768,
10 | "model": "text-embedding-3-small"
11 | },
12 | "reranker": {
13 | "subclass_name": "CohereReranker",
14 | "model": "rerank-english-v3.0"
15 | },
16 | "auto_context_model": {
17 | "subclass_name": "AnthropicChatAPI",
18 | "model": "claude-3-haiku-20240307",
19 | "temperature": 0.2,
20 | "max_tokens": 1000
21 | },
22 | "vector_db": {
23 | "subclass_name": "BasicVectorDB",
24 | "kb_id": "levels_of_agi",
25 | "storage_directory": "example_kb_data",
26 | "use_faiss": true
27 | },
28 | "chunk_db": {
29 | "subclass_name": "BasicChunkDB",
30 | "kb_id": "levels_of_agi",
31 | "storage_directory": "example_kb_data"
32 | }
33 | }
34 | }
--------------------------------------------------------------------------------
/examples/example_kb_data/metadata/nike_10k.json:
--------------------------------------------------------------------------------
1 | {
2 | "title": "",
3 | "description": "",
4 | "language": "en",
5 | "chunk_size": 800,
6 | "components": {
7 | "embedding_model": {
8 | "subclass_name": "OpenAIEmbedding",
9 | "dimension": 768,
10 | "model": "text-embedding-3-small"
11 | },
12 | "reranker": {
13 | "subclass_name": "CohereReranker",
14 | "model": "rerank-english-v3.0"
15 | },
16 | "auto_context_model": {
17 | "subclass_name": "AnthropicChatAPI",
18 | "model": "claude-3-haiku-20240307",
19 | "temperature": 0.2,
20 | "max_tokens": 1000
21 | },
22 | "vector_db": {
23 | "subclass_name": "BasicVectorDB",
24 | "kb_id": "nike_10k",
25 | "storage_directory": "example_kb_data",
26 | "use_faiss": true
27 | },
28 | "chunk_db": {
29 | "subclass_name": "BasicChunkDB",
30 | "kb_id": "nike_10k",
31 | "storage_directory": "example_kb_data"
32 | },
33 | "file_system": {
34 | "subclass_name": "LocalFileSystem",
35 | "base_path": "example_kb_data"
36 | }
37 | }
38 | }
--------------------------------------------------------------------------------
/examples/example_kb_data/vector_storage/levels_of_agi.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/D-Star-AI/dsRAG/b326c1d0930c9a80ff486ab0df5e771e788e3234/examples/example_kb_data/vector_storage/levels_of_agi.pkl
--------------------------------------------------------------------------------
/examples/example_kb_data/vector_storage/nike_10k.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/D-Star-AI/dsRAG/b326c1d0930c9a80ff486ab0df5e771e788e3234/examples/example_kb_data/vector_storage/nike_10k.pkl
--------------------------------------------------------------------------------
/integrations/langchain_retriever.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict, Union
2 |
3 | from langchain_core.documents import Document
4 | from langchain_core.retrievers import BaseRetriever
5 | from langchain_core.callbacks import CallbackManagerForRetrieverRun
6 |
7 | from dsrag.knowledge_base import KnowledgeBase
8 |
9 | class DsRAGLangchainRetriever(BaseRetriever):
10 | """
11 | A retriever for the DsRAG project that uses the KnowledgeBase class to retrieve documents.
12 | """
13 |
14 | kb_id: str = ""
15 | rse_params: Union[Dict, str] = "balanced"
16 |
17 | def __init__(self, kb_id: str, rse_params: Union[Dict, str] = "balanced"):
18 | super().__init__()
19 | self.kb_id = kb_id
20 | self.rse_params = rse_params
21 |
22 |
23 | """
24 | Initializes a DsRAGLangchainRetriever instance.
25 |
26 | Args:
27 | kb_id: A KnowledgeBase id to use for retrieving documents.
28 | rse_params: The RSE parameters to use for querying
29 | """
30 |
31 |
32 | def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]:
33 |
34 | kb = KnowledgeBase(kb_id=self.kb_id, exists_ok=True)
35 |
36 | segment_info = kb.query([query], rse_params=self.rse_params)
37 | # Create a Document object for each segment
38 | documents = []
39 | for segment in segment_info:
40 | document = Document(
41 | page_content=segment["text"],
42 | metadata={"doc_id": segment["doc_id"], "chunk_start": segment["chunk_start"], "chunk_end": segment["chunk_end"]},
43 | )
44 | documents.append(document)
45 |
46 | return documents
47 |
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: dsRAG
2 | site_description: Documentation for the dsRAG library
3 | site_url: https://d-star-ai.github.io/dsRAG/
4 | theme:
5 | name: material
6 | logo: logo.png
7 | favicon: logo.png
8 | features:
9 | - navigation.tabs
10 | - navigation.sections
11 | - toc.integrate
12 | - search.suggest
13 | - search.highlight
14 | - content.tabs.link
15 | - content.code.annotation
16 | - content.code.copy
17 | language: en
18 | palette:
19 | - scheme: default
20 | toggle:
21 | icon: material/toggle-switch-off-outline
22 | name: Switch to dark mode
23 | primary: teal
24 | accent: purple
25 | - scheme: slate
26 | toggle:
27 | icon: material/toggle-switch
28 | name: Switch to light mode
29 | primary: teal
30 | accent: lime
31 |
32 | nav:
33 | - Home: index.md
34 | - Getting Started:
35 | - Installation: getting-started/installation.md
36 | - Quick Start: getting-started/quickstart.md
37 | - Concepts:
38 | - Overview: concepts/overview.md
39 | - Knowledge Bases: concepts/knowledge-bases.md
40 | - Components: concepts/components.md
41 | - Configuration: concepts/config.md
42 | - Chat: concepts/chat.md
43 | - Citations: concepts/citations.md
44 | - VLM File Parsing: concepts/vlm.md
45 | - Logging: concepts/logging.md
46 | - API:
47 | - KnowledgeBase: api/knowledge_base.md
48 | - Chat: api/chat.md
49 | - Community:
50 | - Contributing: community/contributing.md
51 | - Support: community/support.md
52 | - Professional Services: community/professional-services.md
53 |
54 | markdown_extensions:
55 | - pymdownx.highlight:
56 | anchor_linenums: true
57 | - pymdownx.inlinehilite
58 | - pymdownx.snippets
59 | - admonition
60 | - pymdownx.arithmatex:
61 | generic: true
62 | - footnotes
63 | - pymdownx.details
64 | - pymdownx.superfences
65 | - pymdownx.mark
66 | - attr_list
67 | - pymdownx.emoji:
68 | emoji_index: !!python/name:material.extensions.emoji.twemoji
69 | emoji_generator: !!python/name:material.extensions.emoji.to_svg
70 |
71 | plugins:
72 | - search
73 | - mkdocstrings:
74 | handlers:
75 | python:
76 | paths: [dsrag]
77 | options:
78 | docstring_style: google
79 | show_source: true
80 | show_root_heading: true
81 | show_root_full_path: true
82 | show_object_full_path: false
83 | show_category_heading: true
84 | show_if_no_docstring: false
85 | show_signature_annotations: true
86 | separate_signature: true
87 | line_length: 80
88 | merge_init_into_class: true
89 | docstring_section_style: spacy
90 |
91 | copyright: |
92 | © 2025 dsRAG Contributors
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0.0", "wheel"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "dsrag"
7 | version = "0.5.3"
8 | description = "State-of-the-art RAG pipeline from D-Star AI"
9 | readme = "README.md"
10 | authors = [{ name = "Zach McCormick", email = "zach@d-star.ai" }, { name = "Nick McCormick", email = "nick@d-star.ai" }]
11 | license = "MIT"
12 | classifiers=[
13 | "Development Status :: 4 - Beta",
14 | "Intended Audience :: Developers",
15 | "Intended Audience :: Science/Research",
16 | "Operating System :: OS Independent",
17 | "Programming Language :: Python :: 3.9",
18 | "Programming Language :: Python :: 3.10",
19 | "Programming Language :: Python :: 3.11",
20 | "Programming Language :: Python :: 3.12",
21 | "Topic :: Database",
22 | "Topic :: Software Development",
23 | "Topic :: Software Development :: Libraries",
24 | "Topic :: Software Development :: Libraries :: Application Frameworks",
25 | "Topic :: Software Development :: Libraries :: Python Modules"
26 | ]
27 | requires-python = ">=3.9"
28 | # Core dependencies are included directly (matching requirements.in)
29 | dependencies = [
30 | "langchain-text-splitters>=0.3.0",
31 | "langchain-core>=0.3.0",
32 | "pydantic>=2.8.2",
33 | "numpy>=1.20.0",
34 | "pandas>=2.0.0",
35 | "tiktoken>=0.5.0",
36 | "tqdm>=4.65.0",
37 | "requests>=2.30.0",
38 | "python-dotenv>=1.0.0",
39 | "typing-extensions>=4.5.0",
40 | "instructor>=1.7.0",
41 | "scipy>=1.10.0",
42 | "scikit-learn>=1.2.0",
43 | "pillow>=10.0.0",
44 | "pdf2image>=1.16.0",
45 | "docx2txt>=0.8",
46 | "PyPDF2>=3.0.1"
47 | ]
48 |
49 | [project.urls]
50 | Homepage = "https://github.com/D-Star-AI/dsRAG"
51 | Documentation = "https://github.com/D-Star-AI/dsRAG"
52 | Contact = "https://github.com/D-Star-AI/dsRAG"
53 |
54 | [project.optional-dependencies]
55 | # Database optional dependencies
56 | faiss = ["faiss-cpu>=1.8.0"]
57 | chroma = ["chromadb>=0.5.5"]
58 | weaviate = ["weaviate-client>=4.6.0"]
59 | qdrant = ["qdrant-client>=1.8.0"]
60 | milvus = ["pymilvus>=2.3.5"]
61 | pinecone = ["pinecone>=3.0.0"]
62 | postgres = ["psycopg2-binary>=2.9.0", "pgvector>=0.2.0"]
63 | boto3 = ["boto3>=1.28.0"]
64 |
65 | # LLM/embedding/reranker optional dependencies
66 | openai = ["openai>=1.52.2"]
67 | cohere = ["cohere>=4.0.0"]
68 | voyageai = ["voyageai>=0.1.0"]
69 | ollama = ["ollama>=0.1.0"]
70 | anthropic = ["anthropic>=0.37.1"]
71 | google-generativeai = ["google-generativeai>=0.8.3"]
72 |
73 | # Convenience groups
74 | all-dbs = [
75 | "dsrag[faiss,chroma,weaviate,qdrant,milvus,pinecone,postgres,boto3]"
76 | ]
77 |
78 | all-models = [
79 | "dsrag[openai,cohere,voyageai,ollama,anthropic,google-generativeai]"
80 | ]
81 |
82 | # Complete installation with all optional dependencies
83 | all = [
84 | "dsrag[all-dbs,all-models]"
85 | ]
86 |
87 | [tool.setuptools.packages.find]
88 | where = ["."]
89 | include = ["dsrag", "dsrag.*"]
90 | exclude = ["dsrag.dsparse.tests", "dsrag.dsparse.tests.*", "dsrag.dsparse.dist", "dsrag.dsparse.dist.*"]
--------------------------------------------------------------------------------
/requirements.in:
--------------------------------------------------------------------------------
1 | # Core dependencies for dsRAG
2 | # Generate requirements.txt with: pip-compile requirements.in
3 |
4 | # Core functionality - only include direct dependencies here, not transitive ones
5 | langchain-text-splitters>=0.3.0
6 | langchain-core>=0.3.0 # Updated to newer version for tenacity compatibility
7 | pydantic>=2.8.2 # Version compatible with all dependencies
8 | numpy>=1.20.0
9 | pandas>=2.0.0
10 | tiktoken>=0.5.0
11 | tqdm>=4.65.0
12 | requests>=2.30.0
13 | python-dotenv>=1.0.0
14 | typing-extensions>=4.5.0
15 | instructor>=1.7.0 # Required for streaming functionality
16 | # tenacity and jiter are unpinned to allow instructor to select compatible versions
17 | scipy>=1.10.0
18 | scikit-learn>=1.2.0
19 | pillow>=10.0.0
20 | pdf2image>=1.16.0
21 | docx2txt>=0.8
22 | PyPDF2>=3.0.1
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | #
2 | # This file is autogenerated by pip-compile with Python 3.12
3 | # by the following command:
4 | #
5 | # pip-compile requirements.in
6 | #
7 | aiohappyeyeballs==2.4.6
8 | # via aiohttp
9 | aiohttp==3.11.13
10 | # via instructor
11 | aiosignal==1.3.1
12 | # via aiohttp
13 | annotated-types==0.7.0
14 | # via pydantic
15 | anyio==4.8.0
16 | # via
17 | # httpx
18 | # openai
19 | attrs==23.2.0
20 | # via aiohttp
21 | certifi==2025.1.31
22 | # via
23 | # httpcore
24 | # httpx
25 | # requests
26 | charset-normalizer==3.4.1
27 | # via requests
28 | click==8.1.7
29 | # via typer
30 | distro==1.9.0
31 | # via openai
32 | docstring-parser==0.16
33 | # via instructor
34 | docx2txt==0.8
35 | # via -r requirements.in
36 | frozenlist==1.5.0
37 | # via
38 | # aiohttp
39 | # aiosignal
40 | h11==0.14.0
41 | # via httpcore
42 | httpcore==1.0.7
43 | # via httpx
44 | httpx==0.27.0
45 | # via
46 | # langsmith
47 | # openai
48 | idna==3.10
49 | # via
50 | # anyio
51 | # httpx
52 | # requests
53 | # yarl
54 | instructor==1.7.3
55 | # via -r requirements.in
56 | jinja2==3.1.6
57 | # via instructor
58 | jiter==0.8.2
59 | # via
60 | # instructor
61 | # openai
62 | joblib==1.4.2
63 | # via scikit-learn
64 | jsonpatch==1.33
65 | # via langchain-core
66 | jsonpointer==3.0.0
67 | # via jsonpatch
68 | langchain-core==0.3.42
69 | # via
70 | # -r requirements.in
71 | # langchain-text-splitters
72 | langchain-text-splitters==0.3.6
73 | # via -r requirements.in
74 | langsmith==0.1.147
75 | # via langchain-core
76 | markdown-it-py==3.0.0
77 | # via rich
78 | markupsafe==3.0.2
79 | # via jinja2
80 | mdurl==0.1.2
81 | # via markdown-it-py
82 | multidict==6.1.0
83 | # via
84 | # aiohttp
85 | # yarl
86 | numpy==1.26.4
87 | # via
88 | # -r requirements.in
89 | # pandas
90 | # scikit-learn
91 | # scipy
92 | openai==1.64.0
93 | # via instructor
94 | orjson==3.10.15
95 | # via langsmith
96 | packaging==24.1
97 | # via langchain-core
98 | pandas==2.2.3
99 | # via -r requirements.in
100 | pdf2image==1.16.0
101 | # via -r requirements.in
102 | pillow==10.0.0
103 | # via
104 | # -r requirements.in
105 | # pdf2image
106 | propcache==0.3.0
107 | # via
108 | # aiohttp
109 | # yarl
110 | pydantic==2.8.2
111 | # via
112 | # -r requirements.in
113 | # instructor
114 | # langchain-core
115 | # langsmith
116 | # openai
117 | pydantic-core==2.20.1
118 | # via
119 | # instructor
120 | # pydantic
121 | pygments==2.18.0
122 | # via rich
123 | pypdf2==3.0.1
124 | # via -r requirements.in
125 | python-dateutil==2.9.0.post0
126 | # via pandas
127 | python-dotenv==1.0.1
128 | # via -r requirements.in
129 | pytz==2025.1
130 | # via pandas
131 | pyyaml==6.0.1
132 | # via langchain-core
133 | regex==2024.5.15
134 | # via tiktoken
135 | requests==2.32.3
136 | # via
137 | # -r requirements.in
138 | # instructor
139 | # langsmith
140 | # requests-toolbelt
141 | # tiktoken
142 | requests-toolbelt==1.0.0
143 | # via langsmith
144 | rich==13.7.1
145 | # via
146 | # instructor
147 | # typer
148 | scikit-learn==1.6.1
149 | # via -r requirements.in
150 | scipy==1.15.2
151 | # via
152 | # -r requirements.in
153 | # scikit-learn
154 | shellingham==1.5.4
155 | # via typer
156 | six==1.16.0
157 | # via python-dateutil
158 | sniffio==1.3.1
159 | # via
160 | # anyio
161 | # httpx
162 | # openai
163 | tenacity==9.0.0
164 | # via
165 | # instructor
166 | # langchain-core
167 | threadpoolctl==3.5.0
168 | # via scikit-learn
169 | tiktoken==0.9.0
170 | # via -r requirements.in
171 | tqdm==4.67.1
172 | # via
173 | # -r requirements.in
174 | # openai
175 | typer==0.12.3
176 | # via instructor
177 | typing-extensions==4.12.2
178 | # via
179 | # -r requirements.in
180 | # anyio
181 | # langchain-core
182 | # openai
183 | # pydantic
184 | # pydantic-core
185 | # typer
186 | tzdata==2024.1
187 | # via pandas
188 | urllib3==2.3.0
189 | # via requests
190 | yarl==1.18.3
191 | # via aiohttp
192 |
--------------------------------------------------------------------------------
/tests/data/levels_of_agi.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/D-Star-AI/dsRAG/b326c1d0930c9a80ff486ab0df5e771e788e3234/tests/data/levels_of_agi.pdf
--------------------------------------------------------------------------------
/tests/data/mck_energy_first_5_pages.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/D-Star-AI/dsRAG/b326c1d0930c9a80ff486ab0df5e771e788e3234/tests/data/mck_energy_first_5_pages.pdf
--------------------------------------------------------------------------------
/tests/data/page_7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/D-Star-AI/dsRAG/b326c1d0930c9a80ff486ab0df5e771e788e3234/tests/data/page_7.png
--------------------------------------------------------------------------------
/tests/integration/test_chunk_and_segment_headers.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import unittest
4 |
5 | # add ../../ to sys.path
6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
7 |
8 | from dsrag.dsparse.file_parsing.non_vlm_file_parsing import extract_text_from_pdf
9 | from dsrag.knowledge_base import KnowledgeBase
10 |
11 |
12 | class TestChunkAndSegmentHeaders(unittest.TestCase):
13 | def test_all_context_included(self):
14 | self.cleanup() # delete the KnowledgeBase object if it exists so we can start fresh
15 |
16 | # Get the absolute path of the script file
17 | script_dir = os.path.dirname(os.path.abspath(__file__))
18 |
19 | # Construct the absolute path to the dataset file
20 | file_path = os.path.join(script_dir, "../data/levels_of_agi.pdf")
21 |
22 | # extract text from the pdf
23 | text, _ = extract_text_from_pdf(file_path)
24 |
25 | kb_id = "levels_of_agi"
26 | kb = KnowledgeBase(kb_id=kb_id, exists_ok=False)
27 |
28 | # set configuration for auto context and semantic sectioning and add the document to the KnowledgeBase
29 | auto_context_config = {
30 | 'use_generated_title': True,
31 | 'get_document_summary': True,
32 | 'get_section_summaries': True,
33 | }
34 | semantic_sectioning_config = {
35 | 'use_semantic_sectioning': True,
36 | }
37 | kb.add_document(doc_id="doc_1", text=text, auto_context_config=auto_context_config, semantic_sectioning_config=semantic_sectioning_config)
38 |
39 | # verify that the chunk headers are as expected by running searches and looking at the chunk_header in the results
40 | search_query = "What are the Levels of AGI and why were they defined this way?"
41 | search_results = kb._search(search_query, top_k=1)
42 | chunk_header = search_results[0]['metadata']['chunk_header']
43 | print (f"\nCHUNK HEADER\n{chunk_header}")
44 |
45 | # verify that the segment headers are as expected
46 | segment_header = kb._get_segment_header(doc_id="doc_1", chunk_index=0)
47 |
48 | assert "Section context" in chunk_header
49 | assert "AGI" in segment_header
50 |
51 | # delete the KnowledgeBase object
52 | kb.delete()
53 |
54 | def test_no_context_included(self):
55 | self.cleanup() # delete the KnowledgeBase object if it exists so we can start fresh
56 |
57 | # Get the absolute path of the script file
58 | script_dir = os.path.dirname(os.path.abspath(__file__))
59 |
60 | # Construct the absolute path to the dataset file
61 | file_path = os.path.join(script_dir, "../data/levels_of_agi.pdf")
62 |
63 | # extract text from the pdf
64 | text, _ = extract_text_from_pdf(file_path)
65 |
66 | kb_id = "levels_of_agi"
67 | kb = KnowledgeBase(kb_id=kb_id, exists_ok=False)
68 |
69 | # set configuration for auto context and semantic sectioning and add the document to the KnowledgeBase
70 | auto_context_config = {
71 | 'use_generated_title': False,
72 | 'get_document_summary': False,
73 | 'get_section_summaries': False,
74 | }
75 | semantic_sectioning_config = {
76 | 'use_semantic_sectioning': False,
77 | }
78 | kb.add_document(doc_id="doc_1", text=text, auto_context_config=auto_context_config, semantic_sectioning_config=semantic_sectioning_config)
79 |
80 | # verify that the chunk headers are as expected by running searches and looking at the chunk_header in the results
81 | search_query = "What are the Levels of AGI and why were they defined this way?"
82 | search_results = kb._search(search_query, top_k=1)
83 | chunk_header = search_results[0]['metadata']['chunk_header']
84 | print (f"\nCHUNK HEADER\n{chunk_header}")
85 |
86 | # verify that the segment headers are as expected
87 | segment_header = kb._get_segment_header(doc_id="doc_1", chunk_index=0)
88 |
89 | assert "Section context" not in chunk_header
90 | assert "AGI" not in segment_header
91 |
92 | # delete the KnowledgeBase object
93 | kb.delete()
94 |
95 |
96 | def cleanup(self):
97 | kb = KnowledgeBase(kb_id="levels_of_agi", exists_ok=True)
98 | kb.delete()
99 |
100 | if __name__ == "__main__":
101 | unittest.main()
--------------------------------------------------------------------------------
/tests/integration/test_create_kb_and_query.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import unittest
4 |
5 | # add ../../ to sys.path
6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
7 |
8 | from dsrag.create_kb import create_kb_from_file
9 | from dsrag.knowledge_base import KnowledgeBase
10 |
11 |
12 | class TestCreateKB(unittest.TestCase):
13 | def test_create_kb_from_file(self):
14 | self.cleanup() # delete the KnowledgeBase object if it exists so we can start fresh
15 |
16 | # Get the absolute path of the script file
17 | script_dir = os.path.dirname(os.path.abspath(__file__))
18 |
19 | # Construct the absolute path to the dataset file
20 | file_path = os.path.join(script_dir, "../data/levels_of_agi.pdf")
21 |
22 | kb_id = "levels_of_agi"
23 | kb = create_kb_from_file(kb_id, file_path)
24 |
25 | # verify that the document is in the chunk db
26 | doc_ids = kb.chunk_db.get_all_doc_ids()
27 | self.assertEqual(len(doc_ids), 1)
28 |
29 | # assert that the document title and summary are present and reasonable
30 | document_title = kb.chunk_db.get_document_title(doc_id="levels_of_agi.pdf", chunk_index=0)
31 | self.assertIn("agi", document_title.lower())
32 | document_summary = kb.chunk_db.get_document_summary(doc_id="levels_of_agi.pdf", chunk_index=0)
33 | self.assertIn("agi", document_summary.lower())
34 |
35 | # run a query and verify results are returned
36 | search_queries = ["What are the levels of AGI?", "What is the highest level of AGI?"]
37 | segment_info = kb.query(search_queries)
38 | self.assertGreater(len(segment_info[0]), 0)
39 |
40 | # delete the KnowledgeBase object
41 | kb.delete()
42 |
43 | def cleanup(self):
44 | kb = KnowledgeBase(kb_id="levels_of_agi", exists_ok=True)
45 | kb.delete()
46 |
47 | if __name__ == "__main__":
48 | unittest.main()
--------------------------------------------------------------------------------
/tests/integration/test_create_kb_and_query_chromadb.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import unittest
4 |
5 | # add ../../ to sys.path
6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
7 |
8 | from dsrag.knowledge_base import KnowledgeBase
9 | from dsrag.database.vector.chroma_db import ChromaDB
10 |
11 | @unittest.skipIf(os.environ.get('GITHUB_ACTIONS') == 'true', "ChromaDB is not available on GitHub Actions")
12 | class TestCreateKB(unittest.TestCase):
13 | def test__001_create_kb_and_query(self):
14 | self.cleanup() # delete the KnowledgeBase object if it exists so we can start fresh
15 |
16 | # Get the absolute path of the script file
17 | script_dir = os.path.dirname(os.path.abspath(__file__))
18 |
19 | # Construct the absolute path to the dataset file
20 | file_path = os.path.join(script_dir, "../data/levels_of_agi.pdf")
21 |
22 | kb_id = "levels_of_agi"
23 |
24 | vector_db = ChromaDB(kb_id=kb_id)
25 | kb = KnowledgeBase(kb_id=kb_id, vector_db=vector_db, exists_ok=False)
26 | kb.reranker.model = "rerank-english-v3.0"
27 | kb.add_document(
28 | doc_id="levels_of_agi.pdf",
29 | document_title="Levels of AGI",
30 | file_path=file_path,
31 | semantic_sectioning_config={"use_semantic_sectioning": False},
32 | auto_context_config={
33 | "use_generated_title": False,
34 | "get_document_summary": False
35 | }
36 | )
37 |
38 | # verify that the document is in the chunk db
39 | self.assertEqual(len(kb.chunk_db.get_all_doc_ids()), 1)
40 |
41 | # run a query and verify results are returned
42 | search_queries = [
43 | "What are the levels of AGI?",
44 | "What is the highest level of AGI?",
45 | ]
46 | segment_info = kb.query(search_queries)
47 | self.assertGreater(len(segment_info[0]), 0)
48 |
49 | # Assert that the chunk page start and end are correct
50 | self.assertEqual(segment_info[0]["chunk_page_start"], 5)
51 | self.assertEqual(segment_info[0]["chunk_page_end"], 8)
52 |
53 | def test__002_test_query_with_metadata_filter(self):
54 |
55 | kb = KnowledgeBase(kb_id="levels_of_agi", exists_ok=True)
56 | document_text = "AGI has many levels. The highest level of AGI is level 5."
57 | kb.add_document(
58 | doc_id="test_doc",
59 | document_title="AGI Test Document",
60 | text=document_text,
61 | semantic_sectioning_config={"use_semantic_sectioning": False},
62 | auto_context_config={
63 | "use_generated_title": False, # No need to generate a title for this document for testing purposes
64 | "get_document_summary": False
65 | }
66 | )
67 |
68 | search_queries = [
69 | "What is the highest level of AGI?",
70 | ]
71 | metadata_filter = {"field": "doc_id", "operator": "equals", "value": "test_doc"}
72 | segment_info = kb.query(search_queries, metadata_filter=metadata_filter, rse_params={"minimum_value": 0})
73 | self.assertEqual(len(segment_info), 1)
74 | self.assertEqual(segment_info[0]["doc_id"], "test_doc")
75 | self.assertEqual(segment_info[0]["chunk_page_start"], None)
76 | self.assertEqual(segment_info[0]["chunk_page_end"], None)
77 |
78 |
79 | def test__003_remove_document(self):
80 | kb = KnowledgeBase(kb_id="levels_of_agi", exists_ok=True)
81 | kb.delete_document(doc_id="levels_of_agi.pdf")
82 | kb.delete_document(doc_id="test_doc")
83 | self.assertEqual(len(kb.chunk_db.get_all_doc_ids()), 0)
84 |
85 | # delete the KnowledgeBase object
86 | kb.delete()
87 |
88 | def test__004_query_empty_kb(self):
89 | kb = KnowledgeBase(kb_id="levels_of_agi", exists_ok=True)
90 | results = kb.query(["What are the levels of AGI?"])
91 | self.assertEqual(len(results), 0)
92 | kb.delete()
93 |
94 | def cleanup(self):
95 | kb = KnowledgeBase(kb_id="levels_of_agi", exists_ok=True)
96 | kb.delete()
97 |
98 |
99 | if __name__ == "__main__":
100 | unittest.main()
101 |
--------------------------------------------------------------------------------
/tests/integration/test_create_kb_and_query_weaviate.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import unittest
4 |
5 | # add ../../ to sys.path
6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
7 |
8 | from dsrag.knowledge_base import KnowledgeBase
9 | from dsrag.database.vector import WeaviateVectorDB
10 |
11 |
12 | @unittest.skipIf(os.environ.get('GITHUB_ACTIONS') == 'true', "Weaviate is not available on GitHub Actions")
13 | class TestCreateKB(unittest.TestCase):
14 | def test_create_kb_and_query(self):
15 | self.cleanup() # delete the KnowledgeBase object if it exists so we can start fresh
16 |
17 | # Get the absolute path of the script file
18 | script_dir = os.path.dirname(os.path.abspath(__file__))
19 |
20 | # Construct the absolute path to the dataset file
21 | file_path = os.path.join(script_dir, "../data/levels_of_agi.pdf")
22 |
23 | kb_id = "levels_of_agi"
24 |
25 | vector_db = WeaviateVectorDB(kb_id=kb_id, use_embedded_weaviate=True)
26 | kb = KnowledgeBase(kb_id=kb_id, vector_db=vector_db, exists_ok=False)
27 | kb.reranker.model = "rerank-english-v3.0"
28 | kb.add_document(
29 | doc_id="levels_of_agi.pdf",
30 | document_title="Levels of AGI",
31 | file_path=file_path,
32 | semantic_sectioning_config={"use_semantic_sectioning": False},
33 | auto_context_config={
34 | "use_generated_title": False,
35 | "get_document_summary": False
36 | }
37 | )
38 |
39 | # verify that the document is in the chunk db
40 | self.assertEqual(len(kb.chunk_db.get_all_doc_ids()), 1)
41 |
42 | # run a query and verify results are returned
43 | search_queries = [
44 | "What are the levels of AGI?",
45 | "What is the highest level of AGI?",
46 | ]
47 | segment_info = kb.query(search_queries)
48 | self.assertGreater(len(segment_info[0]), 0)
49 | # Assert that the chunk page start and end are correct
50 | self.assertEqual(segment_info[0]["chunk_page_start"], 5)
51 | self.assertEqual(segment_info[0]["chunk_page_end"], 8)
52 |
53 | # delete the KnowledgeBase object
54 | kb.delete()
55 |
56 | def cleanup(self):
57 | kb = KnowledgeBase(kb_id="levels_of_agi", exists_ok=True)
58 | kb.delete()
59 |
60 |
61 | if __name__ == "__main__":
62 | unittest.main()
63 |
--------------------------------------------------------------------------------
/tests/integration/test_custom_embedding_subclass.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | from openai import OpenAI
4 | import unittest
5 |
6 | # add ../../ to sys.path
7 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
8 |
9 | from dsrag.embedding import Embedding
10 | from dsrag.knowledge_base import KnowledgeBase
11 |
12 | class CustomEmbedding(Embedding):
13 | def __init__(self, model: str = "text-embedding-3-small", dimension: int = 768):
14 | super().__init__(dimension)
15 | self.model = model
16 | self.client = OpenAI()
17 |
18 | def get_embeddings(self, text, input_type=None):
19 | response = self.client.embeddings.create(input=text, model=self.model, dimensions=int(self.dimension))
20 | embeddings = [embedding_item.embedding for embedding_item in response.data]
21 | return embeddings[0] if isinstance(text, str) else embeddings
22 |
23 | def to_dict(self):
24 | base_dict = super().to_dict()
25 | base_dict.update({
26 | 'model': self.model
27 | })
28 | return base_dict
29 |
30 |
31 | class TestCustomEmbeddingSubclass(unittest.TestCase):
32 | def cleanup(self):
33 | kb = KnowledgeBase(kb_id="test_kb", exists_ok=True)
34 | kb.delete()
35 |
36 | def test_custom_embedding_subclass(self):
37 | # cleanup the knowledge base
38 | self.cleanup()
39 |
40 | # initialize a KnowledgeBase object with the custom embedding
41 | kb_id = "test_kb"
42 | kb = KnowledgeBase(kb_id, embedding_model=CustomEmbedding(model="text-embedding-3-large", dimension=1024), exists_ok=False)
43 |
44 | # load the knowledge base
45 | kb = KnowledgeBase(kb_id, exists_ok=True)
46 |
47 | # verify that the embedding model is set correctly
48 | self.assertEqual(kb.embedding_model.model, "text-embedding-3-large")
49 | self.assertEqual(kb.embedding_model.dimension, 1024)
50 |
51 | # delete the knowledge base
52 | self.cleanup()
53 |
54 | if __name__ == "__main__":
55 | unittest.main()
--------------------------------------------------------------------------------
/tests/integration/test_custom_langchain_retriever.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import unittest
4 |
5 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
6 |
7 | from dsrag.create_kb import create_kb_from_file
8 | from dsrag.knowledge_base import KnowledgeBase
9 | from integrations.langchain_retriever import DsRAGLangchainRetriever
10 | from langchain_core.documents import Document
11 |
12 |
13 | class TestDsRAGLangchainRetriever(unittest.TestCase):
14 |
15 | def setUp(self):
16 |
17 | """ We have to create a knowledge base from a file before we can run the retriever tests."""
18 |
19 | # Get the absolute path of the script file
20 | script_dir = os.path.dirname(os.path.abspath(__file__))
21 |
22 | # Construct the absolute path to the dataset file
23 | file_path = os.path.join(script_dir, "../data/levels_of_agi.pdf")
24 |
25 | kb_id = "levels_of_agi"
26 | create_kb_from_file(kb_id, file_path)
27 | self.kb_id = kb_id
28 |
29 | def test_create_kb_and_query(self):
30 |
31 | retriever = DsRAGLangchainRetriever(self.kb_id)
32 | documents = retriever.invoke("What are the levels of AGI?")
33 |
34 | # Make sure results are returned
35 | self.assertTrue(len(documents) > 0)
36 |
37 | # Make sure the results are Document objects, and that they have page content and metadata
38 | for doc in documents:
39 | assert isinstance(doc, Document)
40 | assert doc.page_content
41 | assert doc.metadata
42 |
43 | def tearDown(self):
44 | # Delete the KnowledgeBase object after the test runs
45 | kb = KnowledgeBase(kb_id="levels_of_agi", exists_ok=True)
46 | kb.delete()
47 |
48 |
49 | if __name__ == "__main__":
50 | unittest.main()
--------------------------------------------------------------------------------
/tests/integration/test_save_and_load.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import unittest
4 | import shutil
5 | from unittest.mock import patch
6 |
7 | # add ../../ to sys.path
8 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
9 |
10 | from dsrag.knowledge_base import KnowledgeBase
11 | from dsrag.llm import OpenAIChatAPI
12 | from dsrag.embedding import CohereEmbedding
13 | from dsrag.database.vector import BasicVectorDB # Import necessary DBs for override
14 | from dsrag.database.chunk import BasicChunkDB
15 | from dsrag.dsparse.file_parsing.file_system import LocalFileSystem
16 |
17 | class TestSaveAndLoad(unittest.TestCase):
18 | def cleanup(self):
19 | # Use a specific kb_id for testing
20 | kb = KnowledgeBase(kb_id="test_kb_override", exists_ok=True)
21 | # Try deleting, ignore if it doesn't exist
22 | try:
23 | kb.delete()
24 | except FileNotFoundError:
25 | pass
26 | # Ensure the base KB used in other tests is also cleaned up
27 | kb_main = KnowledgeBase(kb_id="test_kb", exists_ok=True)
28 | try:
29 | kb_main.delete()
30 | except FileNotFoundError:
31 | pass
32 |
33 | def test_save_and_load(self):
34 | self.cleanup()
35 |
36 | # initialize a KnowledgeBase object
37 | auto_context_model = OpenAIChatAPI(model="gpt-4o-mini")
38 | embedding_model = CohereEmbedding(model="embed-english-v3.0")
39 | kb_id = "test_kb"
40 | kb = KnowledgeBase(kb_id=kb_id, auto_context_model=auto_context_model, embedding_model=embedding_model, exists_ok=False)
41 |
42 | # load the KnowledgeBase object
43 | kb1 = KnowledgeBase(kb_id=kb_id)
44 |
45 | # verify that the KnowledgeBase object has the right parameters
46 | self.assertEqual(kb1.auto_context_model.model, "gpt-4o-mini")
47 | self.assertEqual(kb1.embedding_model.model, "embed-english-v3.0")
48 |
49 | # delete the KnowledgeBase object
50 | kb1.delete()
51 |
52 | # Check if the underlying storage path exists after deletion
53 | chunk_db_path = os.path.join(os.path.expanduser("~/dsRAG"), "chunk_data", kb_id)
54 | self.assertFalse(os.path.exists(chunk_db_path), f"Chunk DB path {chunk_db_path} should not exist after delete.")
55 |
56 | def test_load_override_warning(self):
57 | """Test that warnings are logged when overriding components during load."""
58 | self.cleanup()
59 | kb_id = "test_kb_override"
60 |
61 | # 1. Create initial KB to save config
62 | kb_initial = KnowledgeBase(kb_id=kb_id, exists_ok=False)
63 |
64 | # Mock DB/FS instances for overriding
65 | mock_vector_db = BasicVectorDB(kb_id=kb_id, storage_directory="~/dsRAG_temp_v")
66 | mock_chunk_db = BasicChunkDB(kb_id=kb_id, storage_directory="~/dsRAG_temp_c")
67 | mock_file_system = LocalFileSystem(base_path="~/dsRAG_temp_f")
68 |
69 | # 2. Patch logging.warning
70 | with patch('logging.warning') as mock_warning:
71 | # 3. Load KB again with overrides
72 | kb_loaded = KnowledgeBase(
73 | kb_id=kb_id,
74 | vector_db=mock_vector_db,
75 | chunk_db=mock_chunk_db,
76 | file_system=mock_file_system,
77 | exists_ok=True # This ensures _load is called
78 | )
79 |
80 | # 4. Assertions
81 | self.assertEqual(mock_warning.call_count, 3, "Expected 3 warning calls (vector_db, chunk_db, file_system)")
82 |
83 | expected_calls = [
84 | f"Overriding stored vector_db for KB '{kb_id}' during load.",
85 | f"Overriding stored chunk_db for KB '{kb_id}' during load.",
86 | f"Overriding stored file_system for KB '{kb_id}' during load.",
87 | ]
88 |
89 | call_args_list = [call.args[0] for call in mock_warning.call_args_list]
90 | call_kwargs_list = [call.kwargs for call in mock_warning.call_args_list]
91 |
92 | # Check if all expected messages are present in the actual calls
93 | for expected_msg in expected_calls:
94 | self.assertTrue(any(expected_msg in str(arg) for arg in call_args_list), f"Expected warning message '{expected_msg}' not found.")
95 |
96 | # Check the 'extra' parameter in all calls
97 | expected_extra = {'kb_id': kb_id}
98 | for kwargs in call_kwargs_list:
99 | self.assertIn('extra', kwargs, "Warning call missing 'extra' parameter.")
100 | self.assertEqual(kwargs['extra'], expected_extra, f"Warning call 'extra' parameter mismatch. Expected {expected_extra}, got {kwargs['extra']}")
101 |
102 |
103 | # 5. Cleanup
104 | kb_loaded.delete()
105 | # Explicitly try to remove temp dirs if they exist (though kb.delete should handle them)
106 | temp_dirs = ["~/dsRAG_temp_v", "~/dsRAG_temp_c", "~/dsRAG_temp_f"]
107 | for temp_dir in temp_dirs:
108 | expanded_path = os.path.expanduser(temp_dir)
109 | if os.path.exists(expanded_path):
110 | try:
111 | shutil.rmtree(expanded_path) # Use shutil.rmtree
112 | # print(f"Cleaned up temp dir: {expanded_path}") # Optional: uncomment for debug
113 | except OSError as e:
114 | print(f"Error removing temp dir {expanded_path}: {e}") # Log error if removal fails
115 |
116 |
117 | if __name__ == "__main__":
118 | unittest.main()
--------------------------------------------------------------------------------
/tests/unit/test_chunking.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import unittest
4 |
5 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
6 | from dsrag.dsparse.sectioning_and_chunking.chunking import chunk_document, chunk_sub_section, find_lines_in_range
7 |
8 | class TestChunking(unittest.TestCase):
9 | def setUp(self):
10 | self.sections = [
11 | {
12 | 'title': 'Section 1',
13 | 'start': 0,
14 | 'end': 3,
15 | 'content': 'This is the first section of the document.'
16 | },
17 | {
18 | 'title': 'Section 2',
19 | 'start': 4,
20 | 'end': 7,
21 | 'content': 'This is the second section of the document.'
22 | }
23 | ]
24 |
25 | self.document_lines = [
26 | {
27 | 'content': 'This is the first line of the document. And here is another sentence.',
28 | 'element_type': 'NarrativeText',
29 | 'page_number': 1,
30 | 'is_visual': False
31 | },
32 | {
33 | 'content': 'This is the second line of the document.',
34 | 'element_type': 'NarrativeText',
35 | 'page_number': 1,
36 | 'is_visual': False
37 | },
38 | {
39 | 'content': 'This is the third line of the document........',
40 | 'element_type': 'NarrativeText',
41 | 'page_number': 1,
42 | 'is_visual': False
43 | },
44 | {
45 | 'content': 'This is the fourth line of the document.',
46 | 'element_type': 'NarrativeText',
47 | 'page_number': 1,
48 | 'is_visual': False
49 | },
50 | {
51 | 'content': 'This is the fifth line of the document. With another sentence.',
52 | 'element_type': 'NarrativeText',
53 | 'page_number': 1,
54 | 'is_visual': False
55 | },
56 | {
57 | 'content': 'This is the sixth line of the document.',
58 | 'element_type': 'NarrativeText',
59 | 'page_number': 1,
60 | 'is_visual': False
61 | },
62 | {
63 | 'content': 'This is the seventh line of the document.',
64 | 'element_type': 'NarrativeText',
65 | 'page_number': 1,
66 | 'is_visual': False
67 | },
68 | {
69 | 'content': 'This is the eighth line of the document. And here is another sentence that is a bit longer',
70 | 'element_type': 'NarrativeText',
71 | 'page_number': 1,
72 | 'is_visual': False
73 | }
74 | ]
75 |
76 | def test__chunk_document(self):
77 | chunk_size = 90
78 | min_length_for_chunking = 120
79 | chunks = chunk_document(self.sections, self.document_lines, chunk_size, min_length_for_chunking)
80 | assert len(chunks[-1]["content"]) > 45
81 |
82 | def test__chunk_sub_section(self):
83 | chunk_size = 90
84 | chunks_text, chunk_line_indices = chunk_sub_section(4, 7, self.document_lines, chunk_size)
85 | assert len(chunks_text) == 3
86 | assert chunk_line_indices == [(4, 4), (5, 6), (7, 7)]
87 |
88 | def test__find_lines_in_range(self):
89 | # Test find_lines_in_range
90 | # (line_idx, start_char, end_char)
91 | line_char_ranges = [
92 | (0, 0, 49),
93 | (1, 50, 99),
94 | (2, 100, 149),
95 | (3, 150, 199),
96 | (4, 200, 249),
97 | (5, 250, 299),
98 | (6, 300, 349),
99 | (7, 350, 399)
100 | ]
101 |
102 | chunk_start = 50
103 | chunk_end = 82
104 | line_start = 0
105 | line_end = 0
106 | chunk_line_start, chunk_line_end = find_lines_in_range(chunk_start, chunk_end, line_char_ranges, line_start, line_end)
107 | assert chunk_line_start == 1
108 | assert chunk_line_end == 1
109 |
110 | chunk_start = 50
111 | chunk_end = 150
112 | line_start = 0
113 | line_end = 0
114 | chunk_line_start, chunk_line_end = find_lines_in_range(chunk_start, chunk_end, line_char_ranges, line_start, line_end)
115 | assert chunk_line_start == 1
116 | assert chunk_line_end == 3
117 |
118 | if __name__ == "__main__":
119 | unittest.main()
--------------------------------------------------------------------------------
/tests/unit/test_embedding.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import unittest
4 |
5 | sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
6 |
7 | from dsrag.embedding import (
8 | OpenAIEmbedding,
9 | CohereEmbedding,
10 | OllamaEmbedding,
11 | Embedding,
12 | )
13 |
14 |
15 | class TestEmbedding(unittest.TestCase):
16 | def test__get_embeddings_openai(self):
17 | input_text = "Hello, world!"
18 | model = "text-embedding-3-small"
19 | dimension = 768
20 | embedding_provider = OpenAIEmbedding(model, dimension)
21 | embedding = embedding_provider.get_embeddings(input_text)
22 | self.assertEqual(len(embedding), dimension)
23 |
24 | def test__get_embeddings_cohere(self):
25 | input_text = "Hello, world!"
26 | model = "embed-english-v3.0"
27 | embedding_provider = CohereEmbedding(model)
28 | embedding = embedding_provider.get_embeddings(input_text, input_type="query")
29 | self.assertEqual(len(embedding), 1024)
30 |
31 | def test__get_embeddings_ollama(self):
32 | input_text = "Hello, world!"
33 | model = "llama3"
34 | dimension = 4096
35 | try:
36 | embedding_provider = OllamaEmbedding(model, dimension)
37 | embedding = embedding_provider.get_embeddings(input_text)
38 | except Exception as e:
39 | print (e)
40 | if e.__class__.__name__ == "ConnectError":
41 | print("Connection failed")
42 | return
43 | self.assertEqual(len(embedding), dimension)
44 |
45 | def test__get_embeddings_openai_with_list(self):
46 | input_texts = ["Hello, world!", "Goodbye, world!"]
47 | model = "text-embedding-3-small"
48 | dimension = 768
49 | embedding_provider = OpenAIEmbedding(model, dimension)
50 | embeddings = embedding_provider.get_embeddings(input_texts)
51 | self.assertEqual(len(embeddings), 2)
52 | self.assertTrue(all(len(embed) == dimension for embed in embeddings))
53 |
54 | def test__get_embeddings_cohere_with_list(self):
55 | input_texts = ["Hello, world!", "Goodbye, world!"]
56 | model = "embed-english-v3.0"
57 | embedding_provider = CohereEmbedding(model)
58 | embeddings = embedding_provider.get_embeddings(input_texts, input_type="query")
59 | assert len(embeddings) == 2
60 | assert all(len(embed) == 1024 for embed in embeddings)
61 | self.assertEqual(len(embeddings), 2)
62 | self.assertTrue(all(len(embed) == 1024 for embed in embeddings))
63 |
64 | def test__get_embeddings_ollama_with_list_llama2(self):
65 | input_texts = ["Hello, world!", "Goodbye, world!"]
66 | model = "llama2"
67 | dimension = 4096
68 | try:
69 | embedding_provider = OllamaEmbedding(model, dimension)
70 | embeddings = embedding_provider.get_embeddings(input_texts)
71 | except Exception as e:
72 | print (e)
73 | if e.__class__.__name__ == "ConnectError":
74 | print ("Connection failed")
75 | return
76 | self.assertEqual(len(embeddings), 2)
77 | self.assertTrue(all(len(embed) == dimension for embed in embeddings))
78 |
79 | def test__get_embeddings_ollama_with_list_minilm(self):
80 | input_texts = ["Hello, world!", "Goodbye, world!"]
81 | model = "all-minilm"
82 | dimension = 384
83 | try:
84 | embedding_provider = OllamaEmbedding(model, dimension)
85 | embeddings = embedding_provider.get_embeddings(input_texts)
86 | except Exception as e:
87 | print (e)
88 | if e.__class__.__name__ == "ConnectError":
89 | print ("Connection failed")
90 | return
91 | self.assertEqual(len(embeddings), 2)
92 | self.assertTrue(all(len(embed) == dimension for embed in embeddings))
93 |
94 | def test__get_embeddings_ollama_with_list_nomic(self):
95 | input_texts = ["Hello, world!", "Goodbye, world!"]
96 | model = "nomic-embed-text"
97 | dimension = 768
98 | try:
99 | embedding_provider = OllamaEmbedding(model, dimension)
100 | embeddings = embedding_provider.get_embeddings(input_texts)
101 | except Exception as e:
102 | print (e)
103 | if e.__class__.__name__ == "ConnectError":
104 | print ("Connection failed")
105 | return
106 | self.assertEqual(len(embeddings), 2)
107 | self.assertTrue(all(len(embed) == dimension for embed in embeddings))
108 |
109 | def test__initialize_from_config(self):
110 | config = {
111 | 'subclass_name': 'OpenAIEmbedding',
112 | 'model': 'text-embedding-3-small',
113 | 'dimension': 1024
114 | }
115 | embedding_instance = Embedding.from_dict(config)
116 | self.assertIsInstance(embedding_instance, OpenAIEmbedding)
117 | self.assertEqual(embedding_instance.model, 'text-embedding-3-small')
118 | self.assertEqual(embedding_instance.dimension, 1024)
119 |
120 |
121 | if __name__ == "__main__":
122 | unittest.main()
--------------------------------------------------------------------------------
/tests/unit/test_file_system.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import unittest
4 | import dotenv
5 |
6 | dotenv.load_dotenv()
7 |
8 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
9 | from dsrag.dsparse.file_parsing.file_system import LocalFileSystem, S3FileSystem
10 | from dsrag.knowledge_base import KnowledgeBase
11 |
12 |
13 | class TestFileSystem(unittest.TestCase):
14 |
15 | @classmethod
16 | def setUpClass(self):
17 | self.base_path = "~/dsrag_test_file_system"
18 | self.local_file_system = LocalFileSystem(base_path=self.base_path)
19 | self.s3_file_system = S3FileSystem(
20 | base_path=self.base_path,
21 | bucket_name=os.environ["AWS_S3_BUCKET_NAME"],
22 | region_name=os.environ["AWS_S3_REGION"],
23 | access_key=os.environ["AWS_S3_ACCESS_KEY"],
24 | secret_key=os.environ["AWS_S3_SECRET_KEY"]
25 | )
26 | self.kb_id = "test_file_system_kb"
27 | self.kb = KnowledgeBase(kb_id=self.kb_id, file_system=self.local_file_system)
28 |
29 | def test__001_load_kb(self):
30 | kb = KnowledgeBase(kb_id=self.kb_id)
31 | # Assert the class
32 | assert isinstance(kb.file_system, LocalFileSystem)
33 |
34 | def test__002_overwrite_file_system(self):
35 | kb = KnowledgeBase(kb_id=self.kb_id, file_system=self.s3_file_system)
36 | # Assert the class. We don't want to allow overwriting the file system
37 | assert isinstance(kb.file_system, S3FileSystem)
38 |
39 | @classmethod
40 | def tearDownClass(self):
41 | try:
42 | os.system(f"rm -rf {self.base_path}")
43 | except:
44 | pass
45 |
46 |
47 | if __name__ == "__main__":
48 | unittest.main()
--------------------------------------------------------------------------------
/tests/unit/test_llm_class.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import unittest
4 |
5 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
6 | from dsrag.llm import OpenAIChatAPI, AnthropicChatAPI, OllamaAPI, LLM, GeminiAPI
7 |
8 |
9 | class TestLLM(unittest.TestCase):
10 | def test__openai_chat_api(self):
11 | chat_api = OpenAIChatAPI()
12 | chat_messages = [
13 | {"role": "user", "content": "Hello, how are you?"},
14 | ]
15 | response = chat_api.make_llm_call(chat_messages)
16 | self.assertIsInstance(response, str)
17 | self.assertGreater(len(response), 0)
18 |
19 | def test__anthropic_chat_api(self):
20 | chat_api = AnthropicChatAPI()
21 | chat_messages = [
22 | {"role": "user", "content": "What's the weather like today?"},
23 | ]
24 | response = chat_api.make_llm_call(chat_messages)
25 | self.assertIsInstance(response, str)
26 | self.assertGreater(len(response), 0)
27 |
28 | def test__ollama_api(self):
29 | try:
30 | chat_api = OllamaAPI()
31 | except Exception as e:
32 | print(e)
33 | if e.__class__.__name__ == "ConnectError":
34 | print ("Connection failed")
35 | return
36 | chat_messages = [
37 | {"role": "user", "content": "Who's your daddy?"},
38 | ]
39 | response = chat_api.make_llm_call(chat_messages)
40 | self.assertIsInstance(response, str)
41 | self.assertGreater(len(response), 0)
42 |
43 | def test__gemini_api(self):
44 | try:
45 | chat_api = GeminiAPI()
46 | except ValueError as e:
47 | # Skip test if API key is not set
48 | if "GEMINI_API_KEY environment variable not set" in str(e):
49 | print(f"Skipping Gemini test: {e}")
50 | return
51 | else:
52 | raise # Re-raise other ValueErrors
53 | except Exception as e:
54 | # Handle other potential exceptions during init (e.g., genai configuration issues)
55 | print(f"Skipping Gemini test due to initialization error: {e}")
56 | return
57 |
58 | chat_messages = [
59 | {"role": "user", "content": "Explain the concept of RAG briefly."},
60 | ]
61 | try:
62 | response = chat_api.make_llm_call(chat_messages)
63 | self.assertIsInstance(response, str)
64 | self.assertGreater(len(response), 0)
65 | except Exception as e:
66 | # Catch potential API call errors (e.g., connection, authentication within genai)
67 | self.fail(f"GeminiAPI make_llm_call failed with exception: {e}")
68 |
69 | def test__save_and_load_from_dict(self):
70 | chat_api = OpenAIChatAPI(temperature=0.5, max_tokens=2000)
71 | config = chat_api.to_dict()
72 | chat_api_loaded = LLM.from_dict(config)
73 | self.assertEqual(chat_api_loaded.model, chat_api.model)
74 | self.assertEqual(chat_api_loaded.temperature, chat_api.temperature)
75 | self.assertEqual(chat_api_loaded.max_tokens, chat_api.max_tokens)
76 |
77 | if __name__ == "__main__":
78 | unittest.main()
--------------------------------------------------------------------------------
/tests/unit/test_reranker.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import unittest
4 | sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
5 |
6 | from dsrag.reranker import Reranker, CohereReranker, NoReranker
7 |
8 |
9 | class TestReranker(unittest.TestCase):
10 | def test_rerank_search_results_cohere(self):
11 | query = "Hello, world!"
12 | search_results = [
13 | {
14 | "metadata": {
15 | "chunk_header": "",
16 | "chunk_text": "Hello, world!"
17 | }
18 | },
19 | {
20 | "metadata": {
21 | "chunk_header": "",
22 | "chunk_text": "Goodbye, world!"
23 | }
24 | }
25 | ]
26 | reranker = CohereReranker()
27 | reranked_search_results = reranker.rerank_search_results(query, search_results)
28 | self.assertEqual(len(reranked_search_results), 2)
29 | self.assertEqual(reranked_search_results[0]["metadata"]["chunk_text"], "Hello, world!")
30 | self.assertEqual(reranked_search_results[1]["metadata"]["chunk_text"], "Goodbye, world!")
31 |
32 | def test_save_and_load_from_dict(self):
33 | reranker = CohereReranker(model="embed-english-v3.0")
34 | config = reranker.to_dict()
35 | reranker_instance = Reranker.from_dict(config)
36 | self.assertEqual(reranker_instance.model, "embed-english-v3.0")
37 |
38 | def test_no_reranker(self):
39 | reranker = NoReranker()
40 | query = "Hello, world!"
41 | search_results = [
42 | {
43 | "metadata": {
44 | "chunk_header": "",
45 | "chunk_text": "Hello, world!"
46 | }
47 | },
48 | {
49 | "metadata": {
50 | "chunk_header": "",
51 | "chunk_text": "Goodbye, world!"
52 | }
53 | }
54 | ]
55 | reranked_search_results = reranker.rerank_search_results(query, search_results)
56 | self.assertEqual(len(reranked_search_results), 2)
57 | self.assertEqual(reranked_search_results[0]["metadata"]["chunk_text"], "Hello, world!")
58 | self.assertEqual(reranked_search_results[1]["metadata"]["chunk_text"], "Goodbye, world!")
59 |
60 |
61 | if __name__ == "__main__":
62 | unittest.main()
--------------------------------------------------------------------------------