├── .gitignore
├── LICENSE
├── README.md
├── config.yaml
├── examples
└── chromadb_integration.ipynb
├── qkb_logo.png
├── requirements.txt
├── setup.py
└── src
├── chunking
├── __init__.py
├── base_chunker.py
├── cluster_semantic_chunker.py
├── fixed_token_chunker.py
├── kamradt_modified_chunker.py
├── llm_semantic_chunker.py
├── recursive_token_chunker.py
├── registry.py
└── utils.py
├── hub_upload
├── card_generator.py
├── dataset_pusher.py
└── template.md
├── main.py
├── prompts
└── question_generation.txt
├── synth_dataset
├── deduplicator.py
├── question_generator.py
└── rate_limiter.py
└── training
├── __init__.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | quickb_venv/
3 | knowledgebase/
4 | quickb-env/
5 | kb_env/
6 | testing/
7 | __pycache__/
8 | *.py[cod]
9 | output/
10 | *.egg-info/
11 | .env
12 | .pyc
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Adam Łucek
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # QuicKB
2 |
3 |
4 |
5 | Optimize Document Retrieval with Fine-Tuned KnowledgeBases
6 |
7 | ## Overview
8 |
9 | QuicKB optimizes document retrieval by creating fine-tuned knowledgebases through an end-to-end machine learning pipeline that handles document chunking, synthetic data generation, and embedding model training.
10 |
11 | ## Key Features
12 |
13 | ### Document Chunking
14 | Implement state-of-the-art chunking strategies based on [ChromaDB's research](https://research.trychroma.com/evaluating-chunking):
15 | - **Semantic Approaches**:
16 | - LLM-guided splits for natural language breakpoints
17 | - Content-aware clustering for thematic coherence
18 | - Hybrid semantic-token methods for balanced chunking
19 | - **Token/Character-Based Methods**:
20 | - Recursive chunking with custom separator hierarchies
21 | - Precise length-based splitting
22 |
23 | ### Training Data Generation
24 | Automatically create domain-specific training datasets:
25 | - Generate synthetic question-answer pairs from your content
26 | - Intelligent deduplication using semantic similarity
27 | - Parallel processing for large-scale document sets
28 | - Support for local and cloud LLMs via [LiteLLM](https://docs.litellm.ai/docs/)
29 |
30 | ### Embedding Optimization
31 | Fine-tune embedding models for your specific domain:
32 | - Custom training using [Sentence Transformers](https://sbert.net/)
33 | - Cross-platform training support (CUDA, MPS, CPU)
34 | - Dimension reduction techniques with Matryoshka Representation Learning
35 | - Comprehensive evaluation across multiple metrics
36 | - Detailed performance comparisons with baseline models
37 |
38 | ## Installation
39 |
40 | ```bash
41 | git clone https://github.com/ALucek/QuicKB.git
42 | cd QuicKB
43 |
44 | python -m venv quickb-env
45 | source quickb-env/bin/activate # Windows: quickb-env\Scripts\activate
46 |
47 | pip install -e .
48 | ```
49 |
50 | ## Usage
51 |
52 | 1. Prepare your text documents in a directory
53 | 2. Configure the pipeline in `config.yaml`
54 | 3. Run:
55 | ```bash
56 | python src/main.py
57 | ```
58 | 4. Enjoy!
59 |
60 | ## Configuration Guide
61 |
62 | The pipeline is controlled through a single `config.yaml` file. Here's a complete configuration example:
63 |
64 | ```yaml
65 | # ========== Pipeline Control ==========
66 | pipeline:
67 | from_stage: "CHUNK" # Options: "CHUNK", "GENERATE", "TRAIN"
68 | to_stage: "TRAIN" # Run pipeline from from_stage to to_stage
69 |
70 | # ========== Global Settings ==========
71 | path_to_knowledgebase: "./knowledgebase" # Directory containing source documents
72 | hub_username: "AdamLucek" # Hugging Face username
73 | hub_token: null # Optional: Use HF_TOKEN env variable instead
74 |
75 | # ========== Document Chunking ==========
76 | chunker_config:
77 | output_path: "./output/knowledgebase.json"
78 |
79 | # Chunker Config:
80 | chunker: "RecursiveTokenChunker"
81 |
82 | chunker_arguments:
83 | chunk_size: 400
84 | chunk_overlap: 0
85 | length_type: "character"
86 | separators: ["\n\n", "\n", ".", "?", "!", " ", ""]
87 | keep_separator: true
88 | is_separator_regex: false
89 |
90 | # Optional: Push chunks to Hugging Face Hub
91 | upload_config:
92 | push_to_hub: true
93 | hub_private: false
94 | hub_dataset_id: "AdamLucek/quickb-kb"
95 |
96 | # ========== Question Generation ==========
97 | question_generation:
98 | output_path: "./output/train_data.json"
99 |
100 | # LLM/Embedding Configuration
101 | litellm_config:
102 | model: "openai/gpt-4o-mini"
103 | model_api_base: null # Optional: Custom API endpoint
104 |
105 | embedding_model: "text-embedding-3-large"
106 | embedding_api_base: null # Optional: Custom embedding endpoint
107 |
108 | # Input dataset settings
109 | input_dataset_config:
110 | dataset_source: "local" # Options: "local", "hub"
111 | local_knowledgebase_path: "./output/knowledgebase.json"
112 | # Hub alternative:
113 | # knowledgebase_dataset_id: "username/quickb-kb"
114 |
115 | # Performance settings
116 | max_workers: 150 # Parallel question generation
117 | llm_calls_per_minute: null # null = no limit
118 | embedding_calls_per_minute: null # null = no limit
119 |
120 | # Question deduplication
121 | deduplication_enabled: true
122 | dedup_embedding_batch_size: 2048 # Batch size for embedding calculation
123 | similarity_threshold: 0.85 # Semantic Similarity Threshold
124 |
125 | # Optional: Push training data to Hub
126 | upload_config:
127 | push_to_hub: true
128 | hub_private: false
129 | hub_dataset_id: "AdamLucek/quickb-qa"
130 |
131 | # ========== Model Training ==========
132 | training:
133 | # Model configuration
134 | model_settings:
135 | # Base model:
136 | model_id: "nomic-ai/modernbert-embed-base"
137 |
138 | # Matryoshka dimensions (must be descending)
139 | matryoshka_dimensions: [768, 512, 256, 128, 64]
140 | metric_for_best_model: "eval_dim_128_cosine_ndcg@10"
141 | max_seq_length: 1024
142 | trust_remote_code: false
143 |
144 | # Training data configuration
145 | train_dataset_config:
146 | dataset_source: "local" # Options: "local", "hub"
147 | local_train_path: "./output/train_data.json"
148 | local_knowledgebase_path: "./output/knowledgebase.json"
149 | # Hub alternatives:
150 | # train_dataset_id: "AdamLucek/quickb-qa"
151 | # knowledgebase_dataset_id: "AdamLucek/quickb-kb"
152 |
153 | # Training hyperparameters
154 | training_arguments:
155 | output_path: "./output/modernbert_quickb"
156 | device: "cuda" # Options: "cuda", "mps", "cpu"
157 | epochs: 4
158 | batch_size: 32
159 | gradient_accumulation_steps: 16
160 | learning_rate: 2.0e-5
161 | warmup_ratio: 0.1
162 | lr_scheduler_type: "cosine"
163 | optim: "adamw_torch_fused"
164 | tf32: true
165 | bf16: true
166 | batch_sampler: "no_duplicates" # Options: "batch_sampler", "no_duplicates", "group_by_label"
167 | eval_strategy: "epoch"
168 | save_strategy: "epoch"
169 | logging_steps: 10
170 | save_total_limit: 3
171 | load_best_model_at_end: true
172 | report_to: "none"
173 |
174 | # Optional: Push trained model to Hub
175 | upload_config:
176 | push_to_hub: true
177 | hub_private: false
178 | hub_model_id: "AdamLucek/modernbert-embed-quickb"
179 | ```
180 |
181 | ### Alternative Chunker Configurations
182 |
183 | 1. **Fixed Token Chunker**
184 | ```yaml
185 | chunker_config:
186 | output_path: "./output/fixed_token_chunks.json"
187 | chunker: "FixedTokenChunker"
188 | chunker_arguments:
189 | encoding_name: "cl100k_base"
190 | chunk_size: 400
191 | chunk_overlap: 50
192 | length_type: "token"
193 | ```
194 |
195 | 2. **Cluster Semantic Chunker**
196 | ```yaml
197 | chunker_config:
198 | output_path: "./output/semantic_clusters.json"
199 | chunker: "ClusterSemanticChunker"
200 | chunker_arguments:
201 | max_chunk_size: 400 # Max tokens after clustering
202 | min_chunk_size: 50 # Initial split size
203 | length_type: "token"
204 | litellm_config:
205 | embedding_model: "text-embedding-3-large" # Required for semantic clustering
206 | ```
207 |
208 | 3. **LLM Semantic Chunker**
209 | ```yaml
210 | chunker_config:
211 | output_path: "./output/llm_semantic_chunks.json"
212 | chunker: "LLMSemanticChunker"
213 | chunker_arguments:
214 | length_type: "token"
215 | litellm_config:
216 | model: "openai/gpt-4o" # LLM for split decisions
217 | ```
218 |
219 | 4. **Kamradt Modified Semantic Chunker**
220 | ```yaml
221 | chunker_config:
222 | output_path: "./output/kamradt_chunks.json"
223 | chunker: "KamradtModifiedChunker"
224 | chunker_arguments:
225 | avg_chunk_size: 400 # Target average size
226 | min_chunk_size: 50 # Minimum initial split
227 | length_type: "token"
228 | litellm_config:
229 | embedding_model: "text-embedding-3-large" # For similarity calculations
230 | ```
231 | ### Device Support
232 |
233 | QuicKB supports multiple compute devices for model training:
234 |
235 | - **CUDA**: NVIDIA GPUs (default and recommended for best performance)
236 | - **MPS**: Apple Silicon on macOS
237 | - **CPU**: Available on all systems, but significantly slower for training
238 |
239 | To specify your preferred device, use the `device` parameter in your training configuration:
240 |
241 | ```yaml
242 | training:
243 | training_arguments:
244 | device: "cuda" # Options: "cuda", "mps", "cpu"
245 | ```
246 |
247 | **Note**: The default configuration and hyperparameters are optimized for CUDA GPUs. When using MPS or CPU, you may need to install other torch versions or adjust the following parameters for optimal performance:
248 |
249 | - Reduce `batch_size` (e.g., 8-16 for MPS, 4-8 for CPU)
250 | - Reduce `gradient_accumulation_steps` for CPU training
251 | - Set `bf16: false` and `tf32: false` for CPU training
252 | - Use other optimizers like `adamw_torch`
253 | - Consider using smaller base models
254 |
255 | Device selection is automatic if you don't specify a device (prioritizing CUDA > MPS > CPU), but explicit configuration is recommended for reproducibility.
256 |
257 | ### LiteLLM Integration
258 |
259 | QuicKB uses [LiteLLM](https://docs.litellm.ai/docs/) for flexible model provider integration, allowing you to use any supported LLM or embedding provider for question generation and semantic chunking. This enables both cloud-based and local model deployment.
260 |
261 | The LiteLLM configuration is managed through the `litellm_config` section in both the chunking and question generation configurations:
262 |
263 | ```yaml
264 | litellm_config:
265 | model: "openai/gpt-4o" # LLM model identifier
266 | model_api_base: null # Optional API base URL for LLM
267 | embedding_model: "text-embedding-3-large" # Embedding model identifier
268 | embedding_api_base: null # Optional API base URL for embeddings
269 | ```
270 |
271 | **Using Local Models**:
272 |
273 | 1. Set up an OpenAI API compatible endpoint (e.g., Ollama, vLLM)
274 | 2. Configure the `model_api_base` or `embedding_api_base` in your config
275 | 3. Use the appropriate model identifier format
276 |
277 | Example local setup:
278 |
279 | ```yaml
280 | # For question generation
281 | question_generation:
282 | litellm_config:
283 | model: "local/llama-7b"
284 | model_api_base: "http://localhost:8000"
285 | embedding_model: "local/bge-small"
286 | embedding_api_base: "http://localhost:8000"
287 |
288 | # For semantic chunkers
289 | chunker_config:
290 | chunker: "ClusterSemanticChunker" # or other semantic chunkers
291 | chunker_arguments:
292 | litellm_config:
293 | model: "local/llama-7b"
294 | model_api_base: "http://localhost:8000"
295 | embedding_model: "local/bge-small"
296 | embedding_api_base: "http://localhost:8000"
297 | ```
298 |
299 | For more details on setting up local models and supported providers, refer to the [LiteLLM documentation](https://docs.litellm.ai/docs/providers).
300 |
301 | ### Hugging Face Hub Integration
302 |
303 | QuicKB integrates directly with [Hugging Face](https://huggingface.co/) for storing and loading datasets and models. Each pipeline stage can optionally push its outputs to the Hub, and subsequent stages can load data directly from there.
304 |
305 | The Hub integration is configured through `upload_config` sections and dataset source settings:
306 |
307 | ```yaml
308 | # Example Hub configuration for chunking
309 | chunker_config:
310 | upload_config:
311 | push_to_hub: true
312 | hub_private: false
313 | hub_dataset_id: "username/quickb-kb"
314 |
315 | # Loading data from Hub for question generation
316 | question_generation:
317 | input_dataset_config:
318 | dataset_source: "hub"
319 | knowledgebase_dataset_id: "username/quickb-kb"
320 |
321 | upload_config:
322 | push_to_hub: true
323 | hub_private: false
324 | hub_dataset_id: "username/quickb-qa"
325 |
326 | # Loading from Hub for training
327 | training:
328 | train_dataset_config:
329 | dataset_source: "hub"
330 | train_dataset_id: "username/quickb-qa"
331 | knowledgebase_dataset_id: "username/quickb-kb"
332 |
333 | upload_config:
334 | push_to_hub: true
335 | hub_private: false
336 | hub_model_id: "username/modernbert-embed-quickb"
337 | ```
338 | **Authentication**
339 | - Set your Hugging Face token using the `HF_TOKEN` environment variable or specify it in the config using `hub_token`
340 |
341 | ## Output Format
342 |
343 | ### Knowledgebase Dataset
344 | ```json
345 | {
346 | "id": "3fa85f64-5717-4562-b3fc-2c963f66afa6",
347 | "text": "Section 12.1: Termination clauses...",
348 | "source": "docs/contracts/2024/Q1-agreement.txt"
349 | }
350 | ```
351 |
352 | ### Training Dataset
353 | ```json
354 | {
355 | "anchor": "What are the termination notice requirements?",
356 | "positive": "Section 12.1: Either party may terminate...",
357 | "question_id": "a3b8c7d0-e83a-4b5c-b12d-3f7a8d4c9e1b",
358 | "chunk_id": "3fa85f64-5717-4562-b3fc-2c963f66afa6"
359 | }
360 | ```
361 |
362 | ## Environment Variables
363 |
364 | - `_API_KEY`: Required for LLM embeddings, question generation, and chunking
365 | - `HF_TOKEN`: Required for Hugging Face Hub uploads and downloads
366 |
367 | ## Citations
368 |
369 | QuicKB builds upon these foundational works:
370 |
371 | ChromaDB: [Evaluating Chunking Strategies for Retrieval](https://research.trychroma.com/evaluating-chunking)
372 | ```bibtex
373 | @techreport{smith2024evaluating,
374 | title = {Evaluating Chunking Strategies for Retrieval},
375 | author = {Smith, Brandon and Troynikov, Anton},
376 | year = {2024},
377 | month = {July},
378 | institution = {Chroma},
379 | url = {https://research.trychroma.com/evaluating-chunking},
380 | }
381 | ```
382 |
383 | Sentence Transformers
384 | ```bibtext
385 | @inproceedings{reimers-2019-sentence-bert,
386 | title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
387 | author = "Reimers, Nils and Gurevych, Iryna",
388 | booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
389 | month = "11",
390 | year = "2019",
391 | publisher = "Association for Computational Linguistics",
392 | url = "https://arxiv.org/abs/1908.10084",
393 | }
394 | ```
395 |
396 | ## Contributing
397 |
398 | Contributions welcome! Please feel free to submit a Pull Request.
399 |
400 | Todo List:
401 |
402 | - Cleaner handling of config arguments and validation at pipeline stages
403 | - pydantic v2 fields warning
404 | - Custom Model Card (Using base from SBERT currently)
405 |
406 | ## License
407 |
408 | MIT License - See [LICENSE](LICENSE)
409 |
410 | [](https://opensource.org/licenses/MIT)
411 |
--------------------------------------------------------------------------------
/config.yaml:
--------------------------------------------------------------------------------
1 | # ========== Pipeline Control ==========
2 | pipeline:
3 | from_stage: "CHUNK" # Options: "CHUNK", "GENERATE", "TRAIN"
4 | to_stage: "TRAIN" # Run pipeline from from_stage to to_stage
5 |
6 | # ========== Global Settings ==========
7 | path_to_knowledgebase: "./knowledgebase" # Directory containing source documents
8 | hub_username: "AdamLucek" # Hugging Face username
9 | hub_token: null # Optional: Use HF_TOKEN env variable instead
10 |
11 | # ========== Document Chunking ==========
12 | chunker_config:
13 | output_path: "./output/knowledgebase.json"
14 |
15 | # Chunker Config:
16 | chunker: "RecursiveTokenChunker"
17 |
18 | chunker_arguments:
19 | chunk_size: 400
20 | chunk_overlap: 0
21 | length_type: "character"
22 | separators: ["\n\n", "\n", ".", "?", "!", " ", ""]
23 | keep_separator: true
24 | is_separator_regex: false
25 |
26 | # Optional: Push chunks to Hugging Face Hub
27 | upload_config:
28 | push_to_hub: true
29 | hub_private: false
30 | hub_dataset_id: "AdamLucek/quickb-kb"
31 |
32 | # ========== Question Generation ==========
33 | question_generation:
34 | output_path: "./output/train_data.json"
35 |
36 | # LLM/Embedding Configuration
37 | litellm_config:
38 | model: "openai/gpt-4o-mini"
39 | model_api_base: null # Optional: Custom API endpoint
40 |
41 | embedding_model: "text-embedding-3-large"
42 | embedding_api_base: null # Optional: Custom embedding endpoint
43 |
44 | # Input dataset settings
45 | input_dataset_config:
46 | dataset_source: "local" # Options: "local", "hub"
47 | local_knowledgebase_path: "./output/knowledgebase.json"
48 | # Hub alternative:
49 | # knowledgebase_dataset_id: "AdamLucek/quickb-kb"
50 |
51 | # Performance settings
52 | max_workers: 150 # Parallel question generation
53 | llm_calls_per_minute: null # null = no limit
54 | embedding_calls_per_minute: null # null = no limit
55 |
56 | # Question deduplication
57 | deduplication_enabled: true
58 | dedup_embedding_batch_size: 2048 # Batch size for embedding calculation
59 | similarity_threshold: 0.85 # Semantic Similarity Threshold
60 |
61 | # Optional: Push training data to Hub
62 | upload_config:
63 | push_to_hub: true
64 | hub_private: false
65 | hub_dataset_id: "AdamLucek/quickb-qa"
66 |
67 | # ========== Model Training ==========
68 | training:
69 | # Model configuration
70 | model_settings:
71 | # Base model:
72 | model_id: "nomic-ai/modernbert-embed-base"
73 |
74 | # Matryoshka dimensions (must be descending)
75 | matryoshka_dimensions: [768, 512, 256, 128, 64]
76 | metric_for_best_model: "eval_dim_128_cosine_ndcg@10"
77 | max_seq_length: 1024
78 | trust_remote_code: false
79 |
80 | # Training data configuration
81 | train_dataset_config:
82 | dataset_source: "local" # Options: "local", "hub"
83 | local_train_path: "./output/train_data.json"
84 | local_knowledgebase_path: "./output/knowledgebase.json"
85 | # Hub alternatives:
86 | # train_dataset_id: "AdamLucek/quickb-qa"
87 | # knowledgebase_dataset_id: "AdamLucek/quickb-kb"
88 |
89 | # Training hyperparameters
90 | training_arguments:
91 | output_path: "./output/modernbert_quickb"
92 | device: "cuda" # Options: "cuda", "mps", "cpu"
93 | epochs: 4
94 | batch_size: 32
95 | gradient_accumulation_steps: 16
96 | learning_rate: 2.0e-5
97 | warmup_ratio: 0.1
98 | lr_scheduler_type: "cosine"
99 | optim: "adamw_torch_fused"
100 | tf32: true
101 | bf16: true
102 | batch_sampler: "no_duplicates" # Options: "batch_sampler", "no_duplicates", "group_by_label"
103 | eval_strategy: "epoch"
104 | save_strategy: "epoch"
105 | logging_steps: 10
106 | save_total_limit: 3
107 | load_best_model_at_end: true
108 | report_to: "none"
109 |
110 | # Optional: Push trained model to Hub
111 | upload_config:
112 | push_to_hub: true
113 | hub_private: false
114 | hub_model_id: "AdamLucek/modernbert-embed-quickb"
--------------------------------------------------------------------------------
/examples/chromadb_integration.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "2abfd742-3602-4d97-894b-28f5f1edfd5a",
6 | "metadata": {},
7 | "source": [
8 | "# QuicKB Integration - ChromaDB\n",
9 | "\n",
10 | "This example notebook shows you how to implement your knowledgebase and fine-tuned model with ChromaDB"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 1,
16 | "id": "6dae01d8-72f9-47a8-af45-a57785beec6b",
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "# Install required packages if needed:\n",
21 | "# !pip install chromadb datasets sentence-transformers\n",
22 | "\n",
23 | "import chromadb\n",
24 | "from chromadb.utils import embedding_functions\n",
25 | "from datasets import load_dataset\n",
26 | "\n",
27 | "from sentence_transformers import SentenceTransformer"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": 2,
33 | "id": "230a0175-978f-4fe0-9aa5-b971f9047a6e",
34 | "metadata": {},
35 | "outputs": [
36 | {
37 | "data": {
38 | "application/vnd.jupyter.widget-view+json": {
39 | "model_id": "baa45dc1fbf046aa95efef6c123a6109",
40 | "version_major": 2,
41 | "version_minor": 0
42 | },
43 | "text/plain": [
44 | "modules.json: 0%| | 0.00/349 [00:00, ?B/s]"
45 | ]
46 | },
47 | "metadata": {},
48 | "output_type": "display_data"
49 | },
50 | {
51 | "data": {
52 | "application/vnd.jupyter.widget-view+json": {
53 | "model_id": "4ad23ee7a7c449b8b8d5654432bae8b8",
54 | "version_major": 2,
55 | "version_minor": 0
56 | },
57 | "text/plain": [
58 | "config_sentence_transformers.json: 0%| | 0.00/205 [00:00, ?B/s]"
59 | ]
60 | },
61 | "metadata": {},
62 | "output_type": "display_data"
63 | },
64 | {
65 | "data": {
66 | "application/vnd.jupyter.widget-view+json": {
67 | "model_id": "5027993452f04034b96e47158ae6cb51",
68 | "version_major": 2,
69 | "version_minor": 0
70 | },
71 | "text/plain": [
72 | "README.md: 0%| | 0.00/30.9k [00:00, ?B/s]"
73 | ]
74 | },
75 | "metadata": {},
76 | "output_type": "display_data"
77 | },
78 | {
79 | "data": {
80 | "application/vnd.jupyter.widget-view+json": {
81 | "model_id": "cdf30bcf43fc4ab195e1a98f741e944a",
82 | "version_major": 2,
83 | "version_minor": 0
84 | },
85 | "text/plain": [
86 | "sentence_bert_config.json: 0%| | 0.00/54.0 [00:00, ?B/s]"
87 | ]
88 | },
89 | "metadata": {},
90 | "output_type": "display_data"
91 | },
92 | {
93 | "data": {
94 | "application/vnd.jupyter.widget-view+json": {
95 | "model_id": "860cd9757831417b8434e6a2972c31cd",
96 | "version_major": 2,
97 | "version_minor": 0
98 | },
99 | "text/plain": [
100 | "config.json: 0%| | 0.00/1.30k [00:00, ?B/s]"
101 | ]
102 | },
103 | "metadata": {},
104 | "output_type": "display_data"
105 | },
106 | {
107 | "data": {
108 | "application/vnd.jupyter.widget-view+json": {
109 | "model_id": "083b6c9b7d6b435b9e236e4de8cd5ad2",
110 | "version_major": 2,
111 | "version_minor": 0
112 | },
113 | "text/plain": [
114 | "model.safetensors: 0%| | 0.00/596M [00:00, ?B/s]"
115 | ]
116 | },
117 | "metadata": {},
118 | "output_type": "display_data"
119 | },
120 | {
121 | "data": {
122 | "application/vnd.jupyter.widget-view+json": {
123 | "model_id": "7f1c066ddf0344b493d1ea07e4e54a59",
124 | "version_major": 2,
125 | "version_minor": 0
126 | },
127 | "text/plain": [
128 | "tokenizer_config.json: 0%| | 0.00/20.8k [00:00, ?B/s]"
129 | ]
130 | },
131 | "metadata": {},
132 | "output_type": "display_data"
133 | },
134 | {
135 | "data": {
136 | "application/vnd.jupyter.widget-view+json": {
137 | "model_id": "1272308a822644639e60b5dd191d000d",
138 | "version_major": 2,
139 | "version_minor": 0
140 | },
141 | "text/plain": [
142 | "tokenizer.json: 0%| | 0.00/3.58M [00:00, ?B/s]"
143 | ]
144 | },
145 | "metadata": {},
146 | "output_type": "display_data"
147 | },
148 | {
149 | "data": {
150 | "application/vnd.jupyter.widget-view+json": {
151 | "model_id": "86c76ce0d6a94821b0972a6a34d79c42",
152 | "version_major": 2,
153 | "version_minor": 0
154 | },
155 | "text/plain": [
156 | "special_tokens_map.json: 0%| | 0.00/694 [00:00, ?B/s]"
157 | ]
158 | },
159 | "metadata": {},
160 | "output_type": "display_data"
161 | },
162 | {
163 | "data": {
164 | "application/vnd.jupyter.widget-view+json": {
165 | "model_id": "1f29f0acabad4633a0b8546dfe2f16c3",
166 | "version_major": 2,
167 | "version_minor": 0
168 | },
169 | "text/plain": [
170 | "1_Pooling%2Fconfig.json: 0%| | 0.00/296 [00:00, ?B/s]"
171 | ]
172 | },
173 | "metadata": {},
174 | "output_type": "display_data"
175 | }
176 | ],
177 | "source": [
178 | "# Load model from Hugging Face\n",
179 | "model_id = \"AdamLucek/modernbert-embed-quickb\" # Replace with your model ID\n",
180 | "model = SentenceTransformer(model_id)"
181 | ]
182 | },
183 | {
184 | "cell_type": "code",
185 | "execution_count": 3,
186 | "id": "80550160-cca5-4aea-a1c3-6909e1daaf45",
187 | "metadata": {},
188 | "outputs": [],
189 | "source": [
190 | "# Create embedding function\n",
191 | "ef = embedding_functions.SentenceTransformerEmbeddingFunction(\n",
192 | " model_name=model_id,\n",
193 | " device=\"cuda\" if model.device.type == \"cuda\" else \"cpu\"\n",
194 | ")"
195 | ]
196 | },
197 | {
198 | "cell_type": "code",
199 | "execution_count": 4,
200 | "id": "e24e6709-b8e6-49d4-82d5-019348738503",
201 | "metadata": {},
202 | "outputs": [],
203 | "source": [
204 | "# Initialize ChromaDB\n",
205 | "client = chromadb.PersistentClient(path=\"./chroma_quickb\")"
206 | ]
207 | },
208 | {
209 | "cell_type": "code",
210 | "execution_count": 5,
211 | "id": "31653903-2adb-4b8f-b8b4-447cabfd71e4",
212 | "metadata": {},
213 | "outputs": [],
214 | "source": [
215 | "# Create collection\n",
216 | "collection = client.get_or_create_collection(\n",
217 | " name=\"quickb_collection\",\n",
218 | " embedding_function=ef\n",
219 | ")"
220 | ]
221 | },
222 | {
223 | "cell_type": "code",
224 | "execution_count": 6,
225 | "id": "7f1f9a1a-d9dd-4981-99bf-f11a52ce01e5",
226 | "metadata": {},
227 | "outputs": [
228 | {
229 | "data": {
230 | "application/vnd.jupyter.widget-view+json": {
231 | "model_id": "c5af1a282aeb47b0ba8ecf48471e3ac2",
232 | "version_major": 2,
233 | "version_minor": 0
234 | },
235 | "text/plain": [
236 | "README.md: 0%| | 0.00/1.13k [00:00, ?B/s]"
237 | ]
238 | },
239 | "metadata": {},
240 | "output_type": "display_data"
241 | }
242 | ],
243 | "source": [
244 | "# Load dataset from Hugging Face\n",
245 | "dataset = load_dataset(\"AdamLucek/quickb-kb\") # Replace with your dataset ID\n",
246 | "chunks = dataset['train']"
247 | ]
248 | },
249 | {
250 | "cell_type": "code",
251 | "execution_count": 7,
252 | "id": "f2b585d9-f408-407c-bf24-43b86f9e1c21",
253 | "metadata": {},
254 | "outputs": [
255 | {
256 | "name": "stderr",
257 | "output_type": "stream",
258 | "text": [
259 | "Compiling the model with `torch.compile` and using a `torch.cpu` device is not supported. Falling back to non-compiled mode.\n"
260 | ]
261 | },
262 | {
263 | "name": "stdout",
264 | "output_type": "stream",
265 | "text": [
266 | "Added 2807 documents\n"
267 | ]
268 | }
269 | ],
270 | "source": [
271 | "# Add documents to ChromaDB\n",
272 | "batch_size = 500\n",
273 | "for i in range(0, len(chunks), batch_size):\n",
274 | " batch = chunks[i:i + batch_size]\n",
275 | " \n",
276 | " collection.add(\n",
277 | " documents=batch['text'],\n",
278 | " metadatas=[{'source': doc} for doc in batch['source']],\n",
279 | " ids=[str(id) for id in batch['id']]\n",
280 | " )\n",
281 | "\n",
282 | "print(f\"Added {collection.count()} documents\")"
283 | ]
284 | },
285 | {
286 | "cell_type": "code",
287 | "execution_count": 8,
288 | "id": "f20e2500-df33-4c9b-bedb-1bd4e71e9dcb",
289 | "metadata": {},
290 | "outputs": [
291 | {
292 | "name": "stdout",
293 | "output_type": "stream",
294 | "text": [
295 | "\n",
296 | "Result 1\n",
297 | "Distance: 0.3171\n",
298 | "Source: Al-Hamim_v_Star_2024-12-26.txt\n",
299 | "Text: self-represented litigants alike have relied on them to draft court filings\n",
300 | "\n",
301 | "Result 2\n",
302 | "Distance: 1.0947\n",
303 | "Source: Al-Hamim_v_Star_2024-12-26.txt\n",
304 | "Text: . Some self-represented litigants, including plaintiff, Alim Al-Hamim, have relied on GAI tools to draft court filings, only to discover later to their chagrin that their filings contained hallucinations. Al-Hamim’s opening brief in this appeal contained hallucinations, as well as bona fide legal citations\n",
305 | "\n",
306 | "Result 3\n",
307 | "Distance: 1.1034\n",
308 | "Source: Al-Hamim_v_Star_2024-12-26.txt\n",
309 | "Text: .) For these reasons, individuals using the current generation of general-purpose GAI tools to assist with legal research and drafting must be aware of the tools’ propensity to generate outputs 18 containing fictitious legal authorities and must ensure that such fictitious citations do not appear in any court filing\n"
310 | ]
311 | }
312 | ],
313 | "source": [
314 | "# Example query\n",
315 | "results = collection.query(\n",
316 | " query_texts=[\"Who has relied on them to draft court filings?\"],\n",
317 | " n_results=3\n",
318 | ")\n",
319 | "\n",
320 | "# Print results\n",
321 | "for i, (doc, distance, metadata) in enumerate(zip(\n",
322 | " results['documents'][0],\n",
323 | " results['distances'][0],\n",
324 | " results['metadatas'][0]\n",
325 | ")):\n",
326 | " print(f\"\\nResult {i+1}\")\n",
327 | " print(f\"Distance: {distance:.4f}\")\n",
328 | " print(f\"Source: {metadata['source']}\")\n",
329 | " print(f\"Text: {doc}\")"
330 | ]
331 | },
332 | {
333 | "cell_type": "code",
334 | "execution_count": null,
335 | "id": "502d6648-dd71-4e05-a0b5-bdf7cda1bde7",
336 | "metadata": {},
337 | "outputs": [],
338 | "source": []
339 | }
340 | ],
341 | "metadata": {
342 | "kernelspec": {
343 | "display_name": "kbembed",
344 | "language": "python",
345 | "name": "kbembed"
346 | },
347 | "language_info": {
348 | "codemirror_mode": {
349 | "name": "ipython",
350 | "version": 3
351 | },
352 | "file_extension": ".py",
353 | "mimetype": "text/x-python",
354 | "name": "python",
355 | "nbconvert_exporter": "python",
356 | "pygments_lexer": "ipython3",
357 | "version": "3.12.2"
358 | }
359 | },
360 | "nbformat": 4,
361 | "nbformat_minor": 5
362 | }
363 |
--------------------------------------------------------------------------------
/qkb_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ALucek/QuicKB/288c62bd5024520388d05ece4a841900abb6d93d/qkb_logo.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # requirements.txt
2 | numpy==2.2.2
3 | tiktoken==0.8.0
4 | backoff==2.2.1
5 | tqdm==4.67.1
6 | pyyaml==6.0.2
7 | attrs==24.3.0
8 | datasets==3.2.0
9 | sentence_transformers==3.4.0
10 | transformers==4.48.1
11 | torch==2.5.1
12 | accelerate==1.3.0
13 | litellm==1.59.6
14 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name="quickb",
5 | version="0.1",
6 | packages=find_packages(where="src"),
7 | package_dir={"": "src"},
8 | install_requires=open("requirements.txt").read().splitlines(),
9 | entry_points={
10 | 'console_scripts': [
11 | 'quickb=main:main',
12 | ],
13 | },
14 | python_requires='>=3.10',
15 | )
16 |
--------------------------------------------------------------------------------
/src/chunking/__init__.py:
--------------------------------------------------------------------------------
1 | from .registry import ChunkerRegistry
2 | from .fixed_token_chunker import FixedTokenChunker
3 | from .recursive_token_chunker import RecursiveTokenChunker
4 | from .cluster_semantic_chunker import ClusterSemanticChunker
5 | from .llm_semantic_chunker import LLMSemanticChunker
6 | from .kamradt_modified_chunker import KamradtModifiedChunker
7 | from .utils import get_length_function, get_token_count, get_character_count
8 |
9 | __all__ = [
10 | 'ClusterSemanticChunker',
11 | 'LLMSemanticChunker',
12 | 'FixedTokenChunker',
13 | 'RecursiveTokenChunker',
14 | 'KamradtModifiedChunker',
15 | 'ChunkerRegistry',
16 | 'get_length_function',
17 | 'get_token_count',
18 | 'get_character_count'
19 | ]
--------------------------------------------------------------------------------
/src/chunking/base_chunker.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import List, Any
3 | from .utils import get_length_function
4 |
5 | class BaseChunker(ABC):
6 | """Base class for all chunking implementations."""
7 |
8 | def __init__(self, *args, **kwargs):
9 | """
10 | Initialize the chunker with length function configuration.
11 |
12 | Args:
13 | *args: Variable length argument list
14 | **kwargs: Arbitrary keyword arguments
15 | """
16 | self._initialize_length_function(kwargs)
17 |
18 | def _initialize_length_function(self, kwargs):
19 | """
20 | Initialize the length function based on provided configuration.
21 |
22 | Args:
23 | kwargs: Keyword arguments that may contain length function configuration
24 | """
25 | # Get length function type and parameters
26 | length_type = kwargs.pop('length_type', 'token')
27 | encoding_name = kwargs.pop('encoding_name', 'cl100k_base')
28 |
29 | # Set up default length function
30 | self.length_function = get_length_function(
31 | length_type=length_type,
32 | encoding_name=encoding_name
33 | )
34 |
35 | # Override with custom length function if provided
36 | if 'length_function' in kwargs:
37 | self.length_function = kwargs.pop('length_function')
38 |
39 | @abstractmethod
40 | def split_text(self, text: str) -> List[str]:
41 | """
42 | Split input text into chunks.
43 |
44 | Args:
45 | text: The input text to split
46 |
47 | Returns:
48 | List of text chunks
49 | """
50 | pass
--------------------------------------------------------------------------------
/src/chunking/cluster_semantic_chunker.py:
--------------------------------------------------------------------------------
1 |
2 | # This script is adapted from the chunking_evaluation package, developed by ChromaDB Research.
3 | # Original code can be found at: https://github.com/brandonstarxel/chunking_evaluation/blob/main/chunking_evaluation/chunking/cluster_semantic_chunker.py
4 | # License: MIT License
5 |
6 | from .base_chunker import BaseChunker
7 | from typing import List, Optional
8 | import numpy as np
9 | from litellm import embedding
10 | from .recursive_token_chunker import RecursiveTokenChunker
11 | from .registry import ChunkerRegistry
12 | from .utils import get_length_function
13 | import logging
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 | @ChunkerRegistry.register("ClusterSemanticChunker")
18 | class ClusterSemanticChunker(BaseChunker):
19 | def __init__(
20 | self,
21 | max_chunk_size: int = 400,
22 | min_chunk_size: int = 50,
23 | length_type: str = 'token',
24 | litellm_config: Optional[dict] = None,
25 | **kwargs
26 | ):
27 | super().__init__(length_type=length_type, **kwargs)
28 |
29 | self.length_function = get_length_function(
30 | length_type=length_type,
31 | encoding_name=kwargs.get('encoding_name', 'cl100k_base'),
32 | model_name=kwargs.get('model_name')
33 | )
34 |
35 | self.splitter = RecursiveTokenChunker(
36 | chunk_size=min_chunk_size,
37 | chunk_overlap=0,
38 | length_function=self.length_function,
39 | separators=["\n\n", "\n", ".", "?", "!", " ", ""]
40 | )
41 |
42 | self._litellm_config = litellm_config or {}
43 | self.max_cluster = max_chunk_size // min_chunk_size
44 |
45 | def _get_embeddings(self, texts: List[str]) -> List[List[float]]:
46 | """Universal embedding response handler"""
47 | try:
48 | response = embedding(
49 | model=self._litellm_config.get('embedding_model', 'text-embedding-3-large'),
50 | input=texts,
51 | api_base=self._litellm_config.get('embedding_api_base')
52 | )
53 |
54 | # Handle all possible response formats
55 | if isinstance(response, dict):
56 | items = response.get('data', [])
57 | elif hasattr(response, 'data'):
58 | items = response.data
59 | else:
60 | items = response
61 |
62 | return [
63 | item['embedding'] if isinstance(item, dict) else item.embedding
64 | for item in items
65 | ]
66 |
67 | except Exception as e:
68 | logger.error(f"Embedding failed for batch: {str(e)}")
69 | raise RuntimeError(f"Embedding error: {str(e)}")
70 |
71 | def _calculate_similarity_matrix(self, sentences: List[str]) -> np.ndarray:
72 | """Batch processing with error logging"""
73 | if not sentences:
74 | return np.zeros((0, 0))
75 |
76 | embeddings = []
77 | for batch_idx in range(0, len(sentences), 500):
78 | try:
79 | batch = sentences[batch_idx:batch_idx+500]
80 | embeddings.extend(self._get_embeddings(batch))
81 | except Exception as e:
82 | logger.error(f"Failed processing batch {batch_idx//500}: {str(e)}")
83 | raise
84 |
85 | embedding_matrix = np.array(embeddings)
86 | return np.dot(embedding_matrix, embedding_matrix.T)
87 |
88 | def _optimal_segmentation(self, matrix: np.ndarray) -> List[tuple]:
89 | """Original Chroma algorithm implementation"""
90 | n = matrix.shape[0]
91 | if n < 1:
92 | return []
93 |
94 | # Calculate mean of off-diagonal elements
95 | triu = np.triu_indices(n, k=1)
96 | tril = np.tril_indices(n, k=-1)
97 | mean_value = (matrix[triu].sum() + matrix[tril].sum()) / (n * (n - 1)) if n > 1 else 0
98 |
99 | matrix = matrix - mean_value
100 | np.fill_diagonal(matrix, 0)
101 |
102 | dp = np.zeros(n)
103 | segmentation = np.zeros(n, dtype=int)
104 |
105 | for i in range(n):
106 | for size in range(1, min(self.max_cluster + 1, i + 2)):
107 | start_idx = i - size + 1
108 | if start_idx >= 0:
109 | current_reward = matrix[start_idx:i+1, start_idx:i+1].sum()
110 | if start_idx > 0:
111 | current_reward += dp[start_idx - 1]
112 | if current_reward > dp[i]:
113 | dp[i] = current_reward
114 | segmentation[i] = start_idx
115 |
116 | clusters = []
117 | i = n - 1
118 | while i >= 0:
119 | start = segmentation[i]
120 | clusters.append((start, i))
121 | i = start - 1
122 |
123 | return list(reversed(clusters))
124 |
125 | def split_text(self, text: str) -> List[str]:
126 | """Main processing pipeline"""
127 | if not text.strip():
128 | return []
129 |
130 | # First-stage splitting
131 | sentences = self.splitter.split_text(text)
132 | if len(sentences) < 2:
133 | return [text]
134 |
135 | # Semantic clustering
136 | similarity_matrix = self._calculate_similarity_matrix(sentences)
137 | clusters = self._optimal_segmentation(similarity_matrix)
138 |
139 | return [' '.join(sentences[start:end+1]) for start, end in clusters]
--------------------------------------------------------------------------------
/src/chunking/fixed_token_chunker.py:
--------------------------------------------------------------------------------
1 |
2 | # This script is adapted from the LangChain package, developed by LangChain AI.
3 | # Original code can be found at: https://github.com/langchain-ai/langchain/blob/master/libs/text-splitters/langchain_text_splitters/base.py
4 | # chunking_evaluation modification: https://github.com/brandonstarxel/chunking_evaluation/blob/main/chunking_evaluation/chunking/fixed_token_chunker.py
5 | # License: MIT License
6 |
7 | from abc import ABC, abstractmethod
8 | import logging
9 | from typing import (
10 | AbstractSet,
11 | Any,
12 | Callable,
13 | Collection,
14 | List,
15 | Literal,
16 | Optional,
17 | Type,
18 | TypeVar,
19 | Union,
20 | )
21 | from .base_chunker import BaseChunker
22 | from attr import dataclass
23 | from .registry import ChunkerRegistry
24 |
25 | logger = logging.getLogger(__name__)
26 |
27 | TS = TypeVar("TS", bound="TextSplitter")
28 |
29 | class TextSplitter(BaseChunker, ABC):
30 | """Interface for splitting text into chunks."""
31 |
32 | def __init__(
33 | self,
34 | chunk_size: int = 4000,
35 | chunk_overlap: int = 200,
36 | length_function: Callable[[str], int] = len,
37 | keep_separator: bool = False,
38 | add_start_index: bool = False,
39 | strip_whitespace: bool = True,
40 | **kwargs
41 | ) -> None:
42 | """
43 | Args:
44 | chunk_size: Maximum size of chunks to return
45 | chunk_overlap: Overlap in characters (tokens) between chunks
46 | length_function: Function that measures the length of given chunks
47 | keep_separator: Whether to keep the separator in the chunks
48 | add_start_index: If `True`, includes chunk's start index in metadata
49 | strip_whitespace: If `True`, strips whitespace from the start/end
50 | """
51 | super().__init__(**kwargs)
52 | if chunk_overlap > chunk_size:
53 | raise ValueError(
54 | f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
55 | f"({chunk_size}), should be smaller."
56 | )
57 | self._chunk_size = chunk_size
58 | self._chunk_overlap = chunk_overlap
59 | # If BaseChunker sets self.length_function, we override with the user's length_function param if given
60 | if length_function is not None:
61 | self._length_function = length_function
62 | else:
63 | self._length_function = self.length_function
64 |
65 | self._keep_separator = keep_separator
66 | self._add_start_index = add_start_index
67 | self._strip_whitespace = strip_whitespace
68 |
69 | @abstractmethod
70 | def split_text(self, text: str) -> List[str]:
71 | pass
72 |
73 | def _join_docs(self, docs: List[str], separator: str) -> Optional[str]:
74 | text = separator.join(docs)
75 | if self._strip_whitespace:
76 | text = text.strip()
77 | return text if text != "" else None
78 |
79 | def _merge_splits(self, splits: List[str], separator: str) -> List[str]:
80 | """Combine smaller pieces into medium chunks."""
81 | separator_len = self._length_function(separator)
82 |
83 | docs = []
84 | current_doc: List[str] = []
85 | total = 0
86 | for d in splits:
87 | _len = self._length_function(d)
88 | if total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size:
89 | if total > self._chunk_size:
90 | logger.warning(
91 | f"Created a chunk of size {total}, "
92 | f"which is longer than the specified {self._chunk_size}"
93 | )
94 | if len(current_doc) > 0:
95 | doc = self._join_docs(current_doc, separator)
96 | if doc is not None:
97 | docs.append(doc)
98 | # Pop from the front while chunk > overlap
99 | while total > self._chunk_overlap or (
100 | total + _len + (separator_len if len(current_doc) > 0 else 0)
101 | > self._chunk_size
102 | and total > 0
103 | ):
104 | total -= self._length_function(current_doc[0]) + (
105 | separator_len if len(current_doc) > 1 else 0
106 | )
107 | current_doc = current_doc[1:]
108 | current_doc.append(d)
109 | total += _len + (separator_len if len(current_doc) > 1 else 0)
110 | doc = self._join_docs(current_doc, separator)
111 | if doc is not None:
112 | docs.append(doc)
113 | return docs
114 |
115 | @ChunkerRegistry.register("FixedTokenChunker")
116 | class FixedTokenChunker(TextSplitter):
117 | """Splitting text to tokens using model tokenizer."""
118 |
119 | def __init__(
120 | self,
121 | encoding_name: str = "cl100k_base",
122 | model_name: Optional[str] = None,
123 | chunk_size: int = 4000,
124 | chunk_overlap: int = 200,
125 | allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
126 | disallowed_special: Union[Literal["all"], Collection[str]] = "all",
127 | **kwargs: Any,
128 | ) -> None:
129 | super().__init__(chunk_size=chunk_size, chunk_overlap=chunk_overlap, **kwargs)
130 |
131 | try:
132 | import tiktoken
133 | except ImportError:
134 | raise ImportError(
135 | "Could not import tiktoken python package. "
136 | "This is needed for FixedTokenChunker. "
137 | "Please install it with `pip install tiktoken`."
138 | )
139 |
140 | if model_name is not None:
141 | enc = tiktoken.encoding_for_model(model_name)
142 | else:
143 | enc = tiktoken.get_encoding(encoding_name)
144 | self._tokenizer = enc
145 | self._allowed_special = allowed_special
146 | self._disallowed_special = disallowed_special
147 |
148 | def split_text(self, text: str) -> List[str]:
149 | def _encode(_text: str) -> List[int]:
150 | return self._tokenizer.encode(
151 | _text,
152 | allowed_special=self._allowed_special,
153 | disallowed_special=self._disallowed_special,
154 | )
155 |
156 | tokenizer = Tokenizer(
157 | chunk_overlap=self._chunk_overlap,
158 | tokens_per_chunk=self._chunk_size,
159 | decode=self._tokenizer.decode,
160 | encode=_encode,
161 | )
162 |
163 | return split_text_on_tokens(text=text, tokenizer=tokenizer)
164 |
165 |
166 | @dataclass(frozen=True)
167 | class Tokenizer:
168 | """Tokenizer data class."""
169 | chunk_overlap: int
170 | tokens_per_chunk: int
171 | decode: Callable[[List[int]], str]
172 | encode: Callable[[str], List[int]]
173 |
174 |
175 | def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]:
176 | """Split incoming text and return chunks using tokenizer."""
177 | splits: List[str] = []
178 | input_ids = tokenizer.encode(text)
179 | start_idx = 0
180 | cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
181 | chunk_ids = input_ids[start_idx:cur_idx]
182 | while start_idx < len(input_ids):
183 | splits.append(tokenizer.decode(chunk_ids))
184 | if cur_idx == len(input_ids):
185 | break
186 | start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
187 | cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
188 | chunk_ids = input_ids[start_idx:cur_idx]
189 | return splits
--------------------------------------------------------------------------------
/src/chunking/kamradt_modified_chunker.py:
--------------------------------------------------------------------------------
1 |
2 | # This script is adapted from the Greg Kamradt's notebook on chunking.
3 | # Original code can be found at: https://github.com/FullStackRetrieval-com/RetrievalTutorials/blob/main/tutorials/LevelsOfTextSplitting/5_Levels_Of_Text_Splitting.ipynb
4 | # chunking_evaluation modification: https://github.com/brandonstarxel/chunking_evaluation/blob/main/chunking_evaluation/chunking/kamradt_modified_chunker.py
5 |
6 | from typing import Optional, List, Any
7 | import numpy as np
8 | from .base_chunker import BaseChunker
9 | from .recursive_token_chunker import RecursiveTokenChunker
10 | from litellm import embedding
11 | from .registry import ChunkerRegistry
12 |
13 | @ChunkerRegistry.register("KamradtModifiedChunker")
14 | class KamradtModifiedChunker(BaseChunker):
15 | def __init__(
16 | self,
17 | avg_chunk_size: int = 400,
18 | min_chunk_size: int = 50,
19 | litellm_config: Optional[dict] = None,
20 | length_type: str = 'token',
21 | **kwargs
22 | ):
23 | super().__init__(length_type=length_type, **kwargs)
24 |
25 | self.splitter = RecursiveTokenChunker(
26 | chunk_size=min_chunk_size,
27 | chunk_overlap=0,
28 | length_function=self.length_function
29 | )
30 |
31 | self._litellm_config = litellm_config or {}
32 | self.avg_chunk_size = avg_chunk_size
33 |
34 | def _get_embeddings(self, texts: List[str]) -> List[List[float]]:
35 | """Get embeddings using LiteLLM."""
36 | response = embedding(
37 | model=self._litellm_config.get('embedding_model', 'text-embedding-3-large'),
38 | input=texts,
39 | api_base=self._litellm_config.get('embedding_api_base')
40 | )
41 | # Extract embeddings from the response data
42 | if hasattr(response, 'data'):
43 | return [item['embedding'] for item in response.data]
44 | elif isinstance(response, dict) and 'data' in response:
45 | return [item['embedding'] for item in response['data']]
46 | raise ValueError(f"Unexpected response format from LiteLLM: {response}")
47 |
48 | def combine_sentences(self, sentences: List[dict], buffer_size: int = 1) -> List[dict]:
49 | for i in range(len(sentences)):
50 | combined = []
51 | for j in range(max(0, i - buffer_size), min(len(sentences), i + buffer_size + 1)):
52 | combined.append(sentences[j]['sentence'])
53 | sentences[i]['combined_sentence'] = ' '.join(combined)
54 | return sentences
55 |
56 | def calculate_cosine_distances(self, sentences: List[dict]):
57 | embeddings = []
58 | for i in range(0, len(sentences), 500):
59 | batch = [s['combined_sentence'] for s in sentences[i:i + 500]]
60 | embeddings.extend(self._get_embeddings(batch))
61 |
62 | embedding_matrix = np.array(embeddings)
63 | norms = np.linalg.norm(embedding_matrix, axis=1, keepdims=True)
64 | embedding_matrix /= norms
65 | similarity_matrix = np.dot(embedding_matrix, embedding_matrix.T)
66 |
67 | distances = []
68 | for i in range(len(sentences) - 1):
69 | distance = 1 - similarity_matrix[i, i + 1]
70 | distances.append(distance)
71 | sentences[i]['distance_to_next'] = distance
72 | return distances, sentences
73 |
74 | def split_text(self, text: str) -> List[str]:
75 | s_list = self.splitter.split_text(text)
76 | sentences = [{'sentence': s, 'index': i} for i, s in enumerate(s_list)]
77 | if not sentences:
78 | return []
79 |
80 | sentences = self.combine_sentences(sentences, 3)
81 | distances, sentences = self.calculate_cosine_distances(sentences)
82 |
83 | total_tokens = sum(self.length_function(s['sentence']) for s in sentences)
84 | target_splits = total_tokens // self.avg_chunk_size if self.avg_chunk_size else 1
85 | distances = np.array(distances)
86 |
87 | low, high = 0.0, 1.0
88 | while high - low > 1e-6:
89 | mid = (low + high) / 2
90 | if (distances > mid).sum() > target_splits:
91 | low = mid
92 | else:
93 | high = mid
94 |
95 | split_indices = [i for i, d in enumerate(distances) if d > high]
96 | chunks = []
97 | start = 0
98 |
99 | for idx in split_indices:
100 | chunks.append(' '.join(s['sentence'] for s in sentences[start:idx + 1]))
101 | start = idx + 1
102 |
103 | if start < len(sentences):
104 | chunks.append(' '.join(s['sentence'] for s in sentences[start:]))
105 |
106 | return chunks
--------------------------------------------------------------------------------
/src/chunking/llm_semantic_chunker.py:
--------------------------------------------------------------------------------
1 |
2 | # This script is adapted from the chunking_evaluation package, developed by ChromaDB Research.
3 | # Original code can be found at: https://github.com/brandonstarxel/chunking_evaluation/blob/main/chunking_evaluation/chunking/llm_semantic_chunker.py
4 | # License: MIT License
5 |
6 | from .base_chunker import BaseChunker
7 | from .recursive_token_chunker import RecursiveTokenChunker
8 | import backoff
9 | from tqdm import tqdm
10 | from typing import List, Optional
11 | import re
12 | from .registry import ChunkerRegistry
13 | from litellm import completion
14 |
15 | @ChunkerRegistry.register("LLMSemanticChunker")
16 | class LLMSemanticChunker(BaseChunker):
17 | def __init__(
18 | self,
19 | litellm_config: Optional[dict] = None,
20 | length_type: str = 'token',
21 | **kwargs
22 | ):
23 | super().__init__(length_type=length_type, **kwargs)
24 |
25 | self._litellm_config = litellm_config or {}
26 |
27 | # Initialize the base splitter for initial text splitting
28 | self.splitter = RecursiveTokenChunker(
29 | chunk_size=50,
30 | chunk_overlap=0,
31 | length_function=self.length_function
32 | )
33 |
34 | def get_prompt(self, chunked_input, current_chunk=0, invalid_response=None):
35 | """Generate the prompt for the LLM."""
36 | base_prompt = (
37 | "You are an assistant specialized in splitting text into thematically consistent sections. "
38 | "The text has been divided into chunks, each marked with <|start_chunk_X|> and <|end_chunk_X|> tags, where X is the chunk number. "
39 | "Your task is to identify the points where splits should occur, such that consecutive chunks of similar themes stay together. "
40 | "Respond with a list of chunk IDs where you believe a split should be made. For example, if chunks 1 and 2 belong together but chunk 3 starts a new topic, you would suggest a split after chunk 2. THE CHUNKS MUST BE IN ASCENDING ORDER."
41 | "Your response should be in the form: 'split_after: 3, 5'."
42 | )
43 |
44 | user_content = (
45 | f"CHUNKED_TEXT: {chunked_input}\n\n"
46 | f"Respond with split points (ascending, ≥{current_chunk}). "
47 | "Respond only with the IDs of the chunks where you believe a split should occur. YOU MUST RESPOND WITH AT LEAST ONE SPLIT. THESE SPLITS MUST BE IN ASCENDING ORDER"
48 | )
49 | if invalid_response:
50 | user_content += (
51 | f"\n\\Previous invalid response: {invalid_response}. "
52 | "DO NOT REPEAT THIS ARRAY OF NUMBERS. Please try again."
53 | )
54 |
55 | return [
56 | {"role": "system", "content": base_prompt},
57 | {"role": "user", "content": user_content}
58 | ]
59 |
60 | @backoff.on_exception(backoff.expo, Exception, max_tries=3)
61 | def _get_llm_response(self, context: str, current: int) -> str:
62 | """Get chunking suggestions from LLM using LiteLLM."""
63 | try:
64 | response = completion(
65 | model=self._litellm_config.get('model', 'openai/gpt-4o'),
66 | messages=self.get_prompt(context, current),
67 | temperature=0.2,
68 | max_tokens=200,
69 | api_base=self._litellm_config.get('model_api_base')
70 | )
71 | return response.choices[0].message.content
72 | except Exception as e:
73 | print(f"LLM API error: {str(e)}")
74 | return ""
75 |
76 | def _parse_response(self, response: str, current_chunk: int) -> List[int]:
77 | numbers = []
78 | if 'split_after:' in response:
79 | numbers = list(map(int, re.findall(r'\d+', response.split('split_after:')[1])))
80 | return sorted(n for n in numbers if n > current_chunk) # Ensure 1-based > current 0-based
81 |
82 | def _merge_chunks(self, chunks: List[str], indices: List[int]) -> List[str]:
83 | """Merge chunks based on split indices (indices are 1-based from LLM)"""
84 | merged = []
85 | current = []
86 | # Convert to 0-based indices and sort
87 | split_points = sorted([i-1 for i in indices if i > 0])
88 |
89 | for i, chunk in enumerate(chunks):
90 | current.append(chunk)
91 | if i in split_points:
92 | merged.append(" ".join(current).strip())
93 | current = []
94 | if current:
95 | merged.append(" ".join(current).strip())
96 | return merged
97 |
98 | def split_text(self, text: str) -> List[str]:
99 | """Split input text into coherent chunks using LLM guidance."""
100 | chunks = self.splitter.split_text(text)
101 | split_indices = []
102 | current_chunk = 0
103 |
104 | with tqdm(total=len(chunks), desc="Processing chunks") as pbar:
105 | while current_chunk < len(chunks) - 4:
106 | context_window = []
107 | token_count = 0
108 |
109 | for i in range(current_chunk, len(chunks)):
110 | token_count += self.length_function(chunks[i])
111 | if token_count > 800:
112 | break
113 | context_window.append(f"<|start_chunk_{i+1}|>{chunks[i]}<|end_chunk_{i+1}|>")
114 |
115 | response = self._get_llm_response("\n".join(context_window), current_chunk)
116 | numbers = self._parse_response(response, current_chunk) # FIXED: Added current_chunk argument
117 |
118 | if numbers:
119 | split_indices.extend(numbers)
120 | current_chunk = numbers[-1]
121 | pbar.update(current_chunk - pbar.n)
122 | else:
123 | break
124 |
125 | return self._merge_chunks(chunks, split_indices)
--------------------------------------------------------------------------------
/src/chunking/recursive_token_chunker.py:
--------------------------------------------------------------------------------
1 |
2 | # This script is adapted from the LangChain package, developed by LangChain AI.
3 | # Original code can be found at: https://github.com/langchain-ai/langchain/blob/master/libs/text-splitters/langchain_text_splitters/character.py
4 | # chunking_evaluation modification: https://github.com/brandonstarxel/chunking_evaluation/blob/main/chunking_evaluation/chunking/recursive_token_chunker.py
5 | # License: MIT License
6 |
7 | from typing import Any, List, Optional
8 | from .utils import Language
9 | from .fixed_token_chunker import TextSplitter
10 | import re
11 | from .registry import ChunkerRegistry
12 |
13 | def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> List[str]:
14 | if separator:
15 | if keep_separator:
16 | # The parentheses in the pattern keep the delimiters in the result.
17 | _splits = re.split(f"({separator})", text)
18 | splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
19 | if len(_splits) % 2 == 0:
20 | splits += _splits[-1:]
21 | splits = [_splits[0]] + splits
22 | else:
23 | splits = re.split(separator, text)
24 | else:
25 | splits = list(text)
26 | return [s for s in splits if s != ""]
27 |
28 | @ChunkerRegistry.register("RecursiveTokenChunker")
29 | class RecursiveTokenChunker(TextSplitter):
30 | """Splitting text by recursively looking at characters / tokens."""
31 |
32 | def __init__(
33 | self,
34 | chunk_size: int = 4000,
35 | chunk_overlap: int = 200,
36 | separators: Optional[List[str]] = None,
37 | keep_separator: bool = True,
38 | is_separator_regex: bool = False,
39 | length_type: str = 'token',
40 | **kwargs: Any,
41 | ) -> None:
42 | super().__init__(
43 | chunk_size=chunk_size,
44 | chunk_overlap=chunk_overlap,
45 | keep_separator=keep_separator,
46 | length_type=length_type,
47 | **kwargs
48 | )
49 | self._separators = separators or ["\n\n", "\n", ".", "?", "!", " ", ""]
50 | self._is_separator_regex = is_separator_regex
51 |
52 | def _split_text(self, text: str, separators: List[str]) -> List[str]:
53 | final_chunks = []
54 | separator = separators[-1]
55 | new_separators = []
56 |
57 | for i, _s in enumerate(separators):
58 | _separator = _s if self._is_separator_regex else re.escape(_s)
59 | if _s == "":
60 | separator = _s
61 | break
62 | if re.search(_separator, text):
63 | separator = _s
64 | new_separators = separators[i + 1:]
65 | break
66 |
67 | _separator = separator if self._is_separator_regex else re.escape(separator)
68 | splits = _split_text_with_regex(text, _separator, self._keep_separator)
69 |
70 | _good_splits = []
71 | actual_separator = "" if self._keep_separator else separator
72 |
73 | for s in splits:
74 | if self.length_function(s) < self._chunk_size:
75 | _good_splits.append(s)
76 | else:
77 | if _good_splits:
78 | merged_text = self._merge_splits(_good_splits, actual_separator)
79 | final_chunks.extend(merged_text)
80 | _good_splits = []
81 | if not new_separators:
82 | final_chunks.append(s)
83 | else:
84 | other_info = self._split_text(s, new_separators)
85 | final_chunks.extend(other_info)
86 | if _good_splits:
87 | merged_text = self._merge_splits(_good_splits, actual_separator)
88 | final_chunks.extend(merged_text)
89 |
90 | return final_chunks
91 |
92 | def split_text(self, text: str) -> List[str]:
93 | return self._split_text(text, self._separators)
94 |
95 | @staticmethod
96 | def get_separators_for_language(language: Language) -> List[str]:
97 | if language == Language.PYTHON:
98 | return [
99 | "\nclass ",
100 | "\ndef ",
101 | "\n\tdef ",
102 | "\n\n",
103 | "\n",
104 | " ",
105 | "",
106 | ]
107 | raise ValueError(
108 | f"Language {language} is not supported! "
109 | f"Please choose from {list(Language)}"
110 | )
--------------------------------------------------------------------------------
/src/chunking/registry.py:
--------------------------------------------------------------------------------
1 | class ChunkerRegistry:
2 | _chunkers = {}
3 |
4 | @classmethod
5 | def register(cls, name: str):
6 | def decorator(chunker_class):
7 | cls._chunkers[name] = chunker_class
8 | return chunker_class
9 | return decorator
10 |
11 | @classmethod
12 | def get_chunker(cls, name: str):
13 | if name not in cls._chunkers:
14 | available = list(cls._chunkers.keys())
15 | raise ValueError(f"Unknown chunker: {name}. Available chunkers: {available}")
16 | return cls._chunkers[name]
--------------------------------------------------------------------------------
/src/chunking/utils.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | import tiktoken
3 | from typing import Callable, Optional
4 |
5 | class Language(str, Enum):
6 | """Supported languages for language-specific chunking."""
7 | CPP = "cpp"
8 | GO = "go"
9 | JAVA = "java"
10 | KOTLIN = "kotlin"
11 | JS = "js"
12 | TS = "ts"
13 | PHP = "php"
14 | PROTO = "proto"
15 | PYTHON = "python"
16 | RST = "rst"
17 | RUBY = "ruby"
18 | RUST = "rust"
19 | SCALA = "scala"
20 | SWIFT = "swift"
21 | MARKDOWN = "markdown"
22 | LATEX = "latex"
23 | HTML = "html"
24 | SOL = "sol"
25 | CSHARP = "csharp"
26 | COBOL = "cobol"
27 | C = "c"
28 | LUA = "lua"
29 | PERL = "perl"
30 |
31 | def get_token_count(
32 | string: str,
33 | encoding_name: str = "cl100k_base",
34 | model_name: Optional[str] = None,
35 | **kwargs
36 | ) -> int:
37 | """
38 | Count the number of tokens in a string using tiktoken.
39 |
40 | Args:
41 | string: The text to count tokens for
42 | encoding_name: The name of the tiktoken encoding to use
43 | model_name: Optional model name to use specific encoding
44 | **kwargs: Additional arguments passed to tiktoken encoder
45 |
46 | Returns:
47 | Number of tokens in the string
48 | """
49 | try:
50 | if model_name:
51 | enc = tiktoken.encoding_for_model(model_name)
52 | else:
53 | enc = tiktoken.get_encoding(encoding_name)
54 |
55 | allowed_special = kwargs.get('allowed_special', set())
56 | disallowed_special = kwargs.get('disallowed_special', 'all')
57 |
58 | return len(enc.encode(
59 | string,
60 | allowed_special=allowed_special,
61 | disallowed_special=disallowed_special
62 | ))
63 | except Exception as e:
64 | raise ValueError(f"Error counting tokens: {str(e)}")
65 |
66 | def get_character_count(text: str) -> int:
67 | """
68 | Count the number of characters in a string.
69 |
70 | Args:
71 | text: The text to count characters for
72 |
73 | Returns:
74 | Number of characters in the string
75 | """
76 | return len(text)
77 |
78 | def get_length_function(length_type: str = "token", **kwargs) -> Callable[[str], int]:
79 | """
80 | Get a length function based on the specified type.
81 |
82 | Args:
83 | length_type: Type of length function ('token' or 'character')
84 | **kwargs: Additional arguments passed to token counter
85 |
86 | Returns:
87 | A callable that takes a string and returns its length
88 | """
89 | if length_type == "token":
90 | return lambda x: get_token_count(x, **kwargs)
91 | elif length_type == "character":
92 | return get_character_count
93 | else:
94 | raise ValueError(
95 | f"Unknown length type: {length_type}. "
96 | "Choose 'token' or 'character'"
97 | )
--------------------------------------------------------------------------------
/src/hub_upload/card_generator.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 | import os
4 | from typing import Dict, Any, Optional, List
5 |
6 | class DatasetCardGenerator:
7 | """Handles dataset card generation for quickb datasets."""
8 |
9 | def __init__(self, template_path: str = "src/hub_upload/template.md"):
10 | """Initialize with path to card template."""
11 | self.template_path = template_path
12 | with open(template_path, 'r', encoding='utf-8') as f:
13 | self.template = f.read()
14 |
15 | def _get_size_category(self, num_entries: Optional[int]) -> str:
16 | """Determine the size category based on number of entries."""
17 | if not num_entries:
18 | return "unknown"
19 | elif num_entries < 1000:
20 | return "n<1K"
21 | elif num_entries < 10000:
22 | return "1K1M"
29 |
30 | def _format_chunker_params(self, params: Dict[str, Any]) -> str:
31 | """Simple Markdown-safe parameter formatting"""
32 | return "\n ".join(
33 | f"- **{key}**: `{repr(value)}`"
34 | for key, value in params.items()
35 | if value is not None and not key.startswith('_')
36 | )
37 |
38 | def _format_question_generation(self,
39 | model_name: str,
40 | similarity_threshold: float,
41 | num_questions: int,
42 | num_deduped: int
43 | ) -> str:
44 | """Format question generation section if enabled."""
45 | return f"""### Question Generation
46 | - **Model**: {model_name}
47 | - **Deduplication threshold**: {similarity_threshold}
48 | - **Results**:
49 | - Total questions generated: {num_questions}
50 | - Questions after deduplication: {num_deduped}"""
51 |
52 | def _format_dataset_structure(self, has_chunks: bool, has_questions: bool) -> str:
53 | """Format dataset structure section based on available configurations."""
54 | if has_questions:
55 | return """### Dataset Structure
56 | - `anchor`: The generated question
57 | - `positive`: The text chunk containing the answer
58 | - `question_id`: Unique identifier for the question
59 | - `chunk_id`: Reference to the source chunk"""
60 | else:
61 | return """### Dataset Structure
62 | This dataset contains the following fields:
63 |
64 | - `text`: The content of each text chunk
65 | - `source`: The source file path for the chunk
66 | - `id`: Unique identifier for each chunk"""
67 |
68 | def _format_chunking_section(self,
69 | chunker_name: str,
70 | chunker_params: Dict[str, Any],
71 | num_chunks: int,
72 | avg_chunk_size: float,
73 | num_files: int
74 | ) -> str:
75 | """Format chunking section if enabled."""
76 | return f"""### Chunking Configuration
77 | - **Chunker**: {chunker_name}
78 | - **Parameters**:
79 | {self._format_chunker_params(chunker_params)}
80 |
81 | ### Dataset Statistics
82 | - Total chunks: {num_chunks:,}
83 | - Average chunk size: {avg_chunk_size:.1f} words
84 | - Source files: {num_files}"""
85 |
86 | def generate_card(self,
87 | dataset_name: str,
88 | chunker_name: Optional[str] = None,
89 | chunker_params: Optional[Dict[str, Any]] = None,
90 | num_chunks: Optional[int] = None,
91 | avg_chunk_size: Optional[float] = None,
92 | num_files: Optional[int] = None,
93 | question_generation: Optional[Dict[str, Any]] = None
94 | ) -> str:
95 | """Generate a dataset card with the provided information."""
96 |
97 | # Load knowledgebase data to determine size category
98 | size_category = self._get_size_category(num_chunks)
99 |
100 | # Determine components and tags
101 | has_chunks = all(x is not None for x in [chunker_name, chunker_params, num_chunks, avg_chunk_size, num_files])
102 | has_questions = question_generation is not None
103 |
104 | gen_tag = "\n- question-generation" if has_questions else ""
105 |
106 | # Format sections based on available data
107 | chunking_section = ""
108 | if has_chunks:
109 | chunking_section = self._format_chunking_section(
110 | chunker_name=chunker_name,
111 | chunker_params=chunker_params,
112 | num_chunks=num_chunks,
113 | avg_chunk_size=avg_chunk_size,
114 | num_files=num_files
115 | )
116 |
117 | qg_section = ""
118 | if has_questions:
119 | qg_section = self._format_question_generation(
120 | model_name=question_generation["model_name"],
121 | similarity_threshold=question_generation["similarity_threshold"],
122 | num_questions=question_generation["num_questions"],
123 | num_deduped=question_generation["num_deduped"]
124 | )
125 |
126 | # Generate dataset structure section
127 | dataset_structure = self._format_dataset_structure(has_chunks, has_questions)
128 |
129 | # Fill template
130 | return self.template.format(
131 | dataset_name=dataset_name,
132 | gen_tag=gen_tag,
133 | size_category=size_category,
134 | chunker_section=chunking_section,
135 | question_generation=qg_section,
136 | dataset_structure=dataset_structure
137 | )
--------------------------------------------------------------------------------
/src/hub_upload/dataset_pusher.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | from pathlib import Path
5 | from typing import Optional, Dict, Any
6 | from datasets import Dataset
7 | from huggingface_hub import create_repo, upload_file, repo_exists
8 | from .card_generator import DatasetCardGenerator
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 | class DatasetPusher:
13 | """Handles uploading datasets to the Hugging Face Hub."""
14 |
15 | def __init__(self, username: str, token: Optional[str] = None):
16 | """Initialize with HF credentials."""
17 | self.username = username
18 | self.token = token or os.getenv("HF_TOKEN")
19 |
20 | if not self.token:
21 | raise ValueError("No Hugging Face token provided or found in environment")
22 |
23 | self.card_generator = DatasetCardGenerator()
24 |
25 | def _repository_exists(self, repo_id: str) -> bool:
26 | """Check if repository already exists."""
27 | try:
28 | return repo_exists(repo_id, repo_type="dataset")
29 | except Exception as e:
30 | logger.error(f"Error checking repository: {str(e)}")
31 | return False
32 |
33 | def _load_json_file(self, file_path: str) -> list:
34 | """Load JSON file and ensure it's a list of records."""
35 | try:
36 | with open(file_path, 'r', encoding='utf-8') as f:
37 | data = json.load(f)
38 | if not isinstance(data, list):
39 | raise ValueError(f"Expected JSON array in {file_path}")
40 | return data
41 | except Exception as e:
42 | logger.error(f"Error loading {file_path}: {str(e)}")
43 | raise
44 |
45 | def _calculate_dataset_stats(self, knowledgebase_data: list) -> Dict[str, Any]:
46 | """Calculate statistics for dataset card."""
47 | text_lengths = [len(item['text'].split()) for item in knowledgebase_data]
48 | unique_sources = len(set(item['source'] for item in knowledgebase_data))
49 |
50 | return {
51 | 'num_chunks': len(knowledgebase_data),
52 | 'avg_chunk_size': sum(text_lengths) / len(text_lengths) if text_lengths else 0,
53 | 'num_files': unique_sources
54 | }
55 |
56 | def push_dataset(
57 | self,
58 | hub_dataset_id: str,
59 | knowledgebase_path: Optional[str] = None,
60 | chunker_info: Optional[Dict[str, Any]] = None,
61 | train_path: Optional[str] = None,
62 | question_gen_info: Optional[Dict[str, Any]] = None,
63 | private: bool = True
64 | ) -> None:
65 | """Push dataset to the Hugging Face Hub, overwriting existing data."""
66 | try:
67 | # Create repository if it doesn't exist
68 | if not self._repository_exists(hub_dataset_id):
69 | create_repo(
70 | hub_dataset_id,
71 | repo_type="dataset",
72 | private=private,
73 | token=self.token
74 | )
75 | logger.info(f"Created new dataset repository: {hub_dataset_id}")
76 | else:
77 | logger.info(f"Dataset repository exists: {hub_dataset_id}")
78 |
79 | # Load and push knowledgebase if provided
80 | kb_data = None
81 | if knowledgebase_path:
82 | kb_data = self._load_json_file(knowledgebase_path)
83 | kb_dataset = Dataset.from_list(kb_data)
84 | kb_dataset.push_to_hub(
85 | hub_dataset_id,
86 | token=self.token,
87 | private=private
88 | )
89 | logger.info(f"Pushed knowledgebase to {hub_dataset_id}")
90 |
91 | # Load and push training data if provided
92 | if train_path:
93 | train_data = self._load_json_file(train_path)
94 | train_dataset = Dataset.from_list(train_data)
95 | train_dataset.push_to_hub(
96 | hub_dataset_id,
97 | token=self.token,
98 | private=private
99 | )
100 | logger.info(f"Pushed training data to {hub_dataset_id}")
101 |
102 | # Generate and upload README
103 | repository_name = hub_dataset_id.split('/')[-1]
104 | card_content = self.card_generator.generate_card(
105 | dataset_name=repository_name,
106 | chunker_name=chunker_info.get('chunker_name') if chunker_info else None,
107 | chunker_params=chunker_info.get('chunker_params') if chunker_info else None,
108 | num_chunks=self._calculate_dataset_stats(kb_data)['num_chunks'] if kb_data else None,
109 | avg_chunk_size=self._calculate_dataset_stats(kb_data)['avg_chunk_size'] if kb_data else None,
110 | num_files=self._calculate_dataset_stats(kb_data)['num_files'] if kb_data else None,
111 | question_generation=question_gen_info
112 | )
113 |
114 | upload_file(
115 | path_or_fileobj=card_content.encode('utf-8'),
116 | path_in_repo="README.md",
117 | repo_id=hub_dataset_id,
118 | repo_type="dataset",
119 | token=self.token
120 | )
121 | logger.info(f"Uploaded README.md to {hub_dataset_id}")
122 |
123 | except Exception as e:
124 | logger.error(f"Error pushing dataset to Hub: {str(e)}")
125 | raise
--------------------------------------------------------------------------------
/src/hub_upload/template.md:
--------------------------------------------------------------------------------
1 | ---
2 | language:
3 | - en
4 | pretty_name: "{dataset_name}"
5 | tags:
6 | - quickb
7 | - text-chunking{gen_tag}
8 | - {size_category}
9 | task_categories:
10 | - text-generation
11 | - text-retrieval
12 | task_ids:
13 | - document-retrieval
14 | library_name: quickb
15 | ---
16 |
17 | # {dataset_name}
18 |
19 | Generated using [QuicKB](https://github.com/AdamLucek/quickb), a tool developed by [Adam Lucek](https://huggingface.co/AdamLucek).
20 |
21 | QuicKB optimizes document retrieval by creating fine-tuned knowledge bases through an end-to-end pipeline that handles document chunking, training data generation, and embedding model optimization.
22 |
23 | {chunker_section}
24 |
25 | {question_generation}
26 |
27 | {dataset_structure}
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import json
4 | import uuid
5 | from enum import Enum, auto
6 | from pathlib import Path
7 | from typing import List, Dict, Optional, Any, Literal
8 |
9 | import yaml
10 | from pydantic import BaseModel, ConfigDict, field_validator
11 | from datasets import load_dataset, Dataset
12 |
13 | from chunking import ChunkerRegistry
14 | from hub_upload.dataset_pusher import DatasetPusher
15 | from synth_dataset.question_generator import QuestionGenerator
16 | from training.train import main as train_main
17 |
18 | logger = logging.getLogger(__name__)
19 | logging.getLogger("openai").setLevel(logging.WARNING)
20 | logging.getLogger("httpcore").setLevel(logging.WARNING)
21 | logging.getLogger("httpx").setLevel(logging.WARNING)
22 |
23 | class PipelineStage(Enum):
24 | CHUNK = auto()
25 | GENERATE = auto()
26 | TRAIN = auto()
27 |
28 | class BatchSamplers(str, Enum):
29 | BATCH_SAMPLER = "batch_sampler"
30 | NO_DUPLICATES = "no_duplicates"
31 | GROUP_BY_LABEL = "group_by_label"
32 |
33 | class LiteLLMConfig(BaseModel):
34 | """Configuration for LiteLLM model and embedding settings."""
35 | model_config = ConfigDict(extra='forbid', validate_default=True)
36 |
37 | model: Optional[str] = "openai/gpt-4o"
38 | model_api_base: Optional[str] = None
39 | embedding_model: Optional[str] = "openai/text-embedding-3-large"
40 | embedding_api_base: Optional[str] = None
41 |
42 | class QuestionGenInputConfig(BaseModel):
43 | """Configuration for question generation input dataset."""
44 | model_config = ConfigDict(extra='forbid', validate_default=True)
45 |
46 | dataset_source: Literal["local", "hub"] = "local"
47 | knowledgebase_dataset_id: Optional[str] = None
48 | local_knowledgebase_path: Optional[str] = None
49 |
50 | class TrainInputConfig(BaseModel):
51 | """Configuration for training input datasets."""
52 | model_config = ConfigDict(extra='forbid', validate_default=True)
53 |
54 | dataset_source: Literal["local", "hub"] = "local"
55 | train_dataset_id: Optional[str] = None
56 | knowledgebase_dataset_id: Optional[str] = None
57 | local_train_path: Optional[str] = None
58 | local_knowledgebase_path: Optional[str] = None
59 |
60 | class UploadConfig(BaseModel):
61 | """Configuration for Hugging Face Hub uploads."""
62 | model_config = ConfigDict(extra='forbid', validate_default=True)
63 |
64 | push_to_hub: bool = False
65 | hub_private: bool = False
66 | hub_dataset_id: Optional[str] = None
67 | hub_model_id: Optional[str] = None
68 |
69 | class ChunkerConfig(BaseModel):
70 | """Configuration for text chunking."""
71 | model_config = ConfigDict(extra='forbid', validate_default=True)
72 |
73 | chunker: str
74 | chunker_arguments: Dict[str, Any]
75 | output_path: str
76 | upload_config: Optional[UploadConfig] = None
77 |
78 | @property
79 | def litellm_config(self) -> Optional[LiteLLMConfig]:
80 | """Extract ModelConfig from chunker_arguments if present."""
81 | if "model_config" in self.chunker_arguments:
82 | return LiteLLMConfig.model_validate(self.chunker_arguments["litellm_config"])
83 | return None
84 |
85 | class QuestionGeneratorConfig(BaseModel):
86 | """Configuration for question generation."""
87 | model_config = ConfigDict(extra='forbid', validate_default=True)
88 |
89 | output_path: str
90 | input_dataset_config: QuestionGenInputConfig
91 | litellm_config: Optional[LiteLLMConfig]
92 | max_workers: Optional[int] = 20
93 | deduplication_enabled: Optional[bool] = True
94 | dedup_embedding_batch_size: Optional[int] = 500
95 | similarity_threshold: Optional[float] = 0.85
96 | upload_config: Optional[UploadConfig] = None
97 | llm_calls_per_minute: Optional[int] = 15
98 | embedding_calls_per_minute: Optional[int] = 15
99 |
100 | class ModelSettings(BaseModel):
101 | """Settings for the embedding model training."""
102 | model_config = ConfigDict(extra='forbid', validate_default=True)
103 |
104 | model_id: str
105 | matryoshka_dimensions: List[int] = [768, 512, 256, 128, 64]
106 | metric_for_best_model: str = "eval_dim_128_cosine_ndcg@10"
107 | max_seq_length: Optional[int] = 768
108 | trust_remote_code: Optional[bool] = False
109 |
110 | class TrainingArguments(BaseModel):
111 | """Arguments for model training."""
112 | model_config = ConfigDict(extra='forbid', validate_default=True)
113 |
114 | # Required parameters
115 | output_path: str
116 | device: Optional[str] = "cuda"
117 |
118 | # Basic training parameters
119 | epochs: Optional[int] = 4
120 | batch_size: Optional[int] = 32
121 | gradient_accumulation_steps: Optional[int] = 16
122 | learning_rate: Optional[float] = 2.0e-5
123 |
124 | # Learning rate scheduler settings
125 | warmup_ratio: Optional[float] = 0.1
126 | lr_scheduler_type: Optional[str] = "cosine"
127 |
128 | # Optimizer settings
129 | optim: Optional[str] = "adamw_torch_fused"
130 |
131 | # Hardware optimization flags
132 | tf32: Optional[bool] = True
133 | bf16: Optional[bool] = True
134 |
135 | # Batch sampling strategy
136 | batch_sampler: Optional[BatchSamplers] = BatchSamplers.NO_DUPLICATES
137 |
138 | # Training and evaluation strategy
139 | eval_strategy: Optional[str] = "epoch"
140 | save_strategy: Optional[str] = "epoch"
141 | logging_steps: Optional[int] = 10
142 | save_total_limit: Optional[int] = 3
143 | load_best_model_at_end: Optional[bool] = True
144 |
145 | # Reporting
146 | report_to: Optional[str] = "none"
147 |
148 | class TrainingConfig(BaseModel):
149 | """Configuration for model training."""
150 | model_config = ConfigDict(extra='forbid', validate_default=True)
151 |
152 | model_settings: ModelSettings
153 | training_arguments: TrainingArguments
154 | train_dataset_config: TrainInputConfig
155 | upload_config: Optional[UploadConfig] = None
156 |
157 | class PipelineConfig(BaseModel):
158 | """Main configuration for the QuicKB pipeline."""
159 | model_config = ConfigDict(extra='forbid', validate_default=True)
160 |
161 | pipeline: Dict[str, str]
162 | hub_username: Optional[str] = None
163 | hub_token: Optional[str] = None
164 | path_to_knowledgebase: Optional[str]
165 | chunker_config: Optional[ChunkerConfig] = None
166 | question_generation: Optional[QuestionGeneratorConfig] = None
167 | training: Optional[TrainingConfig] = None
168 |
169 | def load_dataset_from_local(file_path: str) -> List[Dict[str, Any]]:
170 | """Load dataset from a local JSON file."""
171 | try:
172 | with open(file_path, "r", encoding="utf-8") as f:
173 | data = json.load(f)
174 | if not isinstance(data, list):
175 | raise ValueError(f"Expected JSON array in {file_path}")
176 | return data
177 | except FileNotFoundError:
178 | logger.error(f"Dataset file not found: {file_path}")
179 | raise
180 | except json.JSONDecodeError:
181 | logger.error(f"Error decoding JSON in {file_path}")
182 | raise
183 |
184 | def load_dataset_from_hub(hub_dataset_id: str) -> List[Dict[str, Any]]:
185 | """Load dataset from Hugging Face Hub using default config."""
186 | try:
187 | logger.info(f"Loading dataset from Hub: {hub_dataset_id}")
188 | dataset = load_dataset(hub_dataset_id, split="train")
189 | if dataset:
190 | return dataset.to_list()
191 | else:
192 | logger.error(f"No data found in dataset: {hub_dataset_id}")
193 | return []
194 | except Exception as e:
195 | logger.error(f"Error loading dataset from Hub: {hub_dataset_id}. Error: {e}")
196 | raise
197 |
198 |
199 | def load_pipeline_config(config_path: str | Path = "config.yaml") -> PipelineConfig:
200 | """Load and validate pipeline configuration."""
201 |
202 | config_path = Path(config_path)
203 | if not config_path.exists():
204 | raise FileNotFoundError(f"Config file not found: {config_path}")
205 |
206 | try:
207 | with open(config_path, 'r', encoding='utf-8') as f:
208 | config_data = yaml.safe_load(f)
209 |
210 | return PipelineConfig.model_validate(config_data)
211 | except Exception as e:
212 | logger.error(f"Error loading config from {config_path}: {str(e)}")
213 | raise
214 |
215 | def process_chunks(config: PipelineConfig) -> List[Dict[str, Any]]:
216 | """Process documents into chunks and optionally upload to Hub."""
217 |
218 | # Get chunker class
219 | chunker_class = ChunkerRegistry.get_chunker(config.chunker_config.chunker)
220 | args = config.chunker_config.chunker_arguments.copy()
221 | chunker = chunker_class(**args)
222 |
223 | logger.info(f"Initialized Chunker: {config.chunker_config.chunker}")
224 |
225 | # Process files
226 | base_path = Path(config.path_to_knowledgebase)
227 | results = []
228 | total_chunks = 0
229 |
230 | for file_path in base_path.rglob('*.txt'):
231 | try:
232 | with open(file_path, 'r', encoding='utf-8') as f:
233 | text = f.read()
234 | chunks = chunker.split_text(text)
235 | source_path = str(file_path.relative_to(base_path))
236 |
237 | for chunk in chunks:
238 | results.append({
239 | "id": str(uuid.uuid4()),
240 | "text": chunk,
241 | "source": source_path
242 | })
243 |
244 | logger.info(f"Created {len(chunks)} chunks from {file_path}")
245 | total_chunks += len(chunks)
246 |
247 | except Exception as e:
248 | logger.error(f"Error processing {file_path}: {str(e)}")
249 | continue
250 |
251 | logger.info(f"Created {total_chunks} chunks in total")
252 |
253 | # Save results
254 | output_path = Path(config.chunker_config.output_path)
255 | output_path.parent.mkdir(parents=True, exist_ok=True)
256 |
257 | with open(output_path, 'w', encoding='utf-8') as f:
258 | json.dump(results, f, indent=2, ensure_ascii=False)
259 |
260 | # Handle upload if configured
261 | if (config.hub_username and
262 | config.chunker_config.upload_config and
263 | config.chunker_config.upload_config.push_to_hub):
264 | try:
265 | pusher = DatasetPusher(
266 | username=config.hub_username,
267 | token=config.hub_token
268 | )
269 |
270 | repository_id = (config.chunker_config.upload_config.hub_dataset_id
271 | or f"{config.hub_username}/{Path(config.chunker_config.output_path).stem}")
272 |
273 | chunker_info = {
274 | 'chunker_name': config.chunker_config.chunker,
275 | 'chunker_params': config.chunker_config.chunker_arguments
276 | }
277 |
278 | pusher.push_dataset(
279 | hub_dataset_id=repository_id,
280 | knowledgebase_path=config.chunker_config.output_path,
281 | chunker_info=chunker_info,
282 | private=config.chunker_config.upload_config.hub_private
283 | )
284 | logger.info(f"Successfully uploaded chunks to Hub: {repository_id}")
285 | except Exception as e:
286 | logger.error(f"Failed to upload chunks to Hub: {str(e)}")
287 |
288 | return results
289 |
290 | def generate_questions(
291 | config: PipelineConfig,
292 | kb_dataset: List[Dict[str, Any]]
293 | ) -> tuple[List[Dict[str, Any]], Dict[str, int]]:
294 | """Generate questions and optionally upload to Hub."""
295 |
296 | if not config.question_generation:
297 | raise ValueError("Question generation config is required but not provided")
298 |
299 | generator = QuestionGenerator(
300 | prompt_path="src/prompts/question_generation.txt",
301 | llm_model=config.question_generation.litellm_config.model,
302 | embedding_model=config.question_generation.litellm_config.embedding_model,
303 | dedup_enabled=config.question_generation.deduplication_enabled,
304 | similarity_threshold=config.question_generation.similarity_threshold,
305 | max_workers=config.question_generation.max_workers,
306 | model_api_base=config.question_generation.litellm_config.model_api_base,
307 | embedding_api_base=config.question_generation.litellm_config.embedding_api_base,
308 | embedding_batch_size=config.question_generation.dedup_embedding_batch_size,
309 | llm_calls_per_minute=config.question_generation.llm_calls_per_minute,
310 | embedding_calls_per_minute=config.question_generation.embedding_calls_per_minute
311 | )
312 |
313 | # Get unique texts and build a text-to-id mapping
314 | text_to_chunk_map = {}
315 | for item in kb_dataset:
316 | text_val = item["text"]
317 | if text_val not in text_to_chunk_map:
318 | text_to_chunk_map[text_val] = []
319 | text_to_chunk_map[text_val].append(item["id"])
320 |
321 | unique_texts = list(text_to_chunk_map.keys())
322 | logger.info(f"Found {len(unique_texts)} unique chunks")
323 |
324 | # Generate questions
325 | questions = generator.generate_for_chunks(unique_texts)
326 | logger.info(f"Generated {len(questions)} questions after deduplication")
327 |
328 | # Track metrics
329 | metrics = {
330 | "num_questions_original": sum(len(generator._question_cache[chunk]) for chunk in generator._question_cache),
331 | "num_questions_deduped": len(questions)
332 | }
333 | logger.info(f"Question generation metrics: {metrics}")
334 |
335 | # Create training records
336 | train_records = []
337 | skipped_questions = 0
338 | for q in questions:
339 | chunk_text = q.get("chunk_text")
340 | if not chunk_text:
341 | skipped_questions += 1
342 | continue
343 |
344 | chunk_ids = text_to_chunk_map.get(chunk_text, [])
345 | if not chunk_ids:
346 | skipped_questions += 1
347 | logger.warning(f"Could not find chunk_id for question: {q['question'][:100]}...")
348 | continue
349 |
350 | # Create a record for each matching chunk
351 | for chunk_id in chunk_ids:
352 | train_records.append({
353 | "anchor": q["question"],
354 | "positive": chunk_text,
355 | "question_id": q["id"],
356 | "chunk_id": chunk_id
357 | })
358 |
359 | logger.info(f"Created {len(train_records)} training records (skipped {skipped_questions} questions)")
360 |
361 | # Save results
362 | if config.question_generation.output_path:
363 | output_path = Path(config.question_generation.output_path)
364 | output_path.parent.mkdir(parents=True, exist_ok=True)
365 |
366 | try:
367 | with open(output_path, 'w', encoding='utf-8') as f:
368 | json.dump(train_records, f, indent=2, ensure_ascii=False)
369 | logger.info(f"Saved training records to {output_path}")
370 | except Exception as e:
371 | logger.error(f"Failed to save training records: {str(e)}")
372 |
373 | # Handle upload if configured
374 | if (config.hub_username and
375 | config.question_generation.upload_config and
376 | config.question_generation.upload_config.push_to_hub):
377 | try:
378 | pusher = DatasetPusher(
379 | username=config.hub_username,
380 | token=config.hub_token
381 | )
382 |
383 | repository_id = config.question_generation.upload_config.hub_dataset_id
384 |
385 | question_gen_info = {
386 | 'model_name': config.question_generation.litellm_config.model,
387 | 'similarity_threshold': config.question_generation.similarity_threshold,
388 | 'num_questions': metrics['num_questions_original'],
389 | 'num_deduped': metrics['num_questions_deduped']
390 | }
391 |
392 | pusher.push_dataset(
393 | hub_dataset_id=repository_id,
394 | train_path=config.question_generation.output_path,
395 | question_gen_info=question_gen_info,
396 | private=config.question_generation.upload_config.hub_private
397 | )
398 | logger.info(f"Successfully uploaded train dataset to Hub: {repository_id}")
399 | except Exception as e:
400 | logger.error(f"Failed to upload train dataset to Hub: {str(e)}")
401 |
402 | return train_records, metrics
403 |
404 | def train_model(config: PipelineConfig, kb_dataset: List[Dict[str, Any]], train_dataset: List[Dict[str, Any]]):
405 | """Train the embedding model."""
406 | train_main(config, train_dataset=train_dataset, kb_dataset=kb_dataset)
407 |
408 | def upload_to_hub(
409 | config: PipelineConfig,
410 | kb_dataset: List[Dict[str, Any]],
411 | train_dataset: Optional[List[Dict[str, Any]]] = None,
412 | question_metrics: Optional[Dict[str, int]] = None
413 | ):
414 | """Upload datasets to Hugging Face Hub."""
415 |
416 | if not config.hub_username:
417 | logger.warning("No 'hub_username' specified, skipping upload.")
418 | return
419 |
420 | try:
421 | # Initialize pusher
422 | pusher = DatasetPusher(
423 | username=config.hub_username,
424 | token=config.hub_token
425 | )
426 |
427 | # Get repository name from output path
428 | repository_name = Path(config.output_path).stem
429 |
430 | # Collect chunker info
431 | chunker_info = {
432 | 'chunker_name': config.chunker,
433 | 'chunker_params': config.chunker_arguments
434 | }
435 |
436 | # Collect question generation info if enabled
437 | question_gen_info = None
438 | question_gen_info = {
439 | 'model_name': config.question_generation.model,
440 | 'similarity_threshold': config.deduplication.similarity_threshold,
441 | 'num_questions': question_metrics['num_questions_original'] if question_metrics else len(train_dataset),
442 | 'num_deduped': question_metrics['num_questions_deduped'] if question_metrics else len(train_dataset)
443 | }
444 |
445 | # Push dataset
446 | pusher.push_dataset(
447 | repository_name=repository_name,
448 | knowledgebase_path=config.output_path,
449 | chunker_info=chunker_info,
450 | train_path=config.question_output_path if train_dataset else None,
451 | question_gen_info=question_gen_info,
452 | private=config.hub_private
453 | )
454 | except Exception as e:
455 | logger.error(f"Failed to upload to Hugging Face Hub: {str(e)}")
456 |
457 | def run_pipeline(config: PipelineConfig):
458 | """Run the QuicKB pipeline."""
459 | from_stage = PipelineStage[config.pipeline["from_stage"]]
460 | to_stage = PipelineStage[config.pipeline["to_stage"]]
461 |
462 | kb_dataset = None
463 | train_dataset = None
464 | question_metrics = None
465 |
466 | # 1. CHUNK
467 | if from_stage.value <= PipelineStage.CHUNK.value <= to_stage.value:
468 | logger.info("Running CHUNK stage.")
469 | kb_dataset = process_chunks(config)
470 | else:
471 | logger.info("Skipping CHUNK stage.")
472 |
473 | # 2. GENERATE
474 | if from_stage.value <= PipelineStage.GENERATE.value <= to_stage.value:
475 | # Load knowledgebase dataset if needed for GENERATE
476 | if not kb_dataset:
477 | input_config = config.question_generation.input_dataset_config
478 | if input_config.dataset_source == "hub":
479 | logger.info(f"Loading knowledgebase dataset from Hub: {input_config.knowledgebase_dataset_id}")
480 | kb_dataset = load_dataset_from_hub(input_config.knowledgebase_dataset_id)
481 | elif input_config.dataset_source == "local":
482 | local_kb_path = input_config.local_knowledgebase_path or config.chunker_config.output_path
483 | logger.info(f"Loading knowledgebase dataset from local path: {local_kb_path}")
484 | kb_dataset = load_dataset_from_local(local_kb_path)
485 |
486 | logger.info("Running GENERATE stage.")
487 | train_dataset, question_metrics = generate_questions(config, kb_dataset)
488 |
489 | # 3. TRAIN
490 | if from_stage.value <= PipelineStage.TRAIN.value <= to_stage.value:
491 | logger.info("Running TRAIN stage.")
492 | if not config.training:
493 | raise ValueError("No training config found, cannot run TRAIN stage.")
494 |
495 | train_config = config.training.train_dataset_config
496 |
497 | # Load datasets for training if needed
498 | if train_config.dataset_source == "hub":
499 | logger.info("Loading datasets from Hub for training...")
500 | if not train_dataset:
501 | logger.info(f"Loading training dataset from Hub: {train_config.train_dataset_id}")
502 | train_dataset = load_dataset_from_hub(train_config.train_dataset_id)
503 | if not kb_dataset:
504 | logger.info(f"Loading knowledgebase dataset from Hub: {train_config.knowledgebase_dataset_id}")
505 | kb_dataset = load_dataset_from_hub(train_config.knowledgebase_dataset_id)
506 | else: # local
507 | logger.info("Loading datasets from local files for training...")
508 | if not train_dataset:
509 | train_path = train_config.local_train_path or config.question_generation.output_path
510 | logger.info(f"Loading training dataset from local path: {train_path}")
511 | train_dataset = load_dataset_from_local(train_path)
512 | if not kb_dataset:
513 | kb_path = train_config.local_knowledgebase_path or config.chunker_config.output_path
514 | logger.info(f"Loading knowledgebase dataset from local path: {kb_path}")
515 | kb_dataset = load_dataset_from_local(kb_path)
516 |
517 | if not kb_dataset or not train_dataset:
518 | raise ValueError("Failed to load required datasets for training")
519 |
520 | train_model(config, kb_dataset, train_dataset)
521 |
522 | logger.info("Pipeline complete!")
523 |
524 | if __name__ == "__main__":
525 | logging.basicConfig(level=logging.INFO)
526 | try:
527 | config = load_pipeline_config("config.yaml")
528 | run_pipeline(config)
529 | except Exception as e:
530 | logger.error(f"Fatal error: {str(e)}")
531 | raise
--------------------------------------------------------------------------------
/src/prompts/question_generation.txt:
--------------------------------------------------------------------------------
1 | You are a precise question generator that creates specific, factual questions from provided text. Each question must be answerable using only the explicit information in the text.
2 |
3 | Requirements for questions:
4 | 1. Must be answerable using ONLY information explicitly stated in the text
5 | 2. Must have a single, unambiguous answer found directly in the text
6 | 3. Must focus on concrete facts, not interpretation or inference
7 | 4. Must not require external knowledge or context
8 | 5. Must be specific to a single topic or point
9 | 6. Must avoid compound questions using "and" or "or"
10 |
11 | Output Format:
12 | {
13 | "questions": [
14 | {
15 | "id": number,
16 | "question": "text of the question",
17 | "answer_location": "exact quote from the text containing the answer",
18 | "explanation": "brief explanation of why this is a good question"
19 | }
20 | ]
21 | }
22 |
23 | Examples:
24 |
25 | Text: "The city council voted 7-2 to approve the new parking ordinance on March 15, 2024. The ordinance will increase parking meter rates from $2.00 to $3.50 per hour in the downtown district, effective July 1, 2024."
26 |
27 | Good Questions:
28 | {
29 | "questions": [
30 | {
31 | "id": 1,
32 | "question": "What was the vote count for the parking ordinance?",
33 | "answer_location": "The city council voted 7-2 to approve the new parking ordinance",
34 | "explanation": "Asks about a specific, explicitly stated numerical fact"
35 | },
36 | {
37 | "id": 2,
38 | "question": "What will be the new parking meter rate per hour?",
39 | "answer_location": "increase parking meter rates from $2.00 to $3.50 per hour",
40 | "explanation": "Focuses on a single, clearly stated numerical change"
41 | }
42 | ]
43 | }
44 |
45 | Bad Questions:
46 | - "Why did some council members vote against the ordinance?" (Requires interpretation)
47 | - "How will this affect local businesses?" (Not addressed in text)
48 | - "What are the current and new parking rates?" (Compound question)
49 |
50 | For the given user text, generate 4 questions following these requirements. Each question should focus on a different aspect of the text. Do not reference the excerpt.
--------------------------------------------------------------------------------
/src/synth_dataset/deduplicator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from typing import List, Dict
3 | import logging
4 |
5 | logger = logging.getLogger(__name__)
6 |
7 | class QuestionDeduplicator:
8 | def __init__(self, similarity_threshold: float = 0.85):
9 | self.similarity_threshold = similarity_threshold
10 |
11 | def _calculate_similarity_matrix(self, embeddings: List[List[float]]) -> np.ndarray:
12 | """Calculate the cosine similarity matrix for the embeddings."""
13 | embeddings_matrix = np.array(embeddings)
14 | # Normalize the vectors
15 | norms = np.linalg.norm(embeddings_matrix, axis=1, keepdims=True)
16 | embeddings_matrix = embeddings_matrix / norms
17 | return np.dot(embeddings_matrix, embeddings_matrix.T)
18 |
19 | def _filter_similar_questions(self, similarity_matrix: np.ndarray) -> List[int]:
20 | """
21 | Filter out questions that are too similar to others.
22 | Returns indices of questions to keep.
23 | """
24 | n = similarity_matrix.shape[0]
25 | keep_mask = np.ones(n, dtype=bool)
26 |
27 | # For each question
28 | for i in range(n):
29 | if keep_mask[i]:
30 | # Find questions that are too similar and remove them
31 | similar_questions = np.where(similarity_matrix[i] > self.similarity_threshold)[0]
32 | # Only look at questions after the current one
33 | similar_questions = similar_questions[similar_questions > i]
34 | keep_mask[similar_questions] = False
35 |
36 | return np.where(keep_mask)[0]
37 |
38 | def deduplicate(self, questions: List[Dict], embeddings: List[List[float]]) -> List[Dict]:
39 | """
40 | Remove duplicate and similar questions based on embedding similarity.
41 |
42 | Args:
43 | questions: List of question dictionaries
44 | embeddings: List of embedding vectors for questions
45 |
46 | Returns:
47 | List of filtered question dictionaries
48 | """
49 | if not questions:
50 | return []
51 |
52 | # First remove exact duplicates based on question text
53 | seen_questions = set()
54 | unique_questions = []
55 | unique_embeddings = []
56 |
57 | for q, emb in zip(questions, embeddings):
58 | if q["question"] not in seen_questions:
59 | seen_questions.add(q["question"])
60 | unique_questions.append(q)
61 | unique_embeddings.append(emb)
62 |
63 | if not unique_questions:
64 | return []
65 |
66 | # Calculate similarity matrix
67 | similarity_matrix = self._calculate_similarity_matrix(unique_embeddings)
68 |
69 | # Get indices of questions to keep
70 | keep_indices = self._filter_similar_questions(similarity_matrix)
71 |
72 | # Filter questions
73 | filtered_questions = [unique_questions[i] for i in keep_indices]
74 |
75 | logger.info(
76 | f"Deduplication: {len(questions)} -> {len(unique_questions)} -> "
77 | f"{len(filtered_questions)} questions after filtering"
78 | )
79 | return filtered_questions
--------------------------------------------------------------------------------
/src/synth_dataset/question_generator.py:
--------------------------------------------------------------------------------
1 | import threading
2 | from threading import Lock, RLock
3 | from typing import Dict, List, Optional
4 | from collections import deque
5 | import json
6 | import uuid
7 | import logging
8 | import backoff
9 | from litellm import completion, embedding
10 | from concurrent.futures import ThreadPoolExecutor, as_completed
11 | from tqdm import tqdm
12 | from .deduplicator import QuestionDeduplicator
13 | from .rate_limiter import RateLimiter
14 |
15 | logger = logging.getLogger(__name__)
16 | logging.getLogger("LiteLLM").setLevel(logging.WARNING)
17 |
18 | class QuestionGenerator:
19 | def __init__(
20 | self,
21 | prompt_path: str,
22 | api_key: str = None,
23 | llm_model: str = "openai/gpt-4o-mini",
24 | embedding_model: str = "text-embedding-3-large",
25 | dedup_enabled: bool = True,
26 | similarity_threshold: float = 0.85,
27 | max_workers: int = 20,
28 | model_api_base: str = None,
29 | embedding_api_base: str = None,
30 | embedding_batch_size: int = 500,
31 | llm_calls_per_minute: int = 15,
32 | embedding_calls_per_minute: int = 15
33 | ):
34 | # Initialize basic attributes
35 | self.api_key = api_key
36 | self.llm_model = llm_model
37 | self.embedding_model = embedding_model
38 | self.prompt = self._load_prompt(prompt_path)
39 | self.dedup_enabled = dedup_enabled
40 | self.max_workers = max_workers
41 | self.model_api_base = model_api_base
42 | self.embedding_api_base = embedding_api_base
43 | self.embedding_batch_size = embedding_batch_size
44 |
45 | # Thread safety mechanisms
46 | self._question_cache: Dict[str, List[Dict]] = {}
47 | self._cache_lock = RLock() # Use RLock for recursive locking capability
48 | self._embedding_lock = Lock() # Separate lock for embedding operations
49 |
50 | # Initialize deduplicator if enabled
51 | self.deduplicator = QuestionDeduplicator(similarity_threshold) if dedup_enabled else None
52 |
53 | # Initialize rate limiters with their own internal locks
54 | self.llm_rate_limiter = RateLimiter(llm_calls_per_minute, name="LLM") if llm_calls_per_minute is not None else None
55 | self.embedding_rate_limiter = RateLimiter(embedding_calls_per_minute, name="Embedding") if embedding_calls_per_minute is not None else None
56 |
57 | def _load_prompt(self, path: str) -> str:
58 | """Load the prompt template from a file."""
59 | try:
60 | with open(path, 'r', encoding='utf-8') as f:
61 | return f.read()
62 | except FileNotFoundError:
63 | logger.error(f"Prompt template file not found: {path}")
64 | raise
65 | except IOError as e:
66 | logger.error(f"Error reading prompt template: {str(e)}")
67 | raise
68 |
69 | def _get_from_cache(self, chunk: str) -> Optional[List[Dict]]:
70 | """Thread-safe cache retrieval."""
71 | with self._cache_lock:
72 | return self._question_cache.get(chunk)
73 |
74 | def _add_to_cache(self, chunk: str, questions: List[Dict]) -> None:
75 | """Thread-safe cache addition."""
76 | with self._cache_lock:
77 | self._question_cache[chunk] = questions
78 |
79 | @backoff.on_exception(
80 | backoff.expo,
81 | Exception,
82 | max_tries=3,
83 | max_time=30 # Maximum total time to try
84 | )
85 | def _generate(self, chunk: str) -> str:
86 | """Generate questions for a single chunk with rate limiting."""
87 | # Wait for rate limit if needed
88 | if self.llm_rate_limiter:
89 | self.llm_rate_limiter.wait_if_needed()
90 |
91 | completion_kwargs = {
92 | "model": self.llm_model,
93 | "messages": [
94 | {"role": "system", "content": self.prompt},
95 | {"role": "user", "content": f"Text: {chunk}"}
96 | ],
97 | "temperature": 0.7,
98 | "api_key": self.api_key,
99 | "timeout": 10 # Add timeout for API calls
100 | }
101 |
102 | if self.model_api_base:
103 | completion_kwargs["api_base"] = self.model_api_base
104 |
105 | response = completion(**completion_kwargs)
106 | return response.choices[0].message.content
107 |
108 | def generate_for_chunk(self, chunk: str) -> List[Dict]:
109 | """Generate questions for a single chunk with caching."""
110 | # Check cache first
111 | cached_questions = self._get_from_cache(chunk)
112 | if cached_questions is not None:
113 | return cached_questions
114 |
115 | try:
116 | response = self._generate(chunk)
117 | questions = json.loads(response)["questions"]
118 |
119 | # Process questions
120 | processed_questions = []
121 | for q in questions:
122 | q.update({
123 | "id": str(uuid.uuid4()),
124 | "chunk_text": chunk,
125 | })
126 | q.pop("explanation", None)
127 | processed_questions.append(q)
128 |
129 | # Add to cache
130 | self._add_to_cache(chunk, processed_questions)
131 | return processed_questions
132 | except Exception as e:
133 | logger.error(f"Error generating questions: {str(e)}")
134 | return []
135 |
136 | def _process_embeddings_batch(self, questions_batch: List[str]) -> List[List[float]]:
137 | """Process a batch of questions for embeddings with thread safety."""
138 | with self._embedding_lock:
139 | if self.embedding_rate_limiter:
140 | self.embedding_rate_limiter.wait_if_needed()
141 |
142 | embedding_kwargs = {
143 | "model": self.embedding_model,
144 | "input": questions_batch,
145 | "api_key": self.api_key,
146 | "timeout": 10
147 | }
148 |
149 | if self.embedding_api_base:
150 | embedding_kwargs["api_base"] = self.embedding_api_base
151 |
152 | response = embedding(**embedding_kwargs)
153 | return [data["embedding"] for data in response.data]
154 |
155 | def generate_for_chunks(self, chunks: List[str]) -> List[Dict]:
156 | """Generate questions for multiple chunks with thread safety."""
157 | results = []
158 | uncached_chunks = []
159 |
160 | # Check cache first
161 | for chunk in chunks:
162 | cached_results = self._get_from_cache(chunk)
163 | if cached_results is not None:
164 | results.extend(cached_results)
165 | else:
166 | uncached_chunks.append(chunk)
167 |
168 | if uncached_chunks:
169 | with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
170 | future_to_chunk = {
171 | executor.submit(self.generate_for_chunk, chunk): chunk
172 | for chunk in uncached_chunks
173 | }
174 |
175 | with tqdm(total=len(uncached_chunks), desc="Generating questions") as pbar:
176 | for future in as_completed(future_to_chunk):
177 | try:
178 | questions = future.result()
179 | results.extend(questions)
180 | except Exception as e:
181 | chunk = future_to_chunk[future]
182 | logger.error(f"Error processing chunk {chunk[:50]}...: {str(e)}")
183 | pbar.update(1)
184 |
185 | if self.dedup_enabled and self.deduplicator and results:
186 | # Process embeddings in batches
187 | questions_text = [q["question"] for q in results]
188 | all_embeddings = []
189 |
190 | # Calculate total number of batches for the progress bar
191 | num_batches = (len(questions_text) + self.embedding_batch_size - 1) // self.embedding_batch_size
192 |
193 | # Add progress bar for embedding batches
194 | with tqdm(total=num_batches, desc="Processing embeddings", unit="batch") as pbar:
195 | for i in range(0, len(questions_text), self.embedding_batch_size):
196 | batch = questions_text[i:i + self.embedding_batch_size]
197 | try:
198 | batch_embeddings = self._process_embeddings_batch(batch)
199 | all_embeddings.extend(batch_embeddings)
200 | pbar.update(1)
201 |
202 | # Optional: add batch statistics to progress bar
203 | pbar.set_postfix({
204 | 'batch_size': len(batch),
205 | 'total_embedded': len(all_embeddings)
206 | })
207 |
208 | except Exception as e:
209 | logger.error(f"Error during embedding batch {i//self.embedding_batch_size}: {str(e)}")
210 | return results # Return un-deduplicated results on error
211 |
212 | # Deduplicate with thread safety
213 | with self._cache_lock:
214 | results = self.deduplicator.deduplicate(results, all_embeddings)
215 |
216 | return results
--------------------------------------------------------------------------------
/src/synth_dataset/rate_limiter.py:
--------------------------------------------------------------------------------
1 | from threading import Lock
2 | from time import time, sleep
3 | from collections import deque
4 | import logging
5 |
6 | logger = logging.getLogger(__name__)
7 |
8 | from typing import Optional
9 |
10 | class RateLimiter:
11 | """Thread-safe rate limiter using a rolling window."""
12 |
13 | def __init__(self, calls_per_minute: int, name: str = "default"):
14 | """
15 | Initialize rate limiter.
16 |
17 | Args:
18 | calls_per_minute: Maximum number of calls allowed per minute
19 | name: Name for this rate limiter instance (for logging)
20 | """
21 | self.calls_per_minute = calls_per_minute
22 | self.name = name
23 | self.window_size = 60 # 1 minute in seconds
24 | self.timestamps = deque(maxlen=calls_per_minute)
25 | self.lock = Lock()
26 |
27 | def _clean_old_timestamps(self, current_time: float) -> None:
28 | """Remove timestamps older than the window size."""
29 | cutoff_time = current_time - self.window_size
30 | while self.timestamps and self.timestamps[0] < cutoff_time:
31 | self.timestamps.popleft()
32 |
33 | def wait_if_needed(self) -> None:
34 | """
35 | Check if rate limit is reached and wait if necessary.
36 | Thread-safe implementation.
37 | """
38 | with self.lock:
39 | current_time = time()
40 | self._clean_old_timestamps(current_time)
41 |
42 | if len(self.timestamps) >= self.calls_per_minute:
43 | # Calculate required wait time
44 | oldest_timestamp = self.timestamps[0]
45 | wait_time = oldest_timestamp + self.window_size - current_time
46 |
47 | if wait_time > 0:
48 | logger.debug(f"{self.name} rate limiter: Waiting {wait_time:.2f} seconds")
49 | sleep(wait_time)
50 | current_time = time() # Update current time after sleep
51 |
52 | # Add new timestamp
53 | self.timestamps.append(current_time)
54 |
55 | def acquire(self) -> None:
56 | """Alias for wait_if_needed for more intuitive API."""
57 | self.wait_if_needed()
58 |
59 | def remaining_calls(self) -> int:
60 | """Return number of remaining calls allowed in current window."""
61 | with self.lock:
62 | self._clean_old_timestamps(time())
63 | return self.calls_per_minute - len(self.timestamps)
--------------------------------------------------------------------------------
/src/training/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ALucek/QuicKB/288c62bd5024520388d05ece4a841900abb6d93d/src/training/__init__.py
--------------------------------------------------------------------------------
/src/training/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import logging
4 | import torch
5 | import yaml
6 | from pathlib import Path
7 | from typing import Dict, List, Any, Optional
8 | from datasets import load_dataset, Dataset
9 | from sentence_transformers import (
10 | SentenceTransformer,
11 | SentenceTransformerModelCardData,
12 | SentenceTransformerTrainingArguments,
13 | SentenceTransformerTrainer,
14 | )
15 | from sentence_transformers.losses import MultipleNegativesRankingLoss, MatryoshkaLoss
16 | from sentence_transformers.evaluation import InformationRetrievalEvaluator, SequentialEvaluator
17 | from sentence_transformers.util import cos_sim
18 | from sentence_transformers.training_args import BatchSamplers
19 | from huggingface_hub import login
20 |
21 | logging.basicConfig(level=logging.INFO)
22 | logger = logging.getLogger(__name__)
23 | logging.getLogger("sentence_transformers").setLevel(logging.WARNING)
24 | logging.getLogger("transformers").setLevel(logging.WARNING)
25 |
26 | def load_main_config(config_path_or_obj):
27 | """Load configuration from path or use provided config object."""
28 | if isinstance(config_path_or_obj, (str, bytes, os.PathLike)):
29 | with open(config_path_or_obj, "r", encoding="utf-8") as f:
30 | return yaml.safe_load(f)
31 | return config_path_or_obj.model_dump()
32 |
33 | def validate_dataset_consistency(kb_dataset: List[Dict[str, Any]], train_dataset: List[Dict[str, Any]]) -> bool:
34 | """Validate that all chunk IDs referenced in training data exist in knowledgebase."""
35 | logger.info("Validating dataset consistency...")
36 |
37 | # Get all chunk IDs from knowledgebase
38 | kb_ids = {str(item['id']) for item in kb_dataset}
39 | logger.info(f"Found {len(kb_ids)} unique chunks in knowledgebase")
40 |
41 | # Get all chunk IDs referenced in training data
42 | train_chunk_ids = {str(item['chunk_id']) for item in train_dataset}
43 | logger.info(f"Found {len(train_chunk_ids)} unique chunk references in training data")
44 |
45 | # Find missing chunks
46 | missing_chunks = train_chunk_ids - kb_ids
47 | if missing_chunks:
48 | logger.error(f"Found {len(missing_chunks)} chunk IDs in training data that don't exist in knowledgebase")
49 | logger.error(f"First few missing IDs: {list(missing_chunks)[:5]}")
50 | return False
51 |
52 | logger.info("Dataset consistency validation passed!")
53 | return True
54 |
55 | def build_evaluation_structures(kb_dataset, test_dataset, kb_id_field="id", kb_text_field="text"):
56 | corpus = {row[kb_id_field]: row[kb_text_field] for row in kb_dataset}
57 | queries = {row["id"]: row["anchor"] for row in test_dataset}
58 | relevant_docs = {}
59 | for row in test_dataset:
60 | q_id = row["id"]
61 | if "global_chunk_id" not in row:
62 | logger.warning(f"Missing 'global_chunk_id': {row}")
63 | continue
64 | if q_id not in relevant_docs:
65 | relevant_docs[q_id] = []
66 | relevant_docs[q_id].append(row["global_chunk_id"])
67 | return corpus, queries, relevant_docs
68 |
69 | def format_evaluation_results(title: str, results: dict, dim_list: list, metrics: list = None):
70 | if metrics is None:
71 | metrics = [
72 | "ndcg@10", "mrr@10", "map@100",
73 | "accuracy@1", "accuracy@3", "accuracy@5", "accuracy@10",
74 | "precision@1", "precision@3", "precision@5", "precision@10",
75 | "recall@1", "recall@3", "recall@5", "recall@10",
76 | ]
77 |
78 | # Calculate required widths for alignment
79 | max_dim_length = max(len(str(dim)) for dim in dim_list)
80 | value_format_length = 5 # '0.000' is 5 characters
81 | column_width = max(max_dim_length, value_format_length)
82 | metric_width = max(len(metric) for metric in metrics) if metrics else 10
83 |
84 | # Prepare dimension headers (left-aligned strings)
85 | dim_header = " ".join(f"{str(dim):<{column_width}}" for dim in dim_list)
86 |
87 | # Create header line and dynamically determine separator length
88 | header_line = f"{'Metric':{metric_width}} {dim_header}"
89 | separator = "-" * len(header_line)
90 |
91 | output = [
92 | f"\n{title}",
93 | separator,
94 | header_line,
95 | separator
96 | ]
97 |
98 | # Populate each metric row (values right-aligned)
99 | for metric in metrics:
100 | values = []
101 | for dim in dim_list:
102 | key = f"dim_{dim}_cosine_{metric}"
103 | val = results.get(key, 0.0)
104 | values.append(f"{val:>{column_width}.3f}") # Right-align values
105 | metric_line = f"{metric:{metric_width}} {' '.join(values)}"
106 | output.append(metric_line)
107 |
108 | # Final row
109 | output.append(separator)
110 |
111 | return "\n".join(output)
112 |
113 | def run_baseline_eval(model, evaluator, dim_list):
114 | results = evaluator(model)
115 | print(format_evaluation_results("Before Training (Baseline) Results", results, dim_list))
116 | return results
117 |
118 | def run_final_eval(model, evaluator, dim_list):
119 | results = evaluator(model)
120 | print(format_evaluation_results("After Training (Fine-Tuned) Results", results, dim_list))
121 | return results
122 |
123 | def save_metrics_to_file(before: dict, after: dict, dim_list: list, path="metrics_comparison.txt"):
124 | metrics = [
125 | "ndcg@10", "mrr@10", "map@100",
126 | "accuracy@1", "accuracy@3", "accuracy@5", "accuracy@10",
127 | "precision@1", "precision@3", "precision@5", "precision@10",
128 | "recall@1", "recall@3", "recall@5", "recall@10",
129 | ]
130 |
131 | def format_dimensions(dim_list):
132 | return [f"{dim:,}" for dim in dim_list]
133 |
134 | def write_table(f, title, headers, rows, value_formatter):
135 | # Calculate column widths
136 | is_percentage = "%" in rows[0][1]
137 |
138 | # For metric column
139 | metric_width = max(len(row[0]) for row in rows)
140 | metric_width = max(metric_width, len(headers[0]))
141 |
142 | # For value columns, ensure we account for the maximum width including +/- and %
143 | dim_widths = []
144 | for col_idx in range(1, len(headers)):
145 | col_values = [row[col_idx] for row in rows]
146 | max_width = max(len(str(val)) for val in col_values)
147 | header_width = len(headers[col_idx])
148 | dim_widths.append(max(max_width, header_width))
149 |
150 | col_widths = [metric_width] + dim_widths
151 |
152 | # Build header with consistent spacing
153 | header = headers[0].ljust(col_widths[0])
154 | for i, h in enumerate(headers[1:], 1):
155 | header += " │ " + h.center(col_widths[i])
156 |
157 | # Build separator
158 | separator = "─" * col_widths[0]
159 | for w in col_widths[1:]:
160 | separator += "─┼─" + "─" * w
161 |
162 | # Write table
163 | f.write(f"\n{title}\n")
164 | f.write("-" * len(title) + "\n")
165 | f.write(f"{header}\n{separator}\n")
166 |
167 | # Write rows with consistent spacing
168 | for row in rows:
169 | line = row[0].ljust(col_widths[0])
170 | for i, val in enumerate(row[1:], 1):
171 | if is_percentage:
172 | # Right-align percentage values
173 | line += " │ " + val.rjust(col_widths[i])
174 | else:
175 | line += " │ " + value_formatter(val, col_widths[i])
176 | f.write(line + "\n")
177 |
178 | with open(path, "w", encoding="utf-8") as f:
179 | # Main title
180 | f.write("Model Performance Metrics Comparison\n")
181 | f.write("====================================\n")
182 |
183 | # Format dimensions with thousands separators
184 | dim_strs = format_dimensions(dim_list)
185 |
186 | # Prepare data tables
187 | baseline_rows = []
188 | finetuned_rows = []
189 | delta_rows = []
190 | pct_change_rows = []
191 |
192 | for metric in metrics:
193 | bl_vals = []
194 | ft_vals = []
195 | dt_vals = []
196 | pc_vals = []
197 |
198 | for dim in dim_list:
199 | key = f"dim_{dim}_cosine_{metric}"
200 | b_val = before.get(key, 0.0)
201 | a_val = after.get(key, 0.0)
202 | delta = a_val - b_val
203 | pct_change = (delta / b_val * 100) if b_val != 0 else 0.0
204 |
205 | bl_vals.append(f"{b_val:.3f}")
206 | ft_vals.append(f"{a_val:.3f}")
207 | dt_vals.append(f"{delta:+.3f}")
208 | pc_vals.append(f"{pct_change:+.1f}%")
209 |
210 | baseline_rows.append([metric] + bl_vals)
211 | finetuned_rows.append([metric] + ft_vals)
212 | delta_rows.append([metric] + dt_vals)
213 | pct_change_rows.append([metric] + pc_vals)
214 |
215 | # Write tables
216 | headers = ["Metric"] + dim_strs
217 | num_formatter = lambda val, w: f"{val:>{w}}"
218 |
219 | write_table(f, "Baseline Performance", headers, baseline_rows, num_formatter)
220 | write_table(f, "Fine-Tuned Performance", headers, finetuned_rows, num_formatter)
221 | write_table(f, "Absolute Changes (Δ)", headers, delta_rows, num_formatter)
222 | write_table(f, "Percentage Changes", headers, pct_change_rows, num_formatter)
223 |
224 | def select_device(preferred_device: Optional[str] = None) -> str:
225 | # Handle explicit preference first
226 | if preferred_device:
227 | if preferred_device == "cuda" and torch.cuda.is_available():
228 | logger.info("Using CUDA GPU as requested")
229 | return "cuda"
230 | elif preferred_device == "mps" and hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
231 | logger.info("Using Apple Silicon MPS as requested")
232 | return "mps"
233 | elif preferred_device == "cpu":
234 | logger.info("Using CPU as requested")
235 | return "cpu"
236 | else:
237 | logger.warning(f"Requested device '{preferred_device}' not available, falling back to auto-detection")
238 |
239 | # Auto-detection (prioritize GPU > MPS > CPU)
240 | if torch.cuda.is_available():
241 | logger.info("CUDA GPU detected and selected")
242 | return "cuda"
243 | elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
244 | logger.info("Apple Silicon MPS detected and selected")
245 | return "mps"
246 | else:
247 | logger.info("No accelerator detected, using CPU")
248 | return "cpu"
249 |
250 | def main(config, train_dataset: List[Dict[str, Any]], kb_dataset: List[Dict[str, Any]]):
251 | """Main training function."""
252 | if not hasattr(config, 'training') or not config.training:
253 | raise ValueError("Training configuration is required but not provided")
254 |
255 | # Validate dataset consistency before proceeding
256 | if not validate_dataset_consistency(kb_dataset, train_dataset):
257 | raise ValueError("Dataset consistency check failed - chunks and training data are mismatched. "
258 | "This usually happens when chunks have been regenerated after question generation.")
259 |
260 | # Convert lists to dataset objects if needed
261 | train_dataset_full = Dataset.from_list(train_dataset)
262 |
263 | if "id" not in train_dataset_full.column_names:
264 | train_dataset_full = train_dataset_full.add_column("id", list(range(len(train_dataset_full))))
265 | if "chunk_id" in train_dataset_full.column_names:
266 | train_dataset_full = train_dataset_full.rename_column("chunk_id", "global_chunk_id")
267 | train_dataset_full = train_dataset_full.shuffle()
268 | dataset_split = train_dataset_full.train_test_split(test_size=0.1)
269 | train_dataset = dataset_split["train"]
270 | test_dataset = dataset_split["test"]
271 | logger.info(f"Train size: {len(train_dataset)} | Test size: {len(test_dataset)}")
272 |
273 | # Build evaluation structures
274 | corpus, queries, relevant_docs = build_evaluation_structures(
275 | kb_dataset=kb_dataset,
276 | test_dataset=test_dataset,
277 | kb_id_field="id",
278 | kb_text_field="text"
279 | )
280 |
281 | # Setup evaluators
282 | dim_list = config.training.model_settings.matryoshka_dimensions
283 | evaluators = []
284 | for d in dim_list:
285 | evaluators.append(
286 | InformationRetrievalEvaluator(
287 | queries=queries,
288 | corpus=corpus,
289 | relevant_docs=relevant_docs,
290 | name=f"dim_{d}",
291 | score_functions={"cosine": cos_sim},
292 | truncate_dim=d
293 | )
294 | )
295 | evaluator = SequentialEvaluator(evaluators)
296 |
297 | # Initialize base model and run baseline evaluation
298 | device = select_device(config.training.training_arguments.device)
299 | base_model = SentenceTransformer(config.training.model_settings.model_id, device=device, trust_remote_code=config.training.model_settings.trust_remote_code)
300 | base_model.max_seq_length = config.training.model_settings.max_seq_length
301 | baseline_results = run_baseline_eval(base_model, evaluator, dim_list)
302 |
303 | logger.info("Re-initializing for training.")
304 | model = SentenceTransformer(
305 | config.training.model_settings.model_id,
306 | device=device,
307 | trust_remote_code=config.training.model_settings.trust_remote_code,
308 | model_kwargs={"attn_implementation": "sdpa"},
309 | model_card_data=SentenceTransformerModelCardData(
310 | language="en",
311 | license="apache-2.0",
312 | model_name="Fine-tuned with [QuicKB](https://github.com/ALucek/QuicKB)",
313 | ),
314 | )
315 | model.max_seq_length = config.training.model_settings.max_seq_length
316 |
317 | # Setup loss functions
318 | base_loss = MultipleNegativesRankingLoss(model)
319 | train_loss = MatryoshkaLoss(
320 | model=model,
321 | loss=base_loss,
322 | matryoshka_dims=dim_list
323 | )
324 |
325 | # Setup training arguments
326 | args = SentenceTransformerTrainingArguments(
327 | output_dir=config.training.training_arguments.output_path,
328 | num_train_epochs=config.training.training_arguments.epochs,
329 | per_device_train_batch_size=config.training.training_arguments.batch_size,
330 | gradient_accumulation_steps=config.training.training_arguments.gradient_accumulation_steps,
331 | learning_rate=config.training.training_arguments.learning_rate,
332 | warmup_ratio=config.training.training_arguments.warmup_ratio,
333 | lr_scheduler_type=config.training.training_arguments.lr_scheduler_type,
334 | optim=config.training.training_arguments.optim,
335 | tf32=config.training.training_arguments.tf32,
336 | bf16=config.training.training_arguments.bf16,
337 | batch_sampler=config.training.training_arguments.batch_sampler.value,
338 | eval_strategy=config.training.training_arguments.eval_strategy,
339 | save_strategy=config.training.training_arguments.save_strategy,
340 | logging_steps=config.training.training_arguments.logging_steps,
341 | save_total_limit=config.training.training_arguments.save_total_limit,
342 | load_best_model_at_end=config.training.training_arguments.load_best_model_at_end,
343 | metric_for_best_model=config.training.model_settings.metric_for_best_model,
344 | report_to=config.training.training_arguments.report_to,
345 | )
346 |
347 | # Prepare final training dataset and trainer
348 | final_train_dataset = train_dataset.select_columns(["anchor", "positive"])
349 | trainer = SentenceTransformerTrainer(
350 | model=model,
351 | args=args,
352 | train_dataset=final_train_dataset,
353 | loss=train_loss,
354 | evaluator=evaluator,
355 | )
356 |
357 | # Train model
358 | logger.info("Starting training...")
359 | trainer.train()
360 | trainer.save_model()
361 |
362 | # Evaluate fine-tuned model
363 | fine_tuned_model = SentenceTransformer(config.training.training_arguments.output_path, device=device, trust_remote_code=config.training.model_settings.trust_remote_code)
364 | fine_tuned_model.max_seq_length = config.training.model_settings.max_seq_length
365 | final_results = run_final_eval(fine_tuned_model, evaluator, dim_list)
366 |
367 | # Save metrics
368 | save_metrics_to_file(
369 | baseline_results,
370 | final_results,
371 | dim_list,
372 | path=f"{config.training.training_arguments.output_path}/metrics_comparison.txt"
373 | )
374 |
375 | # Handle model upload if configured
376 | if (config.training.upload_config and
377 | config.training.upload_config.push_to_hub):
378 |
379 | if not config.hub_token and not os.getenv("HF_TOKEN"):
380 | logger.warning("No HF_TOKEN in env or config, attempting login anyway.")
381 |
382 | if not config.training.upload_config.hub_model_id:
383 | logger.warning("No hub_model_id specified, skipping upload")
384 | else:
385 | logger.info(f"Pushing model to HF Hub: {config.training.upload_config.hub_model_id}")
386 | trainer.model.push_to_hub(
387 | config.training.upload_config.hub_model_id,
388 | exist_ok=True,
389 | private=config.training.upload_config.hub_private
390 | )
391 | logger.info("Upload complete!")
392 |
393 | logger.info("Training pipeline finished.")
394 |
--------------------------------------------------------------------------------