├── .gitignore ├── LICENSE ├── README.md ├── config.yaml ├── examples └── chromadb_integration.ipynb ├── qkb_logo.png ├── requirements.txt ├── setup.py └── src ├── chunking ├── __init__.py ├── base_chunker.py ├── cluster_semantic_chunker.py ├── fixed_token_chunker.py ├── kamradt_modified_chunker.py ├── llm_semantic_chunker.py ├── recursive_token_chunker.py ├── registry.py └── utils.py ├── hub_upload ├── card_generator.py ├── dataset_pusher.py └── template.md ├── main.py ├── prompts └── question_generation.txt ├── synth_dataset ├── deduplicator.py ├── question_generator.py └── rate_limiter.py └── training ├── __init__.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | quickb_venv/ 3 | knowledgebase/ 4 | quickb-env/ 5 | kb_env/ 6 | testing/ 7 | __pycache__/ 8 | *.py[cod] 9 | output/ 10 | *.egg-info/ 11 | .env 12 | .pyc -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Adam Łucek 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # QuicKB 2 | 3 | 4 | 5 | Optimize Document Retrieval with Fine-Tuned KnowledgeBases 6 | 7 | ## Overview 8 | 9 | QuicKB optimizes document retrieval by creating fine-tuned knowledgebases through an end-to-end machine learning pipeline that handles document chunking, synthetic data generation, and embedding model training. 10 | 11 | ## Key Features 12 | 13 | ### Document Chunking 14 | Implement state-of-the-art chunking strategies based on [ChromaDB's research](https://research.trychroma.com/evaluating-chunking): 15 | - **Semantic Approaches**: 16 | - LLM-guided splits for natural language breakpoints 17 | - Content-aware clustering for thematic coherence 18 | - Hybrid semantic-token methods for balanced chunking 19 | - **Token/Character-Based Methods**: 20 | - Recursive chunking with custom separator hierarchies 21 | - Precise length-based splitting 22 | 23 | ### Training Data Generation 24 | Automatically create domain-specific training datasets: 25 | - Generate synthetic question-answer pairs from your content 26 | - Intelligent deduplication using semantic similarity 27 | - Parallel processing for large-scale document sets 28 | - Support for local and cloud LLMs via [LiteLLM](https://docs.litellm.ai/docs/) 29 | 30 | ### Embedding Optimization 31 | Fine-tune embedding models for your specific domain: 32 | - Custom training using [Sentence Transformers](https://sbert.net/) 33 | - Cross-platform training support (CUDA, MPS, CPU) 34 | - Dimension reduction techniques with Matryoshka Representation Learning 35 | - Comprehensive evaluation across multiple metrics 36 | - Detailed performance comparisons with baseline models 37 | 38 | ## Installation 39 | 40 | ```bash 41 | git clone https://github.com/ALucek/QuicKB.git 42 | cd QuicKB 43 | 44 | python -m venv quickb-env 45 | source quickb-env/bin/activate # Windows: quickb-env\Scripts\activate 46 | 47 | pip install -e . 48 | ``` 49 | 50 | ## Usage 51 | 52 | 1. Prepare your text documents in a directory 53 | 2. Configure the pipeline in `config.yaml` 54 | 3. Run: 55 | ```bash 56 | python src/main.py 57 | ``` 58 | 4. Enjoy! 59 | 60 | ## Configuration Guide 61 | 62 | The pipeline is controlled through a single `config.yaml` file. Here's a complete configuration example: 63 | 64 | ```yaml 65 | # ========== Pipeline Control ========== 66 | pipeline: 67 | from_stage: "CHUNK" # Options: "CHUNK", "GENERATE", "TRAIN" 68 | to_stage: "TRAIN" # Run pipeline from from_stage to to_stage 69 | 70 | # ========== Global Settings ========== 71 | path_to_knowledgebase: "./knowledgebase" # Directory containing source documents 72 | hub_username: "AdamLucek" # Hugging Face username 73 | hub_token: null # Optional: Use HF_TOKEN env variable instead 74 | 75 | # ========== Document Chunking ========== 76 | chunker_config: 77 | output_path: "./output/knowledgebase.json" 78 | 79 | # Chunker Config: 80 | chunker: "RecursiveTokenChunker" 81 | 82 | chunker_arguments: 83 | chunk_size: 400 84 | chunk_overlap: 0 85 | length_type: "character" 86 | separators: ["\n\n", "\n", ".", "?", "!", " ", ""] 87 | keep_separator: true 88 | is_separator_regex: false 89 | 90 | # Optional: Push chunks to Hugging Face Hub 91 | upload_config: 92 | push_to_hub: true 93 | hub_private: false 94 | hub_dataset_id: "AdamLucek/quickb-kb" 95 | 96 | # ========== Question Generation ========== 97 | question_generation: 98 | output_path: "./output/train_data.json" 99 | 100 | # LLM/Embedding Configuration 101 | litellm_config: 102 | model: "openai/gpt-4o-mini" 103 | model_api_base: null # Optional: Custom API endpoint 104 | 105 | embedding_model: "text-embedding-3-large" 106 | embedding_api_base: null # Optional: Custom embedding endpoint 107 | 108 | # Input dataset settings 109 | input_dataset_config: 110 | dataset_source: "local" # Options: "local", "hub" 111 | local_knowledgebase_path: "./output/knowledgebase.json" 112 | # Hub alternative: 113 | # knowledgebase_dataset_id: "username/quickb-kb" 114 | 115 | # Performance settings 116 | max_workers: 150 # Parallel question generation 117 | llm_calls_per_minute: null # null = no limit 118 | embedding_calls_per_minute: null # null = no limit 119 | 120 | # Question deduplication 121 | deduplication_enabled: true 122 | dedup_embedding_batch_size: 2048 # Batch size for embedding calculation 123 | similarity_threshold: 0.85 # Semantic Similarity Threshold 124 | 125 | # Optional: Push training data to Hub 126 | upload_config: 127 | push_to_hub: true 128 | hub_private: false 129 | hub_dataset_id: "AdamLucek/quickb-qa" 130 | 131 | # ========== Model Training ========== 132 | training: 133 | # Model configuration 134 | model_settings: 135 | # Base model: 136 | model_id: "nomic-ai/modernbert-embed-base" 137 | 138 | # Matryoshka dimensions (must be descending) 139 | matryoshka_dimensions: [768, 512, 256, 128, 64] 140 | metric_for_best_model: "eval_dim_128_cosine_ndcg@10" 141 | max_seq_length: 1024 142 | trust_remote_code: false 143 | 144 | # Training data configuration 145 | train_dataset_config: 146 | dataset_source: "local" # Options: "local", "hub" 147 | local_train_path: "./output/train_data.json" 148 | local_knowledgebase_path: "./output/knowledgebase.json" 149 | # Hub alternatives: 150 | # train_dataset_id: "AdamLucek/quickb-qa" 151 | # knowledgebase_dataset_id: "AdamLucek/quickb-kb" 152 | 153 | # Training hyperparameters 154 | training_arguments: 155 | output_path: "./output/modernbert_quickb" 156 | device: "cuda" # Options: "cuda", "mps", "cpu" 157 | epochs: 4 158 | batch_size: 32 159 | gradient_accumulation_steps: 16 160 | learning_rate: 2.0e-5 161 | warmup_ratio: 0.1 162 | lr_scheduler_type: "cosine" 163 | optim: "adamw_torch_fused" 164 | tf32: true 165 | bf16: true 166 | batch_sampler: "no_duplicates" # Options: "batch_sampler", "no_duplicates", "group_by_label" 167 | eval_strategy: "epoch" 168 | save_strategy: "epoch" 169 | logging_steps: 10 170 | save_total_limit: 3 171 | load_best_model_at_end: true 172 | report_to: "none" 173 | 174 | # Optional: Push trained model to Hub 175 | upload_config: 176 | push_to_hub: true 177 | hub_private: false 178 | hub_model_id: "AdamLucek/modernbert-embed-quickb" 179 | ``` 180 | 181 | ### Alternative Chunker Configurations 182 | 183 | 1. **Fixed Token Chunker** 184 | ```yaml 185 | chunker_config: 186 | output_path: "./output/fixed_token_chunks.json" 187 | chunker: "FixedTokenChunker" 188 | chunker_arguments: 189 | encoding_name: "cl100k_base" 190 | chunk_size: 400 191 | chunk_overlap: 50 192 | length_type: "token" 193 | ``` 194 | 195 | 2. **Cluster Semantic Chunker** 196 | ```yaml 197 | chunker_config: 198 | output_path: "./output/semantic_clusters.json" 199 | chunker: "ClusterSemanticChunker" 200 | chunker_arguments: 201 | max_chunk_size: 400 # Max tokens after clustering 202 | min_chunk_size: 50 # Initial split size 203 | length_type: "token" 204 | litellm_config: 205 | embedding_model: "text-embedding-3-large" # Required for semantic clustering 206 | ``` 207 | 208 | 3. **LLM Semantic Chunker** 209 | ```yaml 210 | chunker_config: 211 | output_path: "./output/llm_semantic_chunks.json" 212 | chunker: "LLMSemanticChunker" 213 | chunker_arguments: 214 | length_type: "token" 215 | litellm_config: 216 | model: "openai/gpt-4o" # LLM for split decisions 217 | ``` 218 | 219 | 4. **Kamradt Modified Semantic Chunker** 220 | ```yaml 221 | chunker_config: 222 | output_path: "./output/kamradt_chunks.json" 223 | chunker: "KamradtModifiedChunker" 224 | chunker_arguments: 225 | avg_chunk_size: 400 # Target average size 226 | min_chunk_size: 50 # Minimum initial split 227 | length_type: "token" 228 | litellm_config: 229 | embedding_model: "text-embedding-3-large" # For similarity calculations 230 | ``` 231 | ### Device Support 232 | 233 | QuicKB supports multiple compute devices for model training: 234 | 235 | - **CUDA**: NVIDIA GPUs (default and recommended for best performance) 236 | - **MPS**: Apple Silicon on macOS 237 | - **CPU**: Available on all systems, but significantly slower for training 238 | 239 | To specify your preferred device, use the `device` parameter in your training configuration: 240 | 241 | ```yaml 242 | training: 243 | training_arguments: 244 | device: "cuda" # Options: "cuda", "mps", "cpu" 245 | ``` 246 | 247 | **Note**: The default configuration and hyperparameters are optimized for CUDA GPUs. When using MPS or CPU, you may need to install other torch versions or adjust the following parameters for optimal performance: 248 | 249 | - Reduce `batch_size` (e.g., 8-16 for MPS, 4-8 for CPU) 250 | - Reduce `gradient_accumulation_steps` for CPU training 251 | - Set `bf16: false` and `tf32: false` for CPU training 252 | - Use other optimizers like `adamw_torch` 253 | - Consider using smaller base models 254 | 255 | Device selection is automatic if you don't specify a device (prioritizing CUDA > MPS > CPU), but explicit configuration is recommended for reproducibility. 256 | 257 | ### LiteLLM Integration 258 | 259 | QuicKB uses [LiteLLM](https://docs.litellm.ai/docs/) for flexible model provider integration, allowing you to use any supported LLM or embedding provider for question generation and semantic chunking. This enables both cloud-based and local model deployment. 260 | 261 | The LiteLLM configuration is managed through the `litellm_config` section in both the chunking and question generation configurations: 262 | 263 | ```yaml 264 | litellm_config: 265 | model: "openai/gpt-4o" # LLM model identifier 266 | model_api_base: null # Optional API base URL for LLM 267 | embedding_model: "text-embedding-3-large" # Embedding model identifier 268 | embedding_api_base: null # Optional API base URL for embeddings 269 | ``` 270 | 271 | **Using Local Models**: 272 | 273 | 1. Set up an OpenAI API compatible endpoint (e.g., Ollama, vLLM) 274 | 2. Configure the `model_api_base` or `embedding_api_base` in your config 275 | 3. Use the appropriate model identifier format 276 | 277 | Example local setup: 278 | 279 | ```yaml 280 | # For question generation 281 | question_generation: 282 | litellm_config: 283 | model: "local/llama-7b" 284 | model_api_base: "http://localhost:8000" 285 | embedding_model: "local/bge-small" 286 | embedding_api_base: "http://localhost:8000" 287 | 288 | # For semantic chunkers 289 | chunker_config: 290 | chunker: "ClusterSemanticChunker" # or other semantic chunkers 291 | chunker_arguments: 292 | litellm_config: 293 | model: "local/llama-7b" 294 | model_api_base: "http://localhost:8000" 295 | embedding_model: "local/bge-small" 296 | embedding_api_base: "http://localhost:8000" 297 | ``` 298 | 299 | For more details on setting up local models and supported providers, refer to the [LiteLLM documentation](https://docs.litellm.ai/docs/providers). 300 | 301 | ### Hugging Face Hub Integration 302 | 303 | QuicKB integrates directly with [Hugging Face](https://huggingface.co/) for storing and loading datasets and models. Each pipeline stage can optionally push its outputs to the Hub, and subsequent stages can load data directly from there. 304 | 305 | The Hub integration is configured through `upload_config` sections and dataset source settings: 306 | 307 | ```yaml 308 | # Example Hub configuration for chunking 309 | chunker_config: 310 | upload_config: 311 | push_to_hub: true 312 | hub_private: false 313 | hub_dataset_id: "username/quickb-kb" 314 | 315 | # Loading data from Hub for question generation 316 | question_generation: 317 | input_dataset_config: 318 | dataset_source: "hub" 319 | knowledgebase_dataset_id: "username/quickb-kb" 320 | 321 | upload_config: 322 | push_to_hub: true 323 | hub_private: false 324 | hub_dataset_id: "username/quickb-qa" 325 | 326 | # Loading from Hub for training 327 | training: 328 | train_dataset_config: 329 | dataset_source: "hub" 330 | train_dataset_id: "username/quickb-qa" 331 | knowledgebase_dataset_id: "username/quickb-kb" 332 | 333 | upload_config: 334 | push_to_hub: true 335 | hub_private: false 336 | hub_model_id: "username/modernbert-embed-quickb" 337 | ``` 338 | **Authentication** 339 | - Set your Hugging Face token using the `HF_TOKEN` environment variable or specify it in the config using `hub_token` 340 | 341 | ## Output Format 342 | 343 | ### Knowledgebase Dataset 344 | ```json 345 | { 346 | "id": "3fa85f64-5717-4562-b3fc-2c963f66afa6", 347 | "text": "Section 12.1: Termination clauses...", 348 | "source": "docs/contracts/2024/Q1-agreement.txt" 349 | } 350 | ``` 351 | 352 | ### Training Dataset 353 | ```json 354 | { 355 | "anchor": "What are the termination notice requirements?", 356 | "positive": "Section 12.1: Either party may terminate...", 357 | "question_id": "a3b8c7d0-e83a-4b5c-b12d-3f7a8d4c9e1b", 358 | "chunk_id": "3fa85f64-5717-4562-b3fc-2c963f66afa6" 359 | } 360 | ``` 361 | 362 | ## Environment Variables 363 | 364 | - `_API_KEY`: Required for LLM embeddings, question generation, and chunking 365 | - `HF_TOKEN`: Required for Hugging Face Hub uploads and downloads 366 | 367 | ## Citations 368 | 369 | QuicKB builds upon these foundational works: 370 | 371 | ChromaDB: [Evaluating Chunking Strategies for Retrieval](https://research.trychroma.com/evaluating-chunking) 372 | ```bibtex 373 | @techreport{smith2024evaluating, 374 | title = {Evaluating Chunking Strategies for Retrieval}, 375 | author = {Smith, Brandon and Troynikov, Anton}, 376 | year = {2024}, 377 | month = {July}, 378 | institution = {Chroma}, 379 | url = {https://research.trychroma.com/evaluating-chunking}, 380 | } 381 | ``` 382 | 383 | Sentence Transformers 384 | ```bibtext 385 | @inproceedings{reimers-2019-sentence-bert, 386 | title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks", 387 | author = "Reimers, Nils and Gurevych, Iryna", 388 | booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing", 389 | month = "11", 390 | year = "2019", 391 | publisher = "Association for Computational Linguistics", 392 | url = "https://arxiv.org/abs/1908.10084", 393 | } 394 | ``` 395 | 396 | ## Contributing 397 | 398 | Contributions welcome! Please feel free to submit a Pull Request. 399 | 400 | Todo List: 401 | 402 | - Cleaner handling of config arguments and validation at pipeline stages 403 | - pydantic v2 fields warning 404 | - Custom Model Card (Using base from SBERT currently) 405 | 406 | ## License 407 | 408 | MIT License - See [LICENSE](LICENSE) 409 | 410 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 411 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | # ========== Pipeline Control ========== 2 | pipeline: 3 | from_stage: "CHUNK" # Options: "CHUNK", "GENERATE", "TRAIN" 4 | to_stage: "TRAIN" # Run pipeline from from_stage to to_stage 5 | 6 | # ========== Global Settings ========== 7 | path_to_knowledgebase: "./knowledgebase" # Directory containing source documents 8 | hub_username: "AdamLucek" # Hugging Face username 9 | hub_token: null # Optional: Use HF_TOKEN env variable instead 10 | 11 | # ========== Document Chunking ========== 12 | chunker_config: 13 | output_path: "./output/knowledgebase.json" 14 | 15 | # Chunker Config: 16 | chunker: "RecursiveTokenChunker" 17 | 18 | chunker_arguments: 19 | chunk_size: 400 20 | chunk_overlap: 0 21 | length_type: "character" 22 | separators: ["\n\n", "\n", ".", "?", "!", " ", ""] 23 | keep_separator: true 24 | is_separator_regex: false 25 | 26 | # Optional: Push chunks to Hugging Face Hub 27 | upload_config: 28 | push_to_hub: true 29 | hub_private: false 30 | hub_dataset_id: "AdamLucek/quickb-kb" 31 | 32 | # ========== Question Generation ========== 33 | question_generation: 34 | output_path: "./output/train_data.json" 35 | 36 | # LLM/Embedding Configuration 37 | litellm_config: 38 | model: "openai/gpt-4o-mini" 39 | model_api_base: null # Optional: Custom API endpoint 40 | 41 | embedding_model: "text-embedding-3-large" 42 | embedding_api_base: null # Optional: Custom embedding endpoint 43 | 44 | # Input dataset settings 45 | input_dataset_config: 46 | dataset_source: "local" # Options: "local", "hub" 47 | local_knowledgebase_path: "./output/knowledgebase.json" 48 | # Hub alternative: 49 | # knowledgebase_dataset_id: "AdamLucek/quickb-kb" 50 | 51 | # Performance settings 52 | max_workers: 150 # Parallel question generation 53 | llm_calls_per_minute: null # null = no limit 54 | embedding_calls_per_minute: null # null = no limit 55 | 56 | # Question deduplication 57 | deduplication_enabled: true 58 | dedup_embedding_batch_size: 2048 # Batch size for embedding calculation 59 | similarity_threshold: 0.85 # Semantic Similarity Threshold 60 | 61 | # Optional: Push training data to Hub 62 | upload_config: 63 | push_to_hub: true 64 | hub_private: false 65 | hub_dataset_id: "AdamLucek/quickb-qa" 66 | 67 | # ========== Model Training ========== 68 | training: 69 | # Model configuration 70 | model_settings: 71 | # Base model: 72 | model_id: "nomic-ai/modernbert-embed-base" 73 | 74 | # Matryoshka dimensions (must be descending) 75 | matryoshka_dimensions: [768, 512, 256, 128, 64] 76 | metric_for_best_model: "eval_dim_128_cosine_ndcg@10" 77 | max_seq_length: 1024 78 | trust_remote_code: false 79 | 80 | # Training data configuration 81 | train_dataset_config: 82 | dataset_source: "local" # Options: "local", "hub" 83 | local_train_path: "./output/train_data.json" 84 | local_knowledgebase_path: "./output/knowledgebase.json" 85 | # Hub alternatives: 86 | # train_dataset_id: "AdamLucek/quickb-qa" 87 | # knowledgebase_dataset_id: "AdamLucek/quickb-kb" 88 | 89 | # Training hyperparameters 90 | training_arguments: 91 | output_path: "./output/modernbert_quickb" 92 | device: "cuda" # Options: "cuda", "mps", "cpu" 93 | epochs: 4 94 | batch_size: 32 95 | gradient_accumulation_steps: 16 96 | learning_rate: 2.0e-5 97 | warmup_ratio: 0.1 98 | lr_scheduler_type: "cosine" 99 | optim: "adamw_torch_fused" 100 | tf32: true 101 | bf16: true 102 | batch_sampler: "no_duplicates" # Options: "batch_sampler", "no_duplicates", "group_by_label" 103 | eval_strategy: "epoch" 104 | save_strategy: "epoch" 105 | logging_steps: 10 106 | save_total_limit: 3 107 | load_best_model_at_end: true 108 | report_to: "none" 109 | 110 | # Optional: Push trained model to Hub 111 | upload_config: 112 | push_to_hub: true 113 | hub_private: false 114 | hub_model_id: "AdamLucek/modernbert-embed-quickb" -------------------------------------------------------------------------------- /examples/chromadb_integration.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "2abfd742-3602-4d97-894b-28f5f1edfd5a", 6 | "metadata": {}, 7 | "source": [ 8 | "# QuicKB Integration - ChromaDB\n", 9 | "\n", 10 | "This example notebook shows you how to implement your knowledgebase and fine-tuned model with ChromaDB" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "id": "6dae01d8-72f9-47a8-af45-a57785beec6b", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "# Install required packages if needed:\n", 21 | "# !pip install chromadb datasets sentence-transformers\n", 22 | "\n", 23 | "import chromadb\n", 24 | "from chromadb.utils import embedding_functions\n", 25 | "from datasets import load_dataset\n", 26 | "\n", 27 | "from sentence_transformers import SentenceTransformer" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 2, 33 | "id": "230a0175-978f-4fe0-9aa5-b971f9047a6e", 34 | "metadata": {}, 35 | "outputs": [ 36 | { 37 | "data": { 38 | "application/vnd.jupyter.widget-view+json": { 39 | "model_id": "baa45dc1fbf046aa95efef6c123a6109", 40 | "version_major": 2, 41 | "version_minor": 0 42 | }, 43 | "text/plain": [ 44 | "modules.json: 0%| | 0.00/349 [00:00=3.10', 15 | ) 16 | -------------------------------------------------------------------------------- /src/chunking/__init__.py: -------------------------------------------------------------------------------- 1 | from .registry import ChunkerRegistry 2 | from .fixed_token_chunker import FixedTokenChunker 3 | from .recursive_token_chunker import RecursiveTokenChunker 4 | from .cluster_semantic_chunker import ClusterSemanticChunker 5 | from .llm_semantic_chunker import LLMSemanticChunker 6 | from .kamradt_modified_chunker import KamradtModifiedChunker 7 | from .utils import get_length_function, get_token_count, get_character_count 8 | 9 | __all__ = [ 10 | 'ClusterSemanticChunker', 11 | 'LLMSemanticChunker', 12 | 'FixedTokenChunker', 13 | 'RecursiveTokenChunker', 14 | 'KamradtModifiedChunker', 15 | 'ChunkerRegistry', 16 | 'get_length_function', 17 | 'get_token_count', 18 | 'get_character_count' 19 | ] -------------------------------------------------------------------------------- /src/chunking/base_chunker.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List, Any 3 | from .utils import get_length_function 4 | 5 | class BaseChunker(ABC): 6 | """Base class for all chunking implementations.""" 7 | 8 | def __init__(self, *args, **kwargs): 9 | """ 10 | Initialize the chunker with length function configuration. 11 | 12 | Args: 13 | *args: Variable length argument list 14 | **kwargs: Arbitrary keyword arguments 15 | """ 16 | self._initialize_length_function(kwargs) 17 | 18 | def _initialize_length_function(self, kwargs): 19 | """ 20 | Initialize the length function based on provided configuration. 21 | 22 | Args: 23 | kwargs: Keyword arguments that may contain length function configuration 24 | """ 25 | # Get length function type and parameters 26 | length_type = kwargs.pop('length_type', 'token') 27 | encoding_name = kwargs.pop('encoding_name', 'cl100k_base') 28 | 29 | # Set up default length function 30 | self.length_function = get_length_function( 31 | length_type=length_type, 32 | encoding_name=encoding_name 33 | ) 34 | 35 | # Override with custom length function if provided 36 | if 'length_function' in kwargs: 37 | self.length_function = kwargs.pop('length_function') 38 | 39 | @abstractmethod 40 | def split_text(self, text: str) -> List[str]: 41 | """ 42 | Split input text into chunks. 43 | 44 | Args: 45 | text: The input text to split 46 | 47 | Returns: 48 | List of text chunks 49 | """ 50 | pass -------------------------------------------------------------------------------- /src/chunking/cluster_semantic_chunker.py: -------------------------------------------------------------------------------- 1 | 2 | # This script is adapted from the chunking_evaluation package, developed by ChromaDB Research. 3 | # Original code can be found at: https://github.com/brandonstarxel/chunking_evaluation/blob/main/chunking_evaluation/chunking/cluster_semantic_chunker.py 4 | # License: MIT License 5 | 6 | from .base_chunker import BaseChunker 7 | from typing import List, Optional 8 | import numpy as np 9 | from litellm import embedding 10 | from .recursive_token_chunker import RecursiveTokenChunker 11 | from .registry import ChunkerRegistry 12 | from .utils import get_length_function 13 | import logging 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | @ChunkerRegistry.register("ClusterSemanticChunker") 18 | class ClusterSemanticChunker(BaseChunker): 19 | def __init__( 20 | self, 21 | max_chunk_size: int = 400, 22 | min_chunk_size: int = 50, 23 | length_type: str = 'token', 24 | litellm_config: Optional[dict] = None, 25 | **kwargs 26 | ): 27 | super().__init__(length_type=length_type, **kwargs) 28 | 29 | self.length_function = get_length_function( 30 | length_type=length_type, 31 | encoding_name=kwargs.get('encoding_name', 'cl100k_base'), 32 | model_name=kwargs.get('model_name') 33 | ) 34 | 35 | self.splitter = RecursiveTokenChunker( 36 | chunk_size=min_chunk_size, 37 | chunk_overlap=0, 38 | length_function=self.length_function, 39 | separators=["\n\n", "\n", ".", "?", "!", " ", ""] 40 | ) 41 | 42 | self._litellm_config = litellm_config or {} 43 | self.max_cluster = max_chunk_size // min_chunk_size 44 | 45 | def _get_embeddings(self, texts: List[str]) -> List[List[float]]: 46 | """Universal embedding response handler""" 47 | try: 48 | response = embedding( 49 | model=self._litellm_config.get('embedding_model', 'text-embedding-3-large'), 50 | input=texts, 51 | api_base=self._litellm_config.get('embedding_api_base') 52 | ) 53 | 54 | # Handle all possible response formats 55 | if isinstance(response, dict): 56 | items = response.get('data', []) 57 | elif hasattr(response, 'data'): 58 | items = response.data 59 | else: 60 | items = response 61 | 62 | return [ 63 | item['embedding'] if isinstance(item, dict) else item.embedding 64 | for item in items 65 | ] 66 | 67 | except Exception as e: 68 | logger.error(f"Embedding failed for batch: {str(e)}") 69 | raise RuntimeError(f"Embedding error: {str(e)}") 70 | 71 | def _calculate_similarity_matrix(self, sentences: List[str]) -> np.ndarray: 72 | """Batch processing with error logging""" 73 | if not sentences: 74 | return np.zeros((0, 0)) 75 | 76 | embeddings = [] 77 | for batch_idx in range(0, len(sentences), 500): 78 | try: 79 | batch = sentences[batch_idx:batch_idx+500] 80 | embeddings.extend(self._get_embeddings(batch)) 81 | except Exception as e: 82 | logger.error(f"Failed processing batch {batch_idx//500}: {str(e)}") 83 | raise 84 | 85 | embedding_matrix = np.array(embeddings) 86 | return np.dot(embedding_matrix, embedding_matrix.T) 87 | 88 | def _optimal_segmentation(self, matrix: np.ndarray) -> List[tuple]: 89 | """Original Chroma algorithm implementation""" 90 | n = matrix.shape[0] 91 | if n < 1: 92 | return [] 93 | 94 | # Calculate mean of off-diagonal elements 95 | triu = np.triu_indices(n, k=1) 96 | tril = np.tril_indices(n, k=-1) 97 | mean_value = (matrix[triu].sum() + matrix[tril].sum()) / (n * (n - 1)) if n > 1 else 0 98 | 99 | matrix = matrix - mean_value 100 | np.fill_diagonal(matrix, 0) 101 | 102 | dp = np.zeros(n) 103 | segmentation = np.zeros(n, dtype=int) 104 | 105 | for i in range(n): 106 | for size in range(1, min(self.max_cluster + 1, i + 2)): 107 | start_idx = i - size + 1 108 | if start_idx >= 0: 109 | current_reward = matrix[start_idx:i+1, start_idx:i+1].sum() 110 | if start_idx > 0: 111 | current_reward += dp[start_idx - 1] 112 | if current_reward > dp[i]: 113 | dp[i] = current_reward 114 | segmentation[i] = start_idx 115 | 116 | clusters = [] 117 | i = n - 1 118 | while i >= 0: 119 | start = segmentation[i] 120 | clusters.append((start, i)) 121 | i = start - 1 122 | 123 | return list(reversed(clusters)) 124 | 125 | def split_text(self, text: str) -> List[str]: 126 | """Main processing pipeline""" 127 | if not text.strip(): 128 | return [] 129 | 130 | # First-stage splitting 131 | sentences = self.splitter.split_text(text) 132 | if len(sentences) < 2: 133 | return [text] 134 | 135 | # Semantic clustering 136 | similarity_matrix = self._calculate_similarity_matrix(sentences) 137 | clusters = self._optimal_segmentation(similarity_matrix) 138 | 139 | return [' '.join(sentences[start:end+1]) for start, end in clusters] -------------------------------------------------------------------------------- /src/chunking/fixed_token_chunker.py: -------------------------------------------------------------------------------- 1 | 2 | # This script is adapted from the LangChain package, developed by LangChain AI. 3 | # Original code can be found at: https://github.com/langchain-ai/langchain/blob/master/libs/text-splitters/langchain_text_splitters/base.py 4 | # chunking_evaluation modification: https://github.com/brandonstarxel/chunking_evaluation/blob/main/chunking_evaluation/chunking/fixed_token_chunker.py 5 | # License: MIT License 6 | 7 | from abc import ABC, abstractmethod 8 | import logging 9 | from typing import ( 10 | AbstractSet, 11 | Any, 12 | Callable, 13 | Collection, 14 | List, 15 | Literal, 16 | Optional, 17 | Type, 18 | TypeVar, 19 | Union, 20 | ) 21 | from .base_chunker import BaseChunker 22 | from attr import dataclass 23 | from .registry import ChunkerRegistry 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | TS = TypeVar("TS", bound="TextSplitter") 28 | 29 | class TextSplitter(BaseChunker, ABC): 30 | """Interface for splitting text into chunks.""" 31 | 32 | def __init__( 33 | self, 34 | chunk_size: int = 4000, 35 | chunk_overlap: int = 200, 36 | length_function: Callable[[str], int] = len, 37 | keep_separator: bool = False, 38 | add_start_index: bool = False, 39 | strip_whitespace: bool = True, 40 | **kwargs 41 | ) -> None: 42 | """ 43 | Args: 44 | chunk_size: Maximum size of chunks to return 45 | chunk_overlap: Overlap in characters (tokens) between chunks 46 | length_function: Function that measures the length of given chunks 47 | keep_separator: Whether to keep the separator in the chunks 48 | add_start_index: If `True`, includes chunk's start index in metadata 49 | strip_whitespace: If `True`, strips whitespace from the start/end 50 | """ 51 | super().__init__(**kwargs) 52 | if chunk_overlap > chunk_size: 53 | raise ValueError( 54 | f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " 55 | f"({chunk_size}), should be smaller." 56 | ) 57 | self._chunk_size = chunk_size 58 | self._chunk_overlap = chunk_overlap 59 | # If BaseChunker sets self.length_function, we override with the user's length_function param if given 60 | if length_function is not None: 61 | self._length_function = length_function 62 | else: 63 | self._length_function = self.length_function 64 | 65 | self._keep_separator = keep_separator 66 | self._add_start_index = add_start_index 67 | self._strip_whitespace = strip_whitespace 68 | 69 | @abstractmethod 70 | def split_text(self, text: str) -> List[str]: 71 | pass 72 | 73 | def _join_docs(self, docs: List[str], separator: str) -> Optional[str]: 74 | text = separator.join(docs) 75 | if self._strip_whitespace: 76 | text = text.strip() 77 | return text if text != "" else None 78 | 79 | def _merge_splits(self, splits: List[str], separator: str) -> List[str]: 80 | """Combine smaller pieces into medium chunks.""" 81 | separator_len = self._length_function(separator) 82 | 83 | docs = [] 84 | current_doc: List[str] = [] 85 | total = 0 86 | for d in splits: 87 | _len = self._length_function(d) 88 | if total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size: 89 | if total > self._chunk_size: 90 | logger.warning( 91 | f"Created a chunk of size {total}, " 92 | f"which is longer than the specified {self._chunk_size}" 93 | ) 94 | if len(current_doc) > 0: 95 | doc = self._join_docs(current_doc, separator) 96 | if doc is not None: 97 | docs.append(doc) 98 | # Pop from the front while chunk > overlap 99 | while total > self._chunk_overlap or ( 100 | total + _len + (separator_len if len(current_doc) > 0 else 0) 101 | > self._chunk_size 102 | and total > 0 103 | ): 104 | total -= self._length_function(current_doc[0]) + ( 105 | separator_len if len(current_doc) > 1 else 0 106 | ) 107 | current_doc = current_doc[1:] 108 | current_doc.append(d) 109 | total += _len + (separator_len if len(current_doc) > 1 else 0) 110 | doc = self._join_docs(current_doc, separator) 111 | if doc is not None: 112 | docs.append(doc) 113 | return docs 114 | 115 | @ChunkerRegistry.register("FixedTokenChunker") 116 | class FixedTokenChunker(TextSplitter): 117 | """Splitting text to tokens using model tokenizer.""" 118 | 119 | def __init__( 120 | self, 121 | encoding_name: str = "cl100k_base", 122 | model_name: Optional[str] = None, 123 | chunk_size: int = 4000, 124 | chunk_overlap: int = 200, 125 | allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), 126 | disallowed_special: Union[Literal["all"], Collection[str]] = "all", 127 | **kwargs: Any, 128 | ) -> None: 129 | super().__init__(chunk_size=chunk_size, chunk_overlap=chunk_overlap, **kwargs) 130 | 131 | try: 132 | import tiktoken 133 | except ImportError: 134 | raise ImportError( 135 | "Could not import tiktoken python package. " 136 | "This is needed for FixedTokenChunker. " 137 | "Please install it with `pip install tiktoken`." 138 | ) 139 | 140 | if model_name is not None: 141 | enc = tiktoken.encoding_for_model(model_name) 142 | else: 143 | enc = tiktoken.get_encoding(encoding_name) 144 | self._tokenizer = enc 145 | self._allowed_special = allowed_special 146 | self._disallowed_special = disallowed_special 147 | 148 | def split_text(self, text: str) -> List[str]: 149 | def _encode(_text: str) -> List[int]: 150 | return self._tokenizer.encode( 151 | _text, 152 | allowed_special=self._allowed_special, 153 | disallowed_special=self._disallowed_special, 154 | ) 155 | 156 | tokenizer = Tokenizer( 157 | chunk_overlap=self._chunk_overlap, 158 | tokens_per_chunk=self._chunk_size, 159 | decode=self._tokenizer.decode, 160 | encode=_encode, 161 | ) 162 | 163 | return split_text_on_tokens(text=text, tokenizer=tokenizer) 164 | 165 | 166 | @dataclass(frozen=True) 167 | class Tokenizer: 168 | """Tokenizer data class.""" 169 | chunk_overlap: int 170 | tokens_per_chunk: int 171 | decode: Callable[[List[int]], str] 172 | encode: Callable[[str], List[int]] 173 | 174 | 175 | def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]: 176 | """Split incoming text and return chunks using tokenizer.""" 177 | splits: List[str] = [] 178 | input_ids = tokenizer.encode(text) 179 | start_idx = 0 180 | cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) 181 | chunk_ids = input_ids[start_idx:cur_idx] 182 | while start_idx < len(input_ids): 183 | splits.append(tokenizer.decode(chunk_ids)) 184 | if cur_idx == len(input_ids): 185 | break 186 | start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap 187 | cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) 188 | chunk_ids = input_ids[start_idx:cur_idx] 189 | return splits -------------------------------------------------------------------------------- /src/chunking/kamradt_modified_chunker.py: -------------------------------------------------------------------------------- 1 | 2 | # This script is adapted from the Greg Kamradt's notebook on chunking. 3 | # Original code can be found at: https://github.com/FullStackRetrieval-com/RetrievalTutorials/blob/main/tutorials/LevelsOfTextSplitting/5_Levels_Of_Text_Splitting.ipynb 4 | # chunking_evaluation modification: https://github.com/brandonstarxel/chunking_evaluation/blob/main/chunking_evaluation/chunking/kamradt_modified_chunker.py 5 | 6 | from typing import Optional, List, Any 7 | import numpy as np 8 | from .base_chunker import BaseChunker 9 | from .recursive_token_chunker import RecursiveTokenChunker 10 | from litellm import embedding 11 | from .registry import ChunkerRegistry 12 | 13 | @ChunkerRegistry.register("KamradtModifiedChunker") 14 | class KamradtModifiedChunker(BaseChunker): 15 | def __init__( 16 | self, 17 | avg_chunk_size: int = 400, 18 | min_chunk_size: int = 50, 19 | litellm_config: Optional[dict] = None, 20 | length_type: str = 'token', 21 | **kwargs 22 | ): 23 | super().__init__(length_type=length_type, **kwargs) 24 | 25 | self.splitter = RecursiveTokenChunker( 26 | chunk_size=min_chunk_size, 27 | chunk_overlap=0, 28 | length_function=self.length_function 29 | ) 30 | 31 | self._litellm_config = litellm_config or {} 32 | self.avg_chunk_size = avg_chunk_size 33 | 34 | def _get_embeddings(self, texts: List[str]) -> List[List[float]]: 35 | """Get embeddings using LiteLLM.""" 36 | response = embedding( 37 | model=self._litellm_config.get('embedding_model', 'text-embedding-3-large'), 38 | input=texts, 39 | api_base=self._litellm_config.get('embedding_api_base') 40 | ) 41 | # Extract embeddings from the response data 42 | if hasattr(response, 'data'): 43 | return [item['embedding'] for item in response.data] 44 | elif isinstance(response, dict) and 'data' in response: 45 | return [item['embedding'] for item in response['data']] 46 | raise ValueError(f"Unexpected response format from LiteLLM: {response}") 47 | 48 | def combine_sentences(self, sentences: List[dict], buffer_size: int = 1) -> List[dict]: 49 | for i in range(len(sentences)): 50 | combined = [] 51 | for j in range(max(0, i - buffer_size), min(len(sentences), i + buffer_size + 1)): 52 | combined.append(sentences[j]['sentence']) 53 | sentences[i]['combined_sentence'] = ' '.join(combined) 54 | return sentences 55 | 56 | def calculate_cosine_distances(self, sentences: List[dict]): 57 | embeddings = [] 58 | for i in range(0, len(sentences), 500): 59 | batch = [s['combined_sentence'] for s in sentences[i:i + 500]] 60 | embeddings.extend(self._get_embeddings(batch)) 61 | 62 | embedding_matrix = np.array(embeddings) 63 | norms = np.linalg.norm(embedding_matrix, axis=1, keepdims=True) 64 | embedding_matrix /= norms 65 | similarity_matrix = np.dot(embedding_matrix, embedding_matrix.T) 66 | 67 | distances = [] 68 | for i in range(len(sentences) - 1): 69 | distance = 1 - similarity_matrix[i, i + 1] 70 | distances.append(distance) 71 | sentences[i]['distance_to_next'] = distance 72 | return distances, sentences 73 | 74 | def split_text(self, text: str) -> List[str]: 75 | s_list = self.splitter.split_text(text) 76 | sentences = [{'sentence': s, 'index': i} for i, s in enumerate(s_list)] 77 | if not sentences: 78 | return [] 79 | 80 | sentences = self.combine_sentences(sentences, 3) 81 | distances, sentences = self.calculate_cosine_distances(sentences) 82 | 83 | total_tokens = sum(self.length_function(s['sentence']) for s in sentences) 84 | target_splits = total_tokens // self.avg_chunk_size if self.avg_chunk_size else 1 85 | distances = np.array(distances) 86 | 87 | low, high = 0.0, 1.0 88 | while high - low > 1e-6: 89 | mid = (low + high) / 2 90 | if (distances > mid).sum() > target_splits: 91 | low = mid 92 | else: 93 | high = mid 94 | 95 | split_indices = [i for i, d in enumerate(distances) if d > high] 96 | chunks = [] 97 | start = 0 98 | 99 | for idx in split_indices: 100 | chunks.append(' '.join(s['sentence'] for s in sentences[start:idx + 1])) 101 | start = idx + 1 102 | 103 | if start < len(sentences): 104 | chunks.append(' '.join(s['sentence'] for s in sentences[start:])) 105 | 106 | return chunks -------------------------------------------------------------------------------- /src/chunking/llm_semantic_chunker.py: -------------------------------------------------------------------------------- 1 | 2 | # This script is adapted from the chunking_evaluation package, developed by ChromaDB Research. 3 | # Original code can be found at: https://github.com/brandonstarxel/chunking_evaluation/blob/main/chunking_evaluation/chunking/llm_semantic_chunker.py 4 | # License: MIT License 5 | 6 | from .base_chunker import BaseChunker 7 | from .recursive_token_chunker import RecursiveTokenChunker 8 | import backoff 9 | from tqdm import tqdm 10 | from typing import List, Optional 11 | import re 12 | from .registry import ChunkerRegistry 13 | from litellm import completion 14 | 15 | @ChunkerRegistry.register("LLMSemanticChunker") 16 | class LLMSemanticChunker(BaseChunker): 17 | def __init__( 18 | self, 19 | litellm_config: Optional[dict] = None, 20 | length_type: str = 'token', 21 | **kwargs 22 | ): 23 | super().__init__(length_type=length_type, **kwargs) 24 | 25 | self._litellm_config = litellm_config or {} 26 | 27 | # Initialize the base splitter for initial text splitting 28 | self.splitter = RecursiveTokenChunker( 29 | chunk_size=50, 30 | chunk_overlap=0, 31 | length_function=self.length_function 32 | ) 33 | 34 | def get_prompt(self, chunked_input, current_chunk=0, invalid_response=None): 35 | """Generate the prompt for the LLM.""" 36 | base_prompt = ( 37 | "You are an assistant specialized in splitting text into thematically consistent sections. " 38 | "The text has been divided into chunks, each marked with <|start_chunk_X|> and <|end_chunk_X|> tags, where X is the chunk number. " 39 | "Your task is to identify the points where splits should occur, such that consecutive chunks of similar themes stay together. " 40 | "Respond with a list of chunk IDs where you believe a split should be made. For example, if chunks 1 and 2 belong together but chunk 3 starts a new topic, you would suggest a split after chunk 2. THE CHUNKS MUST BE IN ASCENDING ORDER." 41 | "Your response should be in the form: 'split_after: 3, 5'." 42 | ) 43 | 44 | user_content = ( 45 | f"CHUNKED_TEXT: {chunked_input}\n\n" 46 | f"Respond with split points (ascending, ≥{current_chunk}). " 47 | "Respond only with the IDs of the chunks where you believe a split should occur. YOU MUST RESPOND WITH AT LEAST ONE SPLIT. THESE SPLITS MUST BE IN ASCENDING ORDER" 48 | ) 49 | if invalid_response: 50 | user_content += ( 51 | f"\n\\Previous invalid response: {invalid_response}. " 52 | "DO NOT REPEAT THIS ARRAY OF NUMBERS. Please try again." 53 | ) 54 | 55 | return [ 56 | {"role": "system", "content": base_prompt}, 57 | {"role": "user", "content": user_content} 58 | ] 59 | 60 | @backoff.on_exception(backoff.expo, Exception, max_tries=3) 61 | def _get_llm_response(self, context: str, current: int) -> str: 62 | """Get chunking suggestions from LLM using LiteLLM.""" 63 | try: 64 | response = completion( 65 | model=self._litellm_config.get('model', 'openai/gpt-4o'), 66 | messages=self.get_prompt(context, current), 67 | temperature=0.2, 68 | max_tokens=200, 69 | api_base=self._litellm_config.get('model_api_base') 70 | ) 71 | return response.choices[0].message.content 72 | except Exception as e: 73 | print(f"LLM API error: {str(e)}") 74 | return "" 75 | 76 | def _parse_response(self, response: str, current_chunk: int) -> List[int]: 77 | numbers = [] 78 | if 'split_after:' in response: 79 | numbers = list(map(int, re.findall(r'\d+', response.split('split_after:')[1]))) 80 | return sorted(n for n in numbers if n > current_chunk) # Ensure 1-based > current 0-based 81 | 82 | def _merge_chunks(self, chunks: List[str], indices: List[int]) -> List[str]: 83 | """Merge chunks based on split indices (indices are 1-based from LLM)""" 84 | merged = [] 85 | current = [] 86 | # Convert to 0-based indices and sort 87 | split_points = sorted([i-1 for i in indices if i > 0]) 88 | 89 | for i, chunk in enumerate(chunks): 90 | current.append(chunk) 91 | if i in split_points: 92 | merged.append(" ".join(current).strip()) 93 | current = [] 94 | if current: 95 | merged.append(" ".join(current).strip()) 96 | return merged 97 | 98 | def split_text(self, text: str) -> List[str]: 99 | """Split input text into coherent chunks using LLM guidance.""" 100 | chunks = self.splitter.split_text(text) 101 | split_indices = [] 102 | current_chunk = 0 103 | 104 | with tqdm(total=len(chunks), desc="Processing chunks") as pbar: 105 | while current_chunk < len(chunks) - 4: 106 | context_window = [] 107 | token_count = 0 108 | 109 | for i in range(current_chunk, len(chunks)): 110 | token_count += self.length_function(chunks[i]) 111 | if token_count > 800: 112 | break 113 | context_window.append(f"<|start_chunk_{i+1}|>{chunks[i]}<|end_chunk_{i+1}|>") 114 | 115 | response = self._get_llm_response("\n".join(context_window), current_chunk) 116 | numbers = self._parse_response(response, current_chunk) # FIXED: Added current_chunk argument 117 | 118 | if numbers: 119 | split_indices.extend(numbers) 120 | current_chunk = numbers[-1] 121 | pbar.update(current_chunk - pbar.n) 122 | else: 123 | break 124 | 125 | return self._merge_chunks(chunks, split_indices) -------------------------------------------------------------------------------- /src/chunking/recursive_token_chunker.py: -------------------------------------------------------------------------------- 1 | 2 | # This script is adapted from the LangChain package, developed by LangChain AI. 3 | # Original code can be found at: https://github.com/langchain-ai/langchain/blob/master/libs/text-splitters/langchain_text_splitters/character.py 4 | # chunking_evaluation modification: https://github.com/brandonstarxel/chunking_evaluation/blob/main/chunking_evaluation/chunking/recursive_token_chunker.py 5 | # License: MIT License 6 | 7 | from typing import Any, List, Optional 8 | from .utils import Language 9 | from .fixed_token_chunker import TextSplitter 10 | import re 11 | from .registry import ChunkerRegistry 12 | 13 | def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> List[str]: 14 | if separator: 15 | if keep_separator: 16 | # The parentheses in the pattern keep the delimiters in the result. 17 | _splits = re.split(f"({separator})", text) 18 | splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)] 19 | if len(_splits) % 2 == 0: 20 | splits += _splits[-1:] 21 | splits = [_splits[0]] + splits 22 | else: 23 | splits = re.split(separator, text) 24 | else: 25 | splits = list(text) 26 | return [s for s in splits if s != ""] 27 | 28 | @ChunkerRegistry.register("RecursiveTokenChunker") 29 | class RecursiveTokenChunker(TextSplitter): 30 | """Splitting text by recursively looking at characters / tokens.""" 31 | 32 | def __init__( 33 | self, 34 | chunk_size: int = 4000, 35 | chunk_overlap: int = 200, 36 | separators: Optional[List[str]] = None, 37 | keep_separator: bool = True, 38 | is_separator_regex: bool = False, 39 | length_type: str = 'token', 40 | **kwargs: Any, 41 | ) -> None: 42 | super().__init__( 43 | chunk_size=chunk_size, 44 | chunk_overlap=chunk_overlap, 45 | keep_separator=keep_separator, 46 | length_type=length_type, 47 | **kwargs 48 | ) 49 | self._separators = separators or ["\n\n", "\n", ".", "?", "!", " ", ""] 50 | self._is_separator_regex = is_separator_regex 51 | 52 | def _split_text(self, text: str, separators: List[str]) -> List[str]: 53 | final_chunks = [] 54 | separator = separators[-1] 55 | new_separators = [] 56 | 57 | for i, _s in enumerate(separators): 58 | _separator = _s if self._is_separator_regex else re.escape(_s) 59 | if _s == "": 60 | separator = _s 61 | break 62 | if re.search(_separator, text): 63 | separator = _s 64 | new_separators = separators[i + 1:] 65 | break 66 | 67 | _separator = separator if self._is_separator_regex else re.escape(separator) 68 | splits = _split_text_with_regex(text, _separator, self._keep_separator) 69 | 70 | _good_splits = [] 71 | actual_separator = "" if self._keep_separator else separator 72 | 73 | for s in splits: 74 | if self.length_function(s) < self._chunk_size: 75 | _good_splits.append(s) 76 | else: 77 | if _good_splits: 78 | merged_text = self._merge_splits(_good_splits, actual_separator) 79 | final_chunks.extend(merged_text) 80 | _good_splits = [] 81 | if not new_separators: 82 | final_chunks.append(s) 83 | else: 84 | other_info = self._split_text(s, new_separators) 85 | final_chunks.extend(other_info) 86 | if _good_splits: 87 | merged_text = self._merge_splits(_good_splits, actual_separator) 88 | final_chunks.extend(merged_text) 89 | 90 | return final_chunks 91 | 92 | def split_text(self, text: str) -> List[str]: 93 | return self._split_text(text, self._separators) 94 | 95 | @staticmethod 96 | def get_separators_for_language(language: Language) -> List[str]: 97 | if language == Language.PYTHON: 98 | return [ 99 | "\nclass ", 100 | "\ndef ", 101 | "\n\tdef ", 102 | "\n\n", 103 | "\n", 104 | " ", 105 | "", 106 | ] 107 | raise ValueError( 108 | f"Language {language} is not supported! " 109 | f"Please choose from {list(Language)}" 110 | ) -------------------------------------------------------------------------------- /src/chunking/registry.py: -------------------------------------------------------------------------------- 1 | class ChunkerRegistry: 2 | _chunkers = {} 3 | 4 | @classmethod 5 | def register(cls, name: str): 6 | def decorator(chunker_class): 7 | cls._chunkers[name] = chunker_class 8 | return chunker_class 9 | return decorator 10 | 11 | @classmethod 12 | def get_chunker(cls, name: str): 13 | if name not in cls._chunkers: 14 | available = list(cls._chunkers.keys()) 15 | raise ValueError(f"Unknown chunker: {name}. Available chunkers: {available}") 16 | return cls._chunkers[name] -------------------------------------------------------------------------------- /src/chunking/utils.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | import tiktoken 3 | from typing import Callable, Optional 4 | 5 | class Language(str, Enum): 6 | """Supported languages for language-specific chunking.""" 7 | CPP = "cpp" 8 | GO = "go" 9 | JAVA = "java" 10 | KOTLIN = "kotlin" 11 | JS = "js" 12 | TS = "ts" 13 | PHP = "php" 14 | PROTO = "proto" 15 | PYTHON = "python" 16 | RST = "rst" 17 | RUBY = "ruby" 18 | RUST = "rust" 19 | SCALA = "scala" 20 | SWIFT = "swift" 21 | MARKDOWN = "markdown" 22 | LATEX = "latex" 23 | HTML = "html" 24 | SOL = "sol" 25 | CSHARP = "csharp" 26 | COBOL = "cobol" 27 | C = "c" 28 | LUA = "lua" 29 | PERL = "perl" 30 | 31 | def get_token_count( 32 | string: str, 33 | encoding_name: str = "cl100k_base", 34 | model_name: Optional[str] = None, 35 | **kwargs 36 | ) -> int: 37 | """ 38 | Count the number of tokens in a string using tiktoken. 39 | 40 | Args: 41 | string: The text to count tokens for 42 | encoding_name: The name of the tiktoken encoding to use 43 | model_name: Optional model name to use specific encoding 44 | **kwargs: Additional arguments passed to tiktoken encoder 45 | 46 | Returns: 47 | Number of tokens in the string 48 | """ 49 | try: 50 | if model_name: 51 | enc = tiktoken.encoding_for_model(model_name) 52 | else: 53 | enc = tiktoken.get_encoding(encoding_name) 54 | 55 | allowed_special = kwargs.get('allowed_special', set()) 56 | disallowed_special = kwargs.get('disallowed_special', 'all') 57 | 58 | return len(enc.encode( 59 | string, 60 | allowed_special=allowed_special, 61 | disallowed_special=disallowed_special 62 | )) 63 | except Exception as e: 64 | raise ValueError(f"Error counting tokens: {str(e)}") 65 | 66 | def get_character_count(text: str) -> int: 67 | """ 68 | Count the number of characters in a string. 69 | 70 | Args: 71 | text: The text to count characters for 72 | 73 | Returns: 74 | Number of characters in the string 75 | """ 76 | return len(text) 77 | 78 | def get_length_function(length_type: str = "token", **kwargs) -> Callable[[str], int]: 79 | """ 80 | Get a length function based on the specified type. 81 | 82 | Args: 83 | length_type: Type of length function ('token' or 'character') 84 | **kwargs: Additional arguments passed to token counter 85 | 86 | Returns: 87 | A callable that takes a string and returns its length 88 | """ 89 | if length_type == "token": 90 | return lambda x: get_token_count(x, **kwargs) 91 | elif length_type == "character": 92 | return get_character_count 93 | else: 94 | raise ValueError( 95 | f"Unknown length type: {length_type}. " 96 | "Choose 'token' or 'character'" 97 | ) -------------------------------------------------------------------------------- /src/hub_upload/card_generator.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | import os 4 | from typing import Dict, Any, Optional, List 5 | 6 | class DatasetCardGenerator: 7 | """Handles dataset card generation for quickb datasets.""" 8 | 9 | def __init__(self, template_path: str = "src/hub_upload/template.md"): 10 | """Initialize with path to card template.""" 11 | self.template_path = template_path 12 | with open(template_path, 'r', encoding='utf-8') as f: 13 | self.template = f.read() 14 | 15 | def _get_size_category(self, num_entries: Optional[int]) -> str: 16 | """Determine the size category based on number of entries.""" 17 | if not num_entries: 18 | return "unknown" 19 | elif num_entries < 1000: 20 | return "n<1K" 21 | elif num_entries < 10000: 22 | return "1K1M" 29 | 30 | def _format_chunker_params(self, params: Dict[str, Any]) -> str: 31 | """Simple Markdown-safe parameter formatting""" 32 | return "\n ".join( 33 | f"- **{key}**: `{repr(value)}`" 34 | for key, value in params.items() 35 | if value is not None and not key.startswith('_') 36 | ) 37 | 38 | def _format_question_generation(self, 39 | model_name: str, 40 | similarity_threshold: float, 41 | num_questions: int, 42 | num_deduped: int 43 | ) -> str: 44 | """Format question generation section if enabled.""" 45 | return f"""### Question Generation 46 | - **Model**: {model_name} 47 | - **Deduplication threshold**: {similarity_threshold} 48 | - **Results**: 49 | - Total questions generated: {num_questions} 50 | - Questions after deduplication: {num_deduped}""" 51 | 52 | def _format_dataset_structure(self, has_chunks: bool, has_questions: bool) -> str: 53 | """Format dataset structure section based on available configurations.""" 54 | if has_questions: 55 | return """### Dataset Structure 56 | - `anchor`: The generated question 57 | - `positive`: The text chunk containing the answer 58 | - `question_id`: Unique identifier for the question 59 | - `chunk_id`: Reference to the source chunk""" 60 | else: 61 | return """### Dataset Structure 62 | This dataset contains the following fields: 63 | 64 | - `text`: The content of each text chunk 65 | - `source`: The source file path for the chunk 66 | - `id`: Unique identifier for each chunk""" 67 | 68 | def _format_chunking_section(self, 69 | chunker_name: str, 70 | chunker_params: Dict[str, Any], 71 | num_chunks: int, 72 | avg_chunk_size: float, 73 | num_files: int 74 | ) -> str: 75 | """Format chunking section if enabled.""" 76 | return f"""### Chunking Configuration 77 | - **Chunker**: {chunker_name} 78 | - **Parameters**: 79 | {self._format_chunker_params(chunker_params)} 80 | 81 | ### Dataset Statistics 82 | - Total chunks: {num_chunks:,} 83 | - Average chunk size: {avg_chunk_size:.1f} words 84 | - Source files: {num_files}""" 85 | 86 | def generate_card(self, 87 | dataset_name: str, 88 | chunker_name: Optional[str] = None, 89 | chunker_params: Optional[Dict[str, Any]] = None, 90 | num_chunks: Optional[int] = None, 91 | avg_chunk_size: Optional[float] = None, 92 | num_files: Optional[int] = None, 93 | question_generation: Optional[Dict[str, Any]] = None 94 | ) -> str: 95 | """Generate a dataset card with the provided information.""" 96 | 97 | # Load knowledgebase data to determine size category 98 | size_category = self._get_size_category(num_chunks) 99 | 100 | # Determine components and tags 101 | has_chunks = all(x is not None for x in [chunker_name, chunker_params, num_chunks, avg_chunk_size, num_files]) 102 | has_questions = question_generation is not None 103 | 104 | gen_tag = "\n- question-generation" if has_questions else "" 105 | 106 | # Format sections based on available data 107 | chunking_section = "" 108 | if has_chunks: 109 | chunking_section = self._format_chunking_section( 110 | chunker_name=chunker_name, 111 | chunker_params=chunker_params, 112 | num_chunks=num_chunks, 113 | avg_chunk_size=avg_chunk_size, 114 | num_files=num_files 115 | ) 116 | 117 | qg_section = "" 118 | if has_questions: 119 | qg_section = self._format_question_generation( 120 | model_name=question_generation["model_name"], 121 | similarity_threshold=question_generation["similarity_threshold"], 122 | num_questions=question_generation["num_questions"], 123 | num_deduped=question_generation["num_deduped"] 124 | ) 125 | 126 | # Generate dataset structure section 127 | dataset_structure = self._format_dataset_structure(has_chunks, has_questions) 128 | 129 | # Fill template 130 | return self.template.format( 131 | dataset_name=dataset_name, 132 | gen_tag=gen_tag, 133 | size_category=size_category, 134 | chunker_section=chunking_section, 135 | question_generation=qg_section, 136 | dataset_structure=dataset_structure 137 | ) -------------------------------------------------------------------------------- /src/hub_upload/dataset_pusher.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from pathlib import Path 5 | from typing import Optional, Dict, Any 6 | from datasets import Dataset 7 | from huggingface_hub import create_repo, upload_file, repo_exists 8 | from .card_generator import DatasetCardGenerator 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | class DatasetPusher: 13 | """Handles uploading datasets to the Hugging Face Hub.""" 14 | 15 | def __init__(self, username: str, token: Optional[str] = None): 16 | """Initialize with HF credentials.""" 17 | self.username = username 18 | self.token = token or os.getenv("HF_TOKEN") 19 | 20 | if not self.token: 21 | raise ValueError("No Hugging Face token provided or found in environment") 22 | 23 | self.card_generator = DatasetCardGenerator() 24 | 25 | def _repository_exists(self, repo_id: str) -> bool: 26 | """Check if repository already exists.""" 27 | try: 28 | return repo_exists(repo_id, repo_type="dataset") 29 | except Exception as e: 30 | logger.error(f"Error checking repository: {str(e)}") 31 | return False 32 | 33 | def _load_json_file(self, file_path: str) -> list: 34 | """Load JSON file and ensure it's a list of records.""" 35 | try: 36 | with open(file_path, 'r', encoding='utf-8') as f: 37 | data = json.load(f) 38 | if not isinstance(data, list): 39 | raise ValueError(f"Expected JSON array in {file_path}") 40 | return data 41 | except Exception as e: 42 | logger.error(f"Error loading {file_path}: {str(e)}") 43 | raise 44 | 45 | def _calculate_dataset_stats(self, knowledgebase_data: list) -> Dict[str, Any]: 46 | """Calculate statistics for dataset card.""" 47 | text_lengths = [len(item['text'].split()) for item in knowledgebase_data] 48 | unique_sources = len(set(item['source'] for item in knowledgebase_data)) 49 | 50 | return { 51 | 'num_chunks': len(knowledgebase_data), 52 | 'avg_chunk_size': sum(text_lengths) / len(text_lengths) if text_lengths else 0, 53 | 'num_files': unique_sources 54 | } 55 | 56 | def push_dataset( 57 | self, 58 | hub_dataset_id: str, 59 | knowledgebase_path: Optional[str] = None, 60 | chunker_info: Optional[Dict[str, Any]] = None, 61 | train_path: Optional[str] = None, 62 | question_gen_info: Optional[Dict[str, Any]] = None, 63 | private: bool = True 64 | ) -> None: 65 | """Push dataset to the Hugging Face Hub, overwriting existing data.""" 66 | try: 67 | # Create repository if it doesn't exist 68 | if not self._repository_exists(hub_dataset_id): 69 | create_repo( 70 | hub_dataset_id, 71 | repo_type="dataset", 72 | private=private, 73 | token=self.token 74 | ) 75 | logger.info(f"Created new dataset repository: {hub_dataset_id}") 76 | else: 77 | logger.info(f"Dataset repository exists: {hub_dataset_id}") 78 | 79 | # Load and push knowledgebase if provided 80 | kb_data = None 81 | if knowledgebase_path: 82 | kb_data = self._load_json_file(knowledgebase_path) 83 | kb_dataset = Dataset.from_list(kb_data) 84 | kb_dataset.push_to_hub( 85 | hub_dataset_id, 86 | token=self.token, 87 | private=private 88 | ) 89 | logger.info(f"Pushed knowledgebase to {hub_dataset_id}") 90 | 91 | # Load and push training data if provided 92 | if train_path: 93 | train_data = self._load_json_file(train_path) 94 | train_dataset = Dataset.from_list(train_data) 95 | train_dataset.push_to_hub( 96 | hub_dataset_id, 97 | token=self.token, 98 | private=private 99 | ) 100 | logger.info(f"Pushed training data to {hub_dataset_id}") 101 | 102 | # Generate and upload README 103 | repository_name = hub_dataset_id.split('/')[-1] 104 | card_content = self.card_generator.generate_card( 105 | dataset_name=repository_name, 106 | chunker_name=chunker_info.get('chunker_name') if chunker_info else None, 107 | chunker_params=chunker_info.get('chunker_params') if chunker_info else None, 108 | num_chunks=self._calculate_dataset_stats(kb_data)['num_chunks'] if kb_data else None, 109 | avg_chunk_size=self._calculate_dataset_stats(kb_data)['avg_chunk_size'] if kb_data else None, 110 | num_files=self._calculate_dataset_stats(kb_data)['num_files'] if kb_data else None, 111 | question_generation=question_gen_info 112 | ) 113 | 114 | upload_file( 115 | path_or_fileobj=card_content.encode('utf-8'), 116 | path_in_repo="README.md", 117 | repo_id=hub_dataset_id, 118 | repo_type="dataset", 119 | token=self.token 120 | ) 121 | logger.info(f"Uploaded README.md to {hub_dataset_id}") 122 | 123 | except Exception as e: 124 | logger.error(f"Error pushing dataset to Hub: {str(e)}") 125 | raise -------------------------------------------------------------------------------- /src/hub_upload/template.md: -------------------------------------------------------------------------------- 1 | --- 2 | language: 3 | - en 4 | pretty_name: "{dataset_name}" 5 | tags: 6 | - quickb 7 | - text-chunking{gen_tag} 8 | - {size_category} 9 | task_categories: 10 | - text-generation 11 | - text-retrieval 12 | task_ids: 13 | - document-retrieval 14 | library_name: quickb 15 | --- 16 | 17 | # {dataset_name} 18 | 19 | Generated using [QuicKB](https://github.com/AdamLucek/quickb), a tool developed by [Adam Lucek](https://huggingface.co/AdamLucek). 20 | 21 | QuicKB optimizes document retrieval by creating fine-tuned knowledge bases through an end-to-end pipeline that handles document chunking, training data generation, and embedding model optimization. 22 | 23 | {chunker_section} 24 | 25 | {question_generation} 26 | 27 | {dataset_structure} -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import json 4 | import uuid 5 | from enum import Enum, auto 6 | from pathlib import Path 7 | from typing import List, Dict, Optional, Any, Literal 8 | 9 | import yaml 10 | from pydantic import BaseModel, ConfigDict, field_validator 11 | from datasets import load_dataset, Dataset 12 | 13 | from chunking import ChunkerRegistry 14 | from hub_upload.dataset_pusher import DatasetPusher 15 | from synth_dataset.question_generator import QuestionGenerator 16 | from training.train import main as train_main 17 | 18 | logger = logging.getLogger(__name__) 19 | logging.getLogger("openai").setLevel(logging.WARNING) 20 | logging.getLogger("httpcore").setLevel(logging.WARNING) 21 | logging.getLogger("httpx").setLevel(logging.WARNING) 22 | 23 | class PipelineStage(Enum): 24 | CHUNK = auto() 25 | GENERATE = auto() 26 | TRAIN = auto() 27 | 28 | class BatchSamplers(str, Enum): 29 | BATCH_SAMPLER = "batch_sampler" 30 | NO_DUPLICATES = "no_duplicates" 31 | GROUP_BY_LABEL = "group_by_label" 32 | 33 | class LiteLLMConfig(BaseModel): 34 | """Configuration for LiteLLM model and embedding settings.""" 35 | model_config = ConfigDict(extra='forbid', validate_default=True) 36 | 37 | model: Optional[str] = "openai/gpt-4o" 38 | model_api_base: Optional[str] = None 39 | embedding_model: Optional[str] = "openai/text-embedding-3-large" 40 | embedding_api_base: Optional[str] = None 41 | 42 | class QuestionGenInputConfig(BaseModel): 43 | """Configuration for question generation input dataset.""" 44 | model_config = ConfigDict(extra='forbid', validate_default=True) 45 | 46 | dataset_source: Literal["local", "hub"] = "local" 47 | knowledgebase_dataset_id: Optional[str] = None 48 | local_knowledgebase_path: Optional[str] = None 49 | 50 | class TrainInputConfig(BaseModel): 51 | """Configuration for training input datasets.""" 52 | model_config = ConfigDict(extra='forbid', validate_default=True) 53 | 54 | dataset_source: Literal["local", "hub"] = "local" 55 | train_dataset_id: Optional[str] = None 56 | knowledgebase_dataset_id: Optional[str] = None 57 | local_train_path: Optional[str] = None 58 | local_knowledgebase_path: Optional[str] = None 59 | 60 | class UploadConfig(BaseModel): 61 | """Configuration for Hugging Face Hub uploads.""" 62 | model_config = ConfigDict(extra='forbid', validate_default=True) 63 | 64 | push_to_hub: bool = False 65 | hub_private: bool = False 66 | hub_dataset_id: Optional[str] = None 67 | hub_model_id: Optional[str] = None 68 | 69 | class ChunkerConfig(BaseModel): 70 | """Configuration for text chunking.""" 71 | model_config = ConfigDict(extra='forbid', validate_default=True) 72 | 73 | chunker: str 74 | chunker_arguments: Dict[str, Any] 75 | output_path: str 76 | upload_config: Optional[UploadConfig] = None 77 | 78 | @property 79 | def litellm_config(self) -> Optional[LiteLLMConfig]: 80 | """Extract ModelConfig from chunker_arguments if present.""" 81 | if "model_config" in self.chunker_arguments: 82 | return LiteLLMConfig.model_validate(self.chunker_arguments["litellm_config"]) 83 | return None 84 | 85 | class QuestionGeneratorConfig(BaseModel): 86 | """Configuration for question generation.""" 87 | model_config = ConfigDict(extra='forbid', validate_default=True) 88 | 89 | output_path: str 90 | input_dataset_config: QuestionGenInputConfig 91 | litellm_config: Optional[LiteLLMConfig] 92 | max_workers: Optional[int] = 20 93 | deduplication_enabled: Optional[bool] = True 94 | dedup_embedding_batch_size: Optional[int] = 500 95 | similarity_threshold: Optional[float] = 0.85 96 | upload_config: Optional[UploadConfig] = None 97 | llm_calls_per_minute: Optional[int] = 15 98 | embedding_calls_per_minute: Optional[int] = 15 99 | 100 | class ModelSettings(BaseModel): 101 | """Settings for the embedding model training.""" 102 | model_config = ConfigDict(extra='forbid', validate_default=True) 103 | 104 | model_id: str 105 | matryoshka_dimensions: List[int] = [768, 512, 256, 128, 64] 106 | metric_for_best_model: str = "eval_dim_128_cosine_ndcg@10" 107 | max_seq_length: Optional[int] = 768 108 | trust_remote_code: Optional[bool] = False 109 | 110 | class TrainingArguments(BaseModel): 111 | """Arguments for model training.""" 112 | model_config = ConfigDict(extra='forbid', validate_default=True) 113 | 114 | # Required parameters 115 | output_path: str 116 | device: Optional[str] = "cuda" 117 | 118 | # Basic training parameters 119 | epochs: Optional[int] = 4 120 | batch_size: Optional[int] = 32 121 | gradient_accumulation_steps: Optional[int] = 16 122 | learning_rate: Optional[float] = 2.0e-5 123 | 124 | # Learning rate scheduler settings 125 | warmup_ratio: Optional[float] = 0.1 126 | lr_scheduler_type: Optional[str] = "cosine" 127 | 128 | # Optimizer settings 129 | optim: Optional[str] = "adamw_torch_fused" 130 | 131 | # Hardware optimization flags 132 | tf32: Optional[bool] = True 133 | bf16: Optional[bool] = True 134 | 135 | # Batch sampling strategy 136 | batch_sampler: Optional[BatchSamplers] = BatchSamplers.NO_DUPLICATES 137 | 138 | # Training and evaluation strategy 139 | eval_strategy: Optional[str] = "epoch" 140 | save_strategy: Optional[str] = "epoch" 141 | logging_steps: Optional[int] = 10 142 | save_total_limit: Optional[int] = 3 143 | load_best_model_at_end: Optional[bool] = True 144 | 145 | # Reporting 146 | report_to: Optional[str] = "none" 147 | 148 | class TrainingConfig(BaseModel): 149 | """Configuration for model training.""" 150 | model_config = ConfigDict(extra='forbid', validate_default=True) 151 | 152 | model_settings: ModelSettings 153 | training_arguments: TrainingArguments 154 | train_dataset_config: TrainInputConfig 155 | upload_config: Optional[UploadConfig] = None 156 | 157 | class PipelineConfig(BaseModel): 158 | """Main configuration for the QuicKB pipeline.""" 159 | model_config = ConfigDict(extra='forbid', validate_default=True) 160 | 161 | pipeline: Dict[str, str] 162 | hub_username: Optional[str] = None 163 | hub_token: Optional[str] = None 164 | path_to_knowledgebase: Optional[str] 165 | chunker_config: Optional[ChunkerConfig] = None 166 | question_generation: Optional[QuestionGeneratorConfig] = None 167 | training: Optional[TrainingConfig] = None 168 | 169 | def load_dataset_from_local(file_path: str) -> List[Dict[str, Any]]: 170 | """Load dataset from a local JSON file.""" 171 | try: 172 | with open(file_path, "r", encoding="utf-8") as f: 173 | data = json.load(f) 174 | if not isinstance(data, list): 175 | raise ValueError(f"Expected JSON array in {file_path}") 176 | return data 177 | except FileNotFoundError: 178 | logger.error(f"Dataset file not found: {file_path}") 179 | raise 180 | except json.JSONDecodeError: 181 | logger.error(f"Error decoding JSON in {file_path}") 182 | raise 183 | 184 | def load_dataset_from_hub(hub_dataset_id: str) -> List[Dict[str, Any]]: 185 | """Load dataset from Hugging Face Hub using default config.""" 186 | try: 187 | logger.info(f"Loading dataset from Hub: {hub_dataset_id}") 188 | dataset = load_dataset(hub_dataset_id, split="train") 189 | if dataset: 190 | return dataset.to_list() 191 | else: 192 | logger.error(f"No data found in dataset: {hub_dataset_id}") 193 | return [] 194 | except Exception as e: 195 | logger.error(f"Error loading dataset from Hub: {hub_dataset_id}. Error: {e}") 196 | raise 197 | 198 | 199 | def load_pipeline_config(config_path: str | Path = "config.yaml") -> PipelineConfig: 200 | """Load and validate pipeline configuration.""" 201 | 202 | config_path = Path(config_path) 203 | if not config_path.exists(): 204 | raise FileNotFoundError(f"Config file not found: {config_path}") 205 | 206 | try: 207 | with open(config_path, 'r', encoding='utf-8') as f: 208 | config_data = yaml.safe_load(f) 209 | 210 | return PipelineConfig.model_validate(config_data) 211 | except Exception as e: 212 | logger.error(f"Error loading config from {config_path}: {str(e)}") 213 | raise 214 | 215 | def process_chunks(config: PipelineConfig) -> List[Dict[str, Any]]: 216 | """Process documents into chunks and optionally upload to Hub.""" 217 | 218 | # Get chunker class 219 | chunker_class = ChunkerRegistry.get_chunker(config.chunker_config.chunker) 220 | args = config.chunker_config.chunker_arguments.copy() 221 | chunker = chunker_class(**args) 222 | 223 | logger.info(f"Initialized Chunker: {config.chunker_config.chunker}") 224 | 225 | # Process files 226 | base_path = Path(config.path_to_knowledgebase) 227 | results = [] 228 | total_chunks = 0 229 | 230 | for file_path in base_path.rglob('*.txt'): 231 | try: 232 | with open(file_path, 'r', encoding='utf-8') as f: 233 | text = f.read() 234 | chunks = chunker.split_text(text) 235 | source_path = str(file_path.relative_to(base_path)) 236 | 237 | for chunk in chunks: 238 | results.append({ 239 | "id": str(uuid.uuid4()), 240 | "text": chunk, 241 | "source": source_path 242 | }) 243 | 244 | logger.info(f"Created {len(chunks)} chunks from {file_path}") 245 | total_chunks += len(chunks) 246 | 247 | except Exception as e: 248 | logger.error(f"Error processing {file_path}: {str(e)}") 249 | continue 250 | 251 | logger.info(f"Created {total_chunks} chunks in total") 252 | 253 | # Save results 254 | output_path = Path(config.chunker_config.output_path) 255 | output_path.parent.mkdir(parents=True, exist_ok=True) 256 | 257 | with open(output_path, 'w', encoding='utf-8') as f: 258 | json.dump(results, f, indent=2, ensure_ascii=False) 259 | 260 | # Handle upload if configured 261 | if (config.hub_username and 262 | config.chunker_config.upload_config and 263 | config.chunker_config.upload_config.push_to_hub): 264 | try: 265 | pusher = DatasetPusher( 266 | username=config.hub_username, 267 | token=config.hub_token 268 | ) 269 | 270 | repository_id = (config.chunker_config.upload_config.hub_dataset_id 271 | or f"{config.hub_username}/{Path(config.chunker_config.output_path).stem}") 272 | 273 | chunker_info = { 274 | 'chunker_name': config.chunker_config.chunker, 275 | 'chunker_params': config.chunker_config.chunker_arguments 276 | } 277 | 278 | pusher.push_dataset( 279 | hub_dataset_id=repository_id, 280 | knowledgebase_path=config.chunker_config.output_path, 281 | chunker_info=chunker_info, 282 | private=config.chunker_config.upload_config.hub_private 283 | ) 284 | logger.info(f"Successfully uploaded chunks to Hub: {repository_id}") 285 | except Exception as e: 286 | logger.error(f"Failed to upload chunks to Hub: {str(e)}") 287 | 288 | return results 289 | 290 | def generate_questions( 291 | config: PipelineConfig, 292 | kb_dataset: List[Dict[str, Any]] 293 | ) -> tuple[List[Dict[str, Any]], Dict[str, int]]: 294 | """Generate questions and optionally upload to Hub.""" 295 | 296 | if not config.question_generation: 297 | raise ValueError("Question generation config is required but not provided") 298 | 299 | generator = QuestionGenerator( 300 | prompt_path="src/prompts/question_generation.txt", 301 | llm_model=config.question_generation.litellm_config.model, 302 | embedding_model=config.question_generation.litellm_config.embedding_model, 303 | dedup_enabled=config.question_generation.deduplication_enabled, 304 | similarity_threshold=config.question_generation.similarity_threshold, 305 | max_workers=config.question_generation.max_workers, 306 | model_api_base=config.question_generation.litellm_config.model_api_base, 307 | embedding_api_base=config.question_generation.litellm_config.embedding_api_base, 308 | embedding_batch_size=config.question_generation.dedup_embedding_batch_size, 309 | llm_calls_per_minute=config.question_generation.llm_calls_per_minute, 310 | embedding_calls_per_minute=config.question_generation.embedding_calls_per_minute 311 | ) 312 | 313 | # Get unique texts and build a text-to-id mapping 314 | text_to_chunk_map = {} 315 | for item in kb_dataset: 316 | text_val = item["text"] 317 | if text_val not in text_to_chunk_map: 318 | text_to_chunk_map[text_val] = [] 319 | text_to_chunk_map[text_val].append(item["id"]) 320 | 321 | unique_texts = list(text_to_chunk_map.keys()) 322 | logger.info(f"Found {len(unique_texts)} unique chunks") 323 | 324 | # Generate questions 325 | questions = generator.generate_for_chunks(unique_texts) 326 | logger.info(f"Generated {len(questions)} questions after deduplication") 327 | 328 | # Track metrics 329 | metrics = { 330 | "num_questions_original": sum(len(generator._question_cache[chunk]) for chunk in generator._question_cache), 331 | "num_questions_deduped": len(questions) 332 | } 333 | logger.info(f"Question generation metrics: {metrics}") 334 | 335 | # Create training records 336 | train_records = [] 337 | skipped_questions = 0 338 | for q in questions: 339 | chunk_text = q.get("chunk_text") 340 | if not chunk_text: 341 | skipped_questions += 1 342 | continue 343 | 344 | chunk_ids = text_to_chunk_map.get(chunk_text, []) 345 | if not chunk_ids: 346 | skipped_questions += 1 347 | logger.warning(f"Could not find chunk_id for question: {q['question'][:100]}...") 348 | continue 349 | 350 | # Create a record for each matching chunk 351 | for chunk_id in chunk_ids: 352 | train_records.append({ 353 | "anchor": q["question"], 354 | "positive": chunk_text, 355 | "question_id": q["id"], 356 | "chunk_id": chunk_id 357 | }) 358 | 359 | logger.info(f"Created {len(train_records)} training records (skipped {skipped_questions} questions)") 360 | 361 | # Save results 362 | if config.question_generation.output_path: 363 | output_path = Path(config.question_generation.output_path) 364 | output_path.parent.mkdir(parents=True, exist_ok=True) 365 | 366 | try: 367 | with open(output_path, 'w', encoding='utf-8') as f: 368 | json.dump(train_records, f, indent=2, ensure_ascii=False) 369 | logger.info(f"Saved training records to {output_path}") 370 | except Exception as e: 371 | logger.error(f"Failed to save training records: {str(e)}") 372 | 373 | # Handle upload if configured 374 | if (config.hub_username and 375 | config.question_generation.upload_config and 376 | config.question_generation.upload_config.push_to_hub): 377 | try: 378 | pusher = DatasetPusher( 379 | username=config.hub_username, 380 | token=config.hub_token 381 | ) 382 | 383 | repository_id = config.question_generation.upload_config.hub_dataset_id 384 | 385 | question_gen_info = { 386 | 'model_name': config.question_generation.litellm_config.model, 387 | 'similarity_threshold': config.question_generation.similarity_threshold, 388 | 'num_questions': metrics['num_questions_original'], 389 | 'num_deduped': metrics['num_questions_deduped'] 390 | } 391 | 392 | pusher.push_dataset( 393 | hub_dataset_id=repository_id, 394 | train_path=config.question_generation.output_path, 395 | question_gen_info=question_gen_info, 396 | private=config.question_generation.upload_config.hub_private 397 | ) 398 | logger.info(f"Successfully uploaded train dataset to Hub: {repository_id}") 399 | except Exception as e: 400 | logger.error(f"Failed to upload train dataset to Hub: {str(e)}") 401 | 402 | return train_records, metrics 403 | 404 | def train_model(config: PipelineConfig, kb_dataset: List[Dict[str, Any]], train_dataset: List[Dict[str, Any]]): 405 | """Train the embedding model.""" 406 | train_main(config, train_dataset=train_dataset, kb_dataset=kb_dataset) 407 | 408 | def upload_to_hub( 409 | config: PipelineConfig, 410 | kb_dataset: List[Dict[str, Any]], 411 | train_dataset: Optional[List[Dict[str, Any]]] = None, 412 | question_metrics: Optional[Dict[str, int]] = None 413 | ): 414 | """Upload datasets to Hugging Face Hub.""" 415 | 416 | if not config.hub_username: 417 | logger.warning("No 'hub_username' specified, skipping upload.") 418 | return 419 | 420 | try: 421 | # Initialize pusher 422 | pusher = DatasetPusher( 423 | username=config.hub_username, 424 | token=config.hub_token 425 | ) 426 | 427 | # Get repository name from output path 428 | repository_name = Path(config.output_path).stem 429 | 430 | # Collect chunker info 431 | chunker_info = { 432 | 'chunker_name': config.chunker, 433 | 'chunker_params': config.chunker_arguments 434 | } 435 | 436 | # Collect question generation info if enabled 437 | question_gen_info = None 438 | question_gen_info = { 439 | 'model_name': config.question_generation.model, 440 | 'similarity_threshold': config.deduplication.similarity_threshold, 441 | 'num_questions': question_metrics['num_questions_original'] if question_metrics else len(train_dataset), 442 | 'num_deduped': question_metrics['num_questions_deduped'] if question_metrics else len(train_dataset) 443 | } 444 | 445 | # Push dataset 446 | pusher.push_dataset( 447 | repository_name=repository_name, 448 | knowledgebase_path=config.output_path, 449 | chunker_info=chunker_info, 450 | train_path=config.question_output_path if train_dataset else None, 451 | question_gen_info=question_gen_info, 452 | private=config.hub_private 453 | ) 454 | except Exception as e: 455 | logger.error(f"Failed to upload to Hugging Face Hub: {str(e)}") 456 | 457 | def run_pipeline(config: PipelineConfig): 458 | """Run the QuicKB pipeline.""" 459 | from_stage = PipelineStage[config.pipeline["from_stage"]] 460 | to_stage = PipelineStage[config.pipeline["to_stage"]] 461 | 462 | kb_dataset = None 463 | train_dataset = None 464 | question_metrics = None 465 | 466 | # 1. CHUNK 467 | if from_stage.value <= PipelineStage.CHUNK.value <= to_stage.value: 468 | logger.info("Running CHUNK stage.") 469 | kb_dataset = process_chunks(config) 470 | else: 471 | logger.info("Skipping CHUNK stage.") 472 | 473 | # 2. GENERATE 474 | if from_stage.value <= PipelineStage.GENERATE.value <= to_stage.value: 475 | # Load knowledgebase dataset if needed for GENERATE 476 | if not kb_dataset: 477 | input_config = config.question_generation.input_dataset_config 478 | if input_config.dataset_source == "hub": 479 | logger.info(f"Loading knowledgebase dataset from Hub: {input_config.knowledgebase_dataset_id}") 480 | kb_dataset = load_dataset_from_hub(input_config.knowledgebase_dataset_id) 481 | elif input_config.dataset_source == "local": 482 | local_kb_path = input_config.local_knowledgebase_path or config.chunker_config.output_path 483 | logger.info(f"Loading knowledgebase dataset from local path: {local_kb_path}") 484 | kb_dataset = load_dataset_from_local(local_kb_path) 485 | 486 | logger.info("Running GENERATE stage.") 487 | train_dataset, question_metrics = generate_questions(config, kb_dataset) 488 | 489 | # 3. TRAIN 490 | if from_stage.value <= PipelineStage.TRAIN.value <= to_stage.value: 491 | logger.info("Running TRAIN stage.") 492 | if not config.training: 493 | raise ValueError("No training config found, cannot run TRAIN stage.") 494 | 495 | train_config = config.training.train_dataset_config 496 | 497 | # Load datasets for training if needed 498 | if train_config.dataset_source == "hub": 499 | logger.info("Loading datasets from Hub for training...") 500 | if not train_dataset: 501 | logger.info(f"Loading training dataset from Hub: {train_config.train_dataset_id}") 502 | train_dataset = load_dataset_from_hub(train_config.train_dataset_id) 503 | if not kb_dataset: 504 | logger.info(f"Loading knowledgebase dataset from Hub: {train_config.knowledgebase_dataset_id}") 505 | kb_dataset = load_dataset_from_hub(train_config.knowledgebase_dataset_id) 506 | else: # local 507 | logger.info("Loading datasets from local files for training...") 508 | if not train_dataset: 509 | train_path = train_config.local_train_path or config.question_generation.output_path 510 | logger.info(f"Loading training dataset from local path: {train_path}") 511 | train_dataset = load_dataset_from_local(train_path) 512 | if not kb_dataset: 513 | kb_path = train_config.local_knowledgebase_path or config.chunker_config.output_path 514 | logger.info(f"Loading knowledgebase dataset from local path: {kb_path}") 515 | kb_dataset = load_dataset_from_local(kb_path) 516 | 517 | if not kb_dataset or not train_dataset: 518 | raise ValueError("Failed to load required datasets for training") 519 | 520 | train_model(config, kb_dataset, train_dataset) 521 | 522 | logger.info("Pipeline complete!") 523 | 524 | if __name__ == "__main__": 525 | logging.basicConfig(level=logging.INFO) 526 | try: 527 | config = load_pipeline_config("config.yaml") 528 | run_pipeline(config) 529 | except Exception as e: 530 | logger.error(f"Fatal error: {str(e)}") 531 | raise -------------------------------------------------------------------------------- /src/prompts/question_generation.txt: -------------------------------------------------------------------------------- 1 | You are a precise question generator that creates specific, factual questions from provided text. Each question must be answerable using only the explicit information in the text. 2 | 3 | Requirements for questions: 4 | 1. Must be answerable using ONLY information explicitly stated in the text 5 | 2. Must have a single, unambiguous answer found directly in the text 6 | 3. Must focus on concrete facts, not interpretation or inference 7 | 4. Must not require external knowledge or context 8 | 5. Must be specific to a single topic or point 9 | 6. Must avoid compound questions using "and" or "or" 10 | 11 | Output Format: 12 | { 13 | "questions": [ 14 | { 15 | "id": number, 16 | "question": "text of the question", 17 | "answer_location": "exact quote from the text containing the answer", 18 | "explanation": "brief explanation of why this is a good question" 19 | } 20 | ] 21 | } 22 | 23 | Examples: 24 | 25 | Text: "The city council voted 7-2 to approve the new parking ordinance on March 15, 2024. The ordinance will increase parking meter rates from $2.00 to $3.50 per hour in the downtown district, effective July 1, 2024." 26 | 27 | Good Questions: 28 | { 29 | "questions": [ 30 | { 31 | "id": 1, 32 | "question": "What was the vote count for the parking ordinance?", 33 | "answer_location": "The city council voted 7-2 to approve the new parking ordinance", 34 | "explanation": "Asks about a specific, explicitly stated numerical fact" 35 | }, 36 | { 37 | "id": 2, 38 | "question": "What will be the new parking meter rate per hour?", 39 | "answer_location": "increase parking meter rates from $2.00 to $3.50 per hour", 40 | "explanation": "Focuses on a single, clearly stated numerical change" 41 | } 42 | ] 43 | } 44 | 45 | Bad Questions: 46 | - "Why did some council members vote against the ordinance?" (Requires interpretation) 47 | - "How will this affect local businesses?" (Not addressed in text) 48 | - "What are the current and new parking rates?" (Compound question) 49 | 50 | For the given user text, generate 4 questions following these requirements. Each question should focus on a different aspect of the text. Do not reference the excerpt. -------------------------------------------------------------------------------- /src/synth_dataset/deduplicator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import List, Dict 3 | import logging 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | class QuestionDeduplicator: 8 | def __init__(self, similarity_threshold: float = 0.85): 9 | self.similarity_threshold = similarity_threshold 10 | 11 | def _calculate_similarity_matrix(self, embeddings: List[List[float]]) -> np.ndarray: 12 | """Calculate the cosine similarity matrix for the embeddings.""" 13 | embeddings_matrix = np.array(embeddings) 14 | # Normalize the vectors 15 | norms = np.linalg.norm(embeddings_matrix, axis=1, keepdims=True) 16 | embeddings_matrix = embeddings_matrix / norms 17 | return np.dot(embeddings_matrix, embeddings_matrix.T) 18 | 19 | def _filter_similar_questions(self, similarity_matrix: np.ndarray) -> List[int]: 20 | """ 21 | Filter out questions that are too similar to others. 22 | Returns indices of questions to keep. 23 | """ 24 | n = similarity_matrix.shape[0] 25 | keep_mask = np.ones(n, dtype=bool) 26 | 27 | # For each question 28 | for i in range(n): 29 | if keep_mask[i]: 30 | # Find questions that are too similar and remove them 31 | similar_questions = np.where(similarity_matrix[i] > self.similarity_threshold)[0] 32 | # Only look at questions after the current one 33 | similar_questions = similar_questions[similar_questions > i] 34 | keep_mask[similar_questions] = False 35 | 36 | return np.where(keep_mask)[0] 37 | 38 | def deduplicate(self, questions: List[Dict], embeddings: List[List[float]]) -> List[Dict]: 39 | """ 40 | Remove duplicate and similar questions based on embedding similarity. 41 | 42 | Args: 43 | questions: List of question dictionaries 44 | embeddings: List of embedding vectors for questions 45 | 46 | Returns: 47 | List of filtered question dictionaries 48 | """ 49 | if not questions: 50 | return [] 51 | 52 | # First remove exact duplicates based on question text 53 | seen_questions = set() 54 | unique_questions = [] 55 | unique_embeddings = [] 56 | 57 | for q, emb in zip(questions, embeddings): 58 | if q["question"] not in seen_questions: 59 | seen_questions.add(q["question"]) 60 | unique_questions.append(q) 61 | unique_embeddings.append(emb) 62 | 63 | if not unique_questions: 64 | return [] 65 | 66 | # Calculate similarity matrix 67 | similarity_matrix = self._calculate_similarity_matrix(unique_embeddings) 68 | 69 | # Get indices of questions to keep 70 | keep_indices = self._filter_similar_questions(similarity_matrix) 71 | 72 | # Filter questions 73 | filtered_questions = [unique_questions[i] for i in keep_indices] 74 | 75 | logger.info( 76 | f"Deduplication: {len(questions)} -> {len(unique_questions)} -> " 77 | f"{len(filtered_questions)} questions after filtering" 78 | ) 79 | return filtered_questions -------------------------------------------------------------------------------- /src/synth_dataset/question_generator.py: -------------------------------------------------------------------------------- 1 | import threading 2 | from threading import Lock, RLock 3 | from typing import Dict, List, Optional 4 | from collections import deque 5 | import json 6 | import uuid 7 | import logging 8 | import backoff 9 | from litellm import completion, embedding 10 | from concurrent.futures import ThreadPoolExecutor, as_completed 11 | from tqdm import tqdm 12 | from .deduplicator import QuestionDeduplicator 13 | from .rate_limiter import RateLimiter 14 | 15 | logger = logging.getLogger(__name__) 16 | logging.getLogger("LiteLLM").setLevel(logging.WARNING) 17 | 18 | class QuestionGenerator: 19 | def __init__( 20 | self, 21 | prompt_path: str, 22 | api_key: str = None, 23 | llm_model: str = "openai/gpt-4o-mini", 24 | embedding_model: str = "text-embedding-3-large", 25 | dedup_enabled: bool = True, 26 | similarity_threshold: float = 0.85, 27 | max_workers: int = 20, 28 | model_api_base: str = None, 29 | embedding_api_base: str = None, 30 | embedding_batch_size: int = 500, 31 | llm_calls_per_minute: int = 15, 32 | embedding_calls_per_minute: int = 15 33 | ): 34 | # Initialize basic attributes 35 | self.api_key = api_key 36 | self.llm_model = llm_model 37 | self.embedding_model = embedding_model 38 | self.prompt = self._load_prompt(prompt_path) 39 | self.dedup_enabled = dedup_enabled 40 | self.max_workers = max_workers 41 | self.model_api_base = model_api_base 42 | self.embedding_api_base = embedding_api_base 43 | self.embedding_batch_size = embedding_batch_size 44 | 45 | # Thread safety mechanisms 46 | self._question_cache: Dict[str, List[Dict]] = {} 47 | self._cache_lock = RLock() # Use RLock for recursive locking capability 48 | self._embedding_lock = Lock() # Separate lock for embedding operations 49 | 50 | # Initialize deduplicator if enabled 51 | self.deduplicator = QuestionDeduplicator(similarity_threshold) if dedup_enabled else None 52 | 53 | # Initialize rate limiters with their own internal locks 54 | self.llm_rate_limiter = RateLimiter(llm_calls_per_minute, name="LLM") if llm_calls_per_minute is not None else None 55 | self.embedding_rate_limiter = RateLimiter(embedding_calls_per_minute, name="Embedding") if embedding_calls_per_minute is not None else None 56 | 57 | def _load_prompt(self, path: str) -> str: 58 | """Load the prompt template from a file.""" 59 | try: 60 | with open(path, 'r', encoding='utf-8') as f: 61 | return f.read() 62 | except FileNotFoundError: 63 | logger.error(f"Prompt template file not found: {path}") 64 | raise 65 | except IOError as e: 66 | logger.error(f"Error reading prompt template: {str(e)}") 67 | raise 68 | 69 | def _get_from_cache(self, chunk: str) -> Optional[List[Dict]]: 70 | """Thread-safe cache retrieval.""" 71 | with self._cache_lock: 72 | return self._question_cache.get(chunk) 73 | 74 | def _add_to_cache(self, chunk: str, questions: List[Dict]) -> None: 75 | """Thread-safe cache addition.""" 76 | with self._cache_lock: 77 | self._question_cache[chunk] = questions 78 | 79 | @backoff.on_exception( 80 | backoff.expo, 81 | Exception, 82 | max_tries=3, 83 | max_time=30 # Maximum total time to try 84 | ) 85 | def _generate(self, chunk: str) -> str: 86 | """Generate questions for a single chunk with rate limiting.""" 87 | # Wait for rate limit if needed 88 | if self.llm_rate_limiter: 89 | self.llm_rate_limiter.wait_if_needed() 90 | 91 | completion_kwargs = { 92 | "model": self.llm_model, 93 | "messages": [ 94 | {"role": "system", "content": self.prompt}, 95 | {"role": "user", "content": f"Text: {chunk}"} 96 | ], 97 | "temperature": 0.7, 98 | "api_key": self.api_key, 99 | "timeout": 10 # Add timeout for API calls 100 | } 101 | 102 | if self.model_api_base: 103 | completion_kwargs["api_base"] = self.model_api_base 104 | 105 | response = completion(**completion_kwargs) 106 | return response.choices[0].message.content 107 | 108 | def generate_for_chunk(self, chunk: str) -> List[Dict]: 109 | """Generate questions for a single chunk with caching.""" 110 | # Check cache first 111 | cached_questions = self._get_from_cache(chunk) 112 | if cached_questions is not None: 113 | return cached_questions 114 | 115 | try: 116 | response = self._generate(chunk) 117 | questions = json.loads(response)["questions"] 118 | 119 | # Process questions 120 | processed_questions = [] 121 | for q in questions: 122 | q.update({ 123 | "id": str(uuid.uuid4()), 124 | "chunk_text": chunk, 125 | }) 126 | q.pop("explanation", None) 127 | processed_questions.append(q) 128 | 129 | # Add to cache 130 | self._add_to_cache(chunk, processed_questions) 131 | return processed_questions 132 | except Exception as e: 133 | logger.error(f"Error generating questions: {str(e)}") 134 | return [] 135 | 136 | def _process_embeddings_batch(self, questions_batch: List[str]) -> List[List[float]]: 137 | """Process a batch of questions for embeddings with thread safety.""" 138 | with self._embedding_lock: 139 | if self.embedding_rate_limiter: 140 | self.embedding_rate_limiter.wait_if_needed() 141 | 142 | embedding_kwargs = { 143 | "model": self.embedding_model, 144 | "input": questions_batch, 145 | "api_key": self.api_key, 146 | "timeout": 10 147 | } 148 | 149 | if self.embedding_api_base: 150 | embedding_kwargs["api_base"] = self.embedding_api_base 151 | 152 | response = embedding(**embedding_kwargs) 153 | return [data["embedding"] for data in response.data] 154 | 155 | def generate_for_chunks(self, chunks: List[str]) -> List[Dict]: 156 | """Generate questions for multiple chunks with thread safety.""" 157 | results = [] 158 | uncached_chunks = [] 159 | 160 | # Check cache first 161 | for chunk in chunks: 162 | cached_results = self._get_from_cache(chunk) 163 | if cached_results is not None: 164 | results.extend(cached_results) 165 | else: 166 | uncached_chunks.append(chunk) 167 | 168 | if uncached_chunks: 169 | with ThreadPoolExecutor(max_workers=self.max_workers) as executor: 170 | future_to_chunk = { 171 | executor.submit(self.generate_for_chunk, chunk): chunk 172 | for chunk in uncached_chunks 173 | } 174 | 175 | with tqdm(total=len(uncached_chunks), desc="Generating questions") as pbar: 176 | for future in as_completed(future_to_chunk): 177 | try: 178 | questions = future.result() 179 | results.extend(questions) 180 | except Exception as e: 181 | chunk = future_to_chunk[future] 182 | logger.error(f"Error processing chunk {chunk[:50]}...: {str(e)}") 183 | pbar.update(1) 184 | 185 | if self.dedup_enabled and self.deduplicator and results: 186 | # Process embeddings in batches 187 | questions_text = [q["question"] for q in results] 188 | all_embeddings = [] 189 | 190 | # Calculate total number of batches for the progress bar 191 | num_batches = (len(questions_text) + self.embedding_batch_size - 1) // self.embedding_batch_size 192 | 193 | # Add progress bar for embedding batches 194 | with tqdm(total=num_batches, desc="Processing embeddings", unit="batch") as pbar: 195 | for i in range(0, len(questions_text), self.embedding_batch_size): 196 | batch = questions_text[i:i + self.embedding_batch_size] 197 | try: 198 | batch_embeddings = self._process_embeddings_batch(batch) 199 | all_embeddings.extend(batch_embeddings) 200 | pbar.update(1) 201 | 202 | # Optional: add batch statistics to progress bar 203 | pbar.set_postfix({ 204 | 'batch_size': len(batch), 205 | 'total_embedded': len(all_embeddings) 206 | }) 207 | 208 | except Exception as e: 209 | logger.error(f"Error during embedding batch {i//self.embedding_batch_size}: {str(e)}") 210 | return results # Return un-deduplicated results on error 211 | 212 | # Deduplicate with thread safety 213 | with self._cache_lock: 214 | results = self.deduplicator.deduplicate(results, all_embeddings) 215 | 216 | return results -------------------------------------------------------------------------------- /src/synth_dataset/rate_limiter.py: -------------------------------------------------------------------------------- 1 | from threading import Lock 2 | from time import time, sleep 3 | from collections import deque 4 | import logging 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | from typing import Optional 9 | 10 | class RateLimiter: 11 | """Thread-safe rate limiter using a rolling window.""" 12 | 13 | def __init__(self, calls_per_minute: int, name: str = "default"): 14 | """ 15 | Initialize rate limiter. 16 | 17 | Args: 18 | calls_per_minute: Maximum number of calls allowed per minute 19 | name: Name for this rate limiter instance (for logging) 20 | """ 21 | self.calls_per_minute = calls_per_minute 22 | self.name = name 23 | self.window_size = 60 # 1 minute in seconds 24 | self.timestamps = deque(maxlen=calls_per_minute) 25 | self.lock = Lock() 26 | 27 | def _clean_old_timestamps(self, current_time: float) -> None: 28 | """Remove timestamps older than the window size.""" 29 | cutoff_time = current_time - self.window_size 30 | while self.timestamps and self.timestamps[0] < cutoff_time: 31 | self.timestamps.popleft() 32 | 33 | def wait_if_needed(self) -> None: 34 | """ 35 | Check if rate limit is reached and wait if necessary. 36 | Thread-safe implementation. 37 | """ 38 | with self.lock: 39 | current_time = time() 40 | self._clean_old_timestamps(current_time) 41 | 42 | if len(self.timestamps) >= self.calls_per_minute: 43 | # Calculate required wait time 44 | oldest_timestamp = self.timestamps[0] 45 | wait_time = oldest_timestamp + self.window_size - current_time 46 | 47 | if wait_time > 0: 48 | logger.debug(f"{self.name} rate limiter: Waiting {wait_time:.2f} seconds") 49 | sleep(wait_time) 50 | current_time = time() # Update current time after sleep 51 | 52 | # Add new timestamp 53 | self.timestamps.append(current_time) 54 | 55 | def acquire(self) -> None: 56 | """Alias for wait_if_needed for more intuitive API.""" 57 | self.wait_if_needed() 58 | 59 | def remaining_calls(self) -> int: 60 | """Return number of remaining calls allowed in current window.""" 61 | with self.lock: 62 | self._clean_old_timestamps(time()) 63 | return self.calls_per_minute - len(self.timestamps) -------------------------------------------------------------------------------- /src/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALucek/QuicKB/288c62bd5024520388d05ece4a841900abb6d93d/src/training/__init__.py -------------------------------------------------------------------------------- /src/training/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import logging 4 | import torch 5 | import yaml 6 | from pathlib import Path 7 | from typing import Dict, List, Any, Optional 8 | from datasets import load_dataset, Dataset 9 | from sentence_transformers import ( 10 | SentenceTransformer, 11 | SentenceTransformerModelCardData, 12 | SentenceTransformerTrainingArguments, 13 | SentenceTransformerTrainer, 14 | ) 15 | from sentence_transformers.losses import MultipleNegativesRankingLoss, MatryoshkaLoss 16 | from sentence_transformers.evaluation import InformationRetrievalEvaluator, SequentialEvaluator 17 | from sentence_transformers.util import cos_sim 18 | from sentence_transformers.training_args import BatchSamplers 19 | from huggingface_hub import login 20 | 21 | logging.basicConfig(level=logging.INFO) 22 | logger = logging.getLogger(__name__) 23 | logging.getLogger("sentence_transformers").setLevel(logging.WARNING) 24 | logging.getLogger("transformers").setLevel(logging.WARNING) 25 | 26 | def load_main_config(config_path_or_obj): 27 | """Load configuration from path or use provided config object.""" 28 | if isinstance(config_path_or_obj, (str, bytes, os.PathLike)): 29 | with open(config_path_or_obj, "r", encoding="utf-8") as f: 30 | return yaml.safe_load(f) 31 | return config_path_or_obj.model_dump() 32 | 33 | def validate_dataset_consistency(kb_dataset: List[Dict[str, Any]], train_dataset: List[Dict[str, Any]]) -> bool: 34 | """Validate that all chunk IDs referenced in training data exist in knowledgebase.""" 35 | logger.info("Validating dataset consistency...") 36 | 37 | # Get all chunk IDs from knowledgebase 38 | kb_ids = {str(item['id']) for item in kb_dataset} 39 | logger.info(f"Found {len(kb_ids)} unique chunks in knowledgebase") 40 | 41 | # Get all chunk IDs referenced in training data 42 | train_chunk_ids = {str(item['chunk_id']) for item in train_dataset} 43 | logger.info(f"Found {len(train_chunk_ids)} unique chunk references in training data") 44 | 45 | # Find missing chunks 46 | missing_chunks = train_chunk_ids - kb_ids 47 | if missing_chunks: 48 | logger.error(f"Found {len(missing_chunks)} chunk IDs in training data that don't exist in knowledgebase") 49 | logger.error(f"First few missing IDs: {list(missing_chunks)[:5]}") 50 | return False 51 | 52 | logger.info("Dataset consistency validation passed!") 53 | return True 54 | 55 | def build_evaluation_structures(kb_dataset, test_dataset, kb_id_field="id", kb_text_field="text"): 56 | corpus = {row[kb_id_field]: row[kb_text_field] for row in kb_dataset} 57 | queries = {row["id"]: row["anchor"] for row in test_dataset} 58 | relevant_docs = {} 59 | for row in test_dataset: 60 | q_id = row["id"] 61 | if "global_chunk_id" not in row: 62 | logger.warning(f"Missing 'global_chunk_id': {row}") 63 | continue 64 | if q_id not in relevant_docs: 65 | relevant_docs[q_id] = [] 66 | relevant_docs[q_id].append(row["global_chunk_id"]) 67 | return corpus, queries, relevant_docs 68 | 69 | def format_evaluation_results(title: str, results: dict, dim_list: list, metrics: list = None): 70 | if metrics is None: 71 | metrics = [ 72 | "ndcg@10", "mrr@10", "map@100", 73 | "accuracy@1", "accuracy@3", "accuracy@5", "accuracy@10", 74 | "precision@1", "precision@3", "precision@5", "precision@10", 75 | "recall@1", "recall@3", "recall@5", "recall@10", 76 | ] 77 | 78 | # Calculate required widths for alignment 79 | max_dim_length = max(len(str(dim)) for dim in dim_list) 80 | value_format_length = 5 # '0.000' is 5 characters 81 | column_width = max(max_dim_length, value_format_length) 82 | metric_width = max(len(metric) for metric in metrics) if metrics else 10 83 | 84 | # Prepare dimension headers (left-aligned strings) 85 | dim_header = " ".join(f"{str(dim):<{column_width}}" for dim in dim_list) 86 | 87 | # Create header line and dynamically determine separator length 88 | header_line = f"{'Metric':{metric_width}} {dim_header}" 89 | separator = "-" * len(header_line) 90 | 91 | output = [ 92 | f"\n{title}", 93 | separator, 94 | header_line, 95 | separator 96 | ] 97 | 98 | # Populate each metric row (values right-aligned) 99 | for metric in metrics: 100 | values = [] 101 | for dim in dim_list: 102 | key = f"dim_{dim}_cosine_{metric}" 103 | val = results.get(key, 0.0) 104 | values.append(f"{val:>{column_width}.3f}") # Right-align values 105 | metric_line = f"{metric:{metric_width}} {' '.join(values)}" 106 | output.append(metric_line) 107 | 108 | # Final row 109 | output.append(separator) 110 | 111 | return "\n".join(output) 112 | 113 | def run_baseline_eval(model, evaluator, dim_list): 114 | results = evaluator(model) 115 | print(format_evaluation_results("Before Training (Baseline) Results", results, dim_list)) 116 | return results 117 | 118 | def run_final_eval(model, evaluator, dim_list): 119 | results = evaluator(model) 120 | print(format_evaluation_results("After Training (Fine-Tuned) Results", results, dim_list)) 121 | return results 122 | 123 | def save_metrics_to_file(before: dict, after: dict, dim_list: list, path="metrics_comparison.txt"): 124 | metrics = [ 125 | "ndcg@10", "mrr@10", "map@100", 126 | "accuracy@1", "accuracy@3", "accuracy@5", "accuracy@10", 127 | "precision@1", "precision@3", "precision@5", "precision@10", 128 | "recall@1", "recall@3", "recall@5", "recall@10", 129 | ] 130 | 131 | def format_dimensions(dim_list): 132 | return [f"{dim:,}" for dim in dim_list] 133 | 134 | def write_table(f, title, headers, rows, value_formatter): 135 | # Calculate column widths 136 | is_percentage = "%" in rows[0][1] 137 | 138 | # For metric column 139 | metric_width = max(len(row[0]) for row in rows) 140 | metric_width = max(metric_width, len(headers[0])) 141 | 142 | # For value columns, ensure we account for the maximum width including +/- and % 143 | dim_widths = [] 144 | for col_idx in range(1, len(headers)): 145 | col_values = [row[col_idx] for row in rows] 146 | max_width = max(len(str(val)) for val in col_values) 147 | header_width = len(headers[col_idx]) 148 | dim_widths.append(max(max_width, header_width)) 149 | 150 | col_widths = [metric_width] + dim_widths 151 | 152 | # Build header with consistent spacing 153 | header = headers[0].ljust(col_widths[0]) 154 | for i, h in enumerate(headers[1:], 1): 155 | header += " │ " + h.center(col_widths[i]) 156 | 157 | # Build separator 158 | separator = "─" * col_widths[0] 159 | for w in col_widths[1:]: 160 | separator += "─┼─" + "─" * w 161 | 162 | # Write table 163 | f.write(f"\n{title}\n") 164 | f.write("-" * len(title) + "\n") 165 | f.write(f"{header}\n{separator}\n") 166 | 167 | # Write rows with consistent spacing 168 | for row in rows: 169 | line = row[0].ljust(col_widths[0]) 170 | for i, val in enumerate(row[1:], 1): 171 | if is_percentage: 172 | # Right-align percentage values 173 | line += " │ " + val.rjust(col_widths[i]) 174 | else: 175 | line += " │ " + value_formatter(val, col_widths[i]) 176 | f.write(line + "\n") 177 | 178 | with open(path, "w", encoding="utf-8") as f: 179 | # Main title 180 | f.write("Model Performance Metrics Comparison\n") 181 | f.write("====================================\n") 182 | 183 | # Format dimensions with thousands separators 184 | dim_strs = format_dimensions(dim_list) 185 | 186 | # Prepare data tables 187 | baseline_rows = [] 188 | finetuned_rows = [] 189 | delta_rows = [] 190 | pct_change_rows = [] 191 | 192 | for metric in metrics: 193 | bl_vals = [] 194 | ft_vals = [] 195 | dt_vals = [] 196 | pc_vals = [] 197 | 198 | for dim in dim_list: 199 | key = f"dim_{dim}_cosine_{metric}" 200 | b_val = before.get(key, 0.0) 201 | a_val = after.get(key, 0.0) 202 | delta = a_val - b_val 203 | pct_change = (delta / b_val * 100) if b_val != 0 else 0.0 204 | 205 | bl_vals.append(f"{b_val:.3f}") 206 | ft_vals.append(f"{a_val:.3f}") 207 | dt_vals.append(f"{delta:+.3f}") 208 | pc_vals.append(f"{pct_change:+.1f}%") 209 | 210 | baseline_rows.append([metric] + bl_vals) 211 | finetuned_rows.append([metric] + ft_vals) 212 | delta_rows.append([metric] + dt_vals) 213 | pct_change_rows.append([metric] + pc_vals) 214 | 215 | # Write tables 216 | headers = ["Metric"] + dim_strs 217 | num_formatter = lambda val, w: f"{val:>{w}}" 218 | 219 | write_table(f, "Baseline Performance", headers, baseline_rows, num_formatter) 220 | write_table(f, "Fine-Tuned Performance", headers, finetuned_rows, num_formatter) 221 | write_table(f, "Absolute Changes (Δ)", headers, delta_rows, num_formatter) 222 | write_table(f, "Percentage Changes", headers, pct_change_rows, num_formatter) 223 | 224 | def select_device(preferred_device: Optional[str] = None) -> str: 225 | # Handle explicit preference first 226 | if preferred_device: 227 | if preferred_device == "cuda" and torch.cuda.is_available(): 228 | logger.info("Using CUDA GPU as requested") 229 | return "cuda" 230 | elif preferred_device == "mps" and hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): 231 | logger.info("Using Apple Silicon MPS as requested") 232 | return "mps" 233 | elif preferred_device == "cpu": 234 | logger.info("Using CPU as requested") 235 | return "cpu" 236 | else: 237 | logger.warning(f"Requested device '{preferred_device}' not available, falling back to auto-detection") 238 | 239 | # Auto-detection (prioritize GPU > MPS > CPU) 240 | if torch.cuda.is_available(): 241 | logger.info("CUDA GPU detected and selected") 242 | return "cuda" 243 | elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): 244 | logger.info("Apple Silicon MPS detected and selected") 245 | return "mps" 246 | else: 247 | logger.info("No accelerator detected, using CPU") 248 | return "cpu" 249 | 250 | def main(config, train_dataset: List[Dict[str, Any]], kb_dataset: List[Dict[str, Any]]): 251 | """Main training function.""" 252 | if not hasattr(config, 'training') or not config.training: 253 | raise ValueError("Training configuration is required but not provided") 254 | 255 | # Validate dataset consistency before proceeding 256 | if not validate_dataset_consistency(kb_dataset, train_dataset): 257 | raise ValueError("Dataset consistency check failed - chunks and training data are mismatched. " 258 | "This usually happens when chunks have been regenerated after question generation.") 259 | 260 | # Convert lists to dataset objects if needed 261 | train_dataset_full = Dataset.from_list(train_dataset) 262 | 263 | if "id" not in train_dataset_full.column_names: 264 | train_dataset_full = train_dataset_full.add_column("id", list(range(len(train_dataset_full)))) 265 | if "chunk_id" in train_dataset_full.column_names: 266 | train_dataset_full = train_dataset_full.rename_column("chunk_id", "global_chunk_id") 267 | train_dataset_full = train_dataset_full.shuffle() 268 | dataset_split = train_dataset_full.train_test_split(test_size=0.1) 269 | train_dataset = dataset_split["train"] 270 | test_dataset = dataset_split["test"] 271 | logger.info(f"Train size: {len(train_dataset)} | Test size: {len(test_dataset)}") 272 | 273 | # Build evaluation structures 274 | corpus, queries, relevant_docs = build_evaluation_structures( 275 | kb_dataset=kb_dataset, 276 | test_dataset=test_dataset, 277 | kb_id_field="id", 278 | kb_text_field="text" 279 | ) 280 | 281 | # Setup evaluators 282 | dim_list = config.training.model_settings.matryoshka_dimensions 283 | evaluators = [] 284 | for d in dim_list: 285 | evaluators.append( 286 | InformationRetrievalEvaluator( 287 | queries=queries, 288 | corpus=corpus, 289 | relevant_docs=relevant_docs, 290 | name=f"dim_{d}", 291 | score_functions={"cosine": cos_sim}, 292 | truncate_dim=d 293 | ) 294 | ) 295 | evaluator = SequentialEvaluator(evaluators) 296 | 297 | # Initialize base model and run baseline evaluation 298 | device = select_device(config.training.training_arguments.device) 299 | base_model = SentenceTransformer(config.training.model_settings.model_id, device=device, trust_remote_code=config.training.model_settings.trust_remote_code) 300 | base_model.max_seq_length = config.training.model_settings.max_seq_length 301 | baseline_results = run_baseline_eval(base_model, evaluator, dim_list) 302 | 303 | logger.info("Re-initializing for training.") 304 | model = SentenceTransformer( 305 | config.training.model_settings.model_id, 306 | device=device, 307 | trust_remote_code=config.training.model_settings.trust_remote_code, 308 | model_kwargs={"attn_implementation": "sdpa"}, 309 | model_card_data=SentenceTransformerModelCardData( 310 | language="en", 311 | license="apache-2.0", 312 | model_name="Fine-tuned with [QuicKB](https://github.com/ALucek/QuicKB)", 313 | ), 314 | ) 315 | model.max_seq_length = config.training.model_settings.max_seq_length 316 | 317 | # Setup loss functions 318 | base_loss = MultipleNegativesRankingLoss(model) 319 | train_loss = MatryoshkaLoss( 320 | model=model, 321 | loss=base_loss, 322 | matryoshka_dims=dim_list 323 | ) 324 | 325 | # Setup training arguments 326 | args = SentenceTransformerTrainingArguments( 327 | output_dir=config.training.training_arguments.output_path, 328 | num_train_epochs=config.training.training_arguments.epochs, 329 | per_device_train_batch_size=config.training.training_arguments.batch_size, 330 | gradient_accumulation_steps=config.training.training_arguments.gradient_accumulation_steps, 331 | learning_rate=config.training.training_arguments.learning_rate, 332 | warmup_ratio=config.training.training_arguments.warmup_ratio, 333 | lr_scheduler_type=config.training.training_arguments.lr_scheduler_type, 334 | optim=config.training.training_arguments.optim, 335 | tf32=config.training.training_arguments.tf32, 336 | bf16=config.training.training_arguments.bf16, 337 | batch_sampler=config.training.training_arguments.batch_sampler.value, 338 | eval_strategy=config.training.training_arguments.eval_strategy, 339 | save_strategy=config.training.training_arguments.save_strategy, 340 | logging_steps=config.training.training_arguments.logging_steps, 341 | save_total_limit=config.training.training_arguments.save_total_limit, 342 | load_best_model_at_end=config.training.training_arguments.load_best_model_at_end, 343 | metric_for_best_model=config.training.model_settings.metric_for_best_model, 344 | report_to=config.training.training_arguments.report_to, 345 | ) 346 | 347 | # Prepare final training dataset and trainer 348 | final_train_dataset = train_dataset.select_columns(["anchor", "positive"]) 349 | trainer = SentenceTransformerTrainer( 350 | model=model, 351 | args=args, 352 | train_dataset=final_train_dataset, 353 | loss=train_loss, 354 | evaluator=evaluator, 355 | ) 356 | 357 | # Train model 358 | logger.info("Starting training...") 359 | trainer.train() 360 | trainer.save_model() 361 | 362 | # Evaluate fine-tuned model 363 | fine_tuned_model = SentenceTransformer(config.training.training_arguments.output_path, device=device, trust_remote_code=config.training.model_settings.trust_remote_code) 364 | fine_tuned_model.max_seq_length = config.training.model_settings.max_seq_length 365 | final_results = run_final_eval(fine_tuned_model, evaluator, dim_list) 366 | 367 | # Save metrics 368 | save_metrics_to_file( 369 | baseline_results, 370 | final_results, 371 | dim_list, 372 | path=f"{config.training.training_arguments.output_path}/metrics_comparison.txt" 373 | ) 374 | 375 | # Handle model upload if configured 376 | if (config.training.upload_config and 377 | config.training.upload_config.push_to_hub): 378 | 379 | if not config.hub_token and not os.getenv("HF_TOKEN"): 380 | logger.warning("No HF_TOKEN in env or config, attempting login anyway.") 381 | 382 | if not config.training.upload_config.hub_model_id: 383 | logger.warning("No hub_model_id specified, skipping upload") 384 | else: 385 | logger.info(f"Pushing model to HF Hub: {config.training.upload_config.hub_model_id}") 386 | trainer.model.push_to_hub( 387 | config.training.upload_config.hub_model_id, 388 | exist_ok=True, 389 | private=config.training.upload_config.hub_private 390 | ) 391 | logger.info("Upload complete!") 392 | 393 | logger.info("Training pipeline finished.") 394 | --------------------------------------------------------------------------------