├── .gitignore ├── CHANGELOG.md ├── README.md ├── docs ├── README.md ├── RELEASE.md ├── api-reference.md ├── architecture.md ├── configuration.md ├── core-components.md ├── development-guide.md ├── getting-started.md ├── overview.md └── rate-limiting.md ├── embedding_results.json ├── examples ├── embedding_example.py ├── notebook_test.ipynb └── parallel_test.py ├── fastllm ├── __init__.py ├── cache.py ├── core.py └── providers │ ├── __init__.py │ ├── base.py │ └── openai.py ├── justfile ├── pyproject.toml └── tests ├── test_asyncio.py ├── test_cache.py ├── test_core.py ├── test_openai.py ├── test_progress_tracker.py ├── test_providers.py └── test_request_batch.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | build/ 8 | develop-eggs/ 9 | dist/ 10 | downloads/ 11 | eggs/ 12 | .eggs/ 13 | lib/ 14 | lib64/ 15 | parts/ 16 | sdist/ 17 | var/ 18 | wheels/ 19 | *.egg-info/ 20 | .installed.cfg 21 | *.egg 22 | 23 | # Virtual Environment 24 | .env 25 | .venv 26 | env/ 27 | venv/ 28 | ENV/ 29 | 30 | # IDE 31 | .idea/ 32 | .vscode/ 33 | *.swp 34 | *.swo 35 | .DS_Store 36 | 37 | # Project specific 38 | results.json 39 | parallel_test_results.json 40 | .coverage 41 | htmlcov/ 42 | .pytest_cache/ 43 | .ruff_cache/ 44 | .mypy_cache/ 45 | cache/ 46 | uv.lock 47 | *.lock -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), 6 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 7 | 8 | ## [0.1.0] - 2024-03 9 | 10 | ### Added 11 | - Initial release with core functionality 12 | - Parallel request processing with configurable concurrency 13 | - In-memory and disk-based caching support 14 | - Multiple provider support (OpenAI, OpenRouter) 15 | - Request batching with OpenAI-style API 16 | - Progress tracking and statistics 17 | - Request deduplication and response ordering 18 | - Configurable retry mechanism 19 | - Rich progress bar with detailed statistics 20 | - Support for existing asyncio event loops 21 | - Jupyter notebook compatibility 22 | - Request ID (cache key) return from batch creation methods 23 | 24 | ### Changed 25 | - Optimized request processing for better performance 26 | - Improved error handling and reporting 27 | - Enhanced request ID handling and exposure 28 | - Added compatibility with existing asyncio event loops 29 | - Fixed asyncio loop handling in Jupyter notebooks 30 | - Made request IDs accessible to users for cache management 31 | 32 | ### Fixed 33 | - Request ID validation and string conversion 34 | - Cache persistence issues 35 | - Response ordering in parallel processing -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FastLLM 2 | 3 | High-performance parallel LLM API request tool with support for multiple providers and caching capabilities. 4 | 5 | ## Features 6 | 7 | - Parallel request processing with configurable concurrency 8 | - Allows you to process 20000+ prompt tokens per second and 1500+ output tokens per second even for extremely large LLMs, such as Deepseek-V3. 9 | - Built-in caching support (in-memory and disk-based) 10 | - Progress tracking with token usage statistics 11 | - Support for multiple LLM providers (OpenAI, OpenRouter, etc.) 12 | - OpenAI-style API for request batching 13 | - Retry mechanism with configurable attempts and delays 14 | - Request deduplication and response ordering 15 | 16 | ## Installation 17 | 18 | Use pip: 19 | ```bash 20 | pip install fastllm-kit 21 | ``` 22 | 23 | Alternatively, use uv: 24 | ```bash 25 | uv pip install fastllm-kit 26 | ``` 27 | 28 | > **Important:** fastllm does not support yet libsqlite3.49.1, please use libsqlite3.49.0 or lower. See [this issue](https://github.com/grantjenks/python-diskcache/issues/343) for more details. This might be an issue for users with conda environments. 29 | 30 | For development: 31 | ```bash 32 | # Clone the repository 33 | git clone https://github.com/Rexhaif/fastllm.git 34 | cd fastllm 35 | 36 | # Create a virtual environment and install dependencies 37 | uv venv 38 | uv pip install -e ".[dev]" 39 | ``` 40 | 41 | ## Dependencies 42 | 43 | FastLLM requires Python 3.9 or later and depends on the following packages: 44 | 45 | - `httpx` (^0.27.2) - For async HTTP requests 46 | - `pydantic` (^2.10.6) - For data validation and settings management 47 | - `rich` (^13.9.4) - For beautiful terminal output and progress bars 48 | - `diskcache` (^5.6.3) - For persistent disk caching 49 | - `asyncio` (^3.4.3) - For asynchronous operations 50 | - `anyio` (^4.8.0) - For async I/O operations 51 | - `tqdm` (^4.67.1) - For progress tracking 52 | - `typing_extensions` (^4.12.2) - For enhanced type hints 53 | 54 | Development dependencies: 55 | - `ruff` (^0.3.7) - For linting and formatting 56 | - `pytest` (^8.3.4) - For testing 57 | - `pytest-asyncio` (^0.23.8) - For async tests 58 | - `pytest-cov` (^4.1.0) - For test coverage 59 | - `black` (^24.10.0) - For code formatting 60 | - `coverage` (^7.6.10) - For code coverage reporting 61 | 62 | ## Development 63 | 64 | The project uses [just](https://github.com/casey/just) for task automation and [uv](https://github.com/astral/uv) for dependency management. 65 | 66 | Common tasks: 67 | ```bash 68 | # Install dependencies 69 | just install 70 | 71 | # Run tests 72 | just test 73 | 74 | # Format code 75 | just format 76 | 77 | # Run linting 78 | just lint 79 | 80 | # Clean up cache files 81 | just clean 82 | ``` 83 | 84 | ## Quick Start 85 | 86 | ```python 87 | from fastllm import RequestBatch, RequestManager, OpenAIProvider, InMemoryCache 88 | 89 | # Create a provider 90 | provider = OpenAIProvider( 91 | api_key="your-api-key", 92 | # Optional: custom API base URL 93 | api_base="https://api.openai.com/v1", 94 | ) 95 | 96 | # Create a cache provider (optional) 97 | cache = InMemoryCache() # or DiskCache(directory="./cache") 98 | 99 | # Create a request manager 100 | manager = RequestManager( 101 | provider=provider, 102 | concurrency=50, # Number of concurrent requests 103 | show_progress=True, # Show progress bar 104 | caching_provider=cache, # Enable caching 105 | ) 106 | 107 | # Create a batch of requests 108 | request_ids = [] # Store request IDs for later use 109 | with RequestBatch() as batch: 110 | # Add requests to the batch 111 | for i in range(10): 112 | # create() returns the request ID (caching key) 113 | request_id = batch.chat.completions.create( 114 | model="gpt-3.5-turbo", 115 | messages=[{ 116 | "role": "user", 117 | "content": f"What is {i} + {i}?" 118 | }], 119 | temperature=0.7, 120 | include_reasoning=True, # Optional: include model reasoning 121 | ) 122 | request_ids.append(request_id) 123 | 124 | # Process the batch 125 | responses = manager.process_batch(batch) 126 | 127 | # Process responses 128 | for request_id, response in zip(request_ids, responses): 129 | print(f"Request {request_id}: {response.response.choices[0].message.content}") 130 | 131 | # You can use request IDs to check cache status 132 | for request_id in request_ids: 133 | is_cached = await cache.exists(request_id) 134 | print(f"Request {request_id} is {'cached' if is_cached else 'not cached'}") 135 | ``` 136 | 137 | ## Advanced Usage 138 | 139 | ### Async Support 140 | 141 | FastLLM can be used both synchronously and asynchronously, and works seamlessly in regular Python environments, async applications, and Jupyter notebooks: 142 | 143 | ```python 144 | import asyncio 145 | from fastllm import RequestBatch, RequestManager, OpenAIProvider 146 | 147 | # Works in Jupyter notebooks 148 | provider = OpenAIProvider(api_key="your-api-key") 149 | manager = RequestManager(provider=provider) 150 | responses = manager.process_batch(batch) # Just works! 151 | 152 | # Works in async applications 153 | async def process_requests(): 154 | provider = OpenAIProvider(api_key="your-api-key") 155 | manager = RequestManager(provider=provider) 156 | 157 | with RequestBatch() as batch: 158 | batch.chat.completions.create( 159 | model="gpt-3.5-turbo", 160 | messages=[{"role": "user", "content": "Hello!"}] 161 | ) 162 | 163 | responses = manager.process_batch(batch) 164 | return responses 165 | 166 | # Run in existing event loop 167 | async def main(): 168 | responses = await process_requests() 169 | print(responses) 170 | 171 | asyncio.run(main()) 172 | ``` 173 | 174 | ### Caching Configuration 175 | 176 | FastLLM supports both in-memory and disk-based caching, with request IDs serving as cache keys: 177 | 178 | ```python 179 | from fastllm import InMemoryCache, DiskCache, RequestBatch 180 | 181 | # Create a batch and get request IDs 182 | with RequestBatch() as batch: 183 | request_id = batch.chat.completions.create( 184 | model="gpt-3.5-turbo", 185 | messages=[{"role": "user", "content": "Hello!"}] 186 | ) 187 | print(f"Request ID (cache key): {request_id}") 188 | 189 | # In-memory cache (faster, but cleared when process ends) 190 | cache = InMemoryCache() 191 | 192 | # Disk cache (persistent, with optional TTL and size limits) 193 | cache = DiskCache( 194 | directory="./cache", 195 | ttl=3600, # Cache TTL in seconds 196 | size_limit=int(2e9) # 2GB size limit 197 | ) 198 | 199 | # Check if a response is cached 200 | is_cached = await cache.exists(request_id) 201 | 202 | # Get cached response if available 203 | if is_cached: 204 | response = await cache.get(request_id) 205 | ``` 206 | 207 | ### Custom Providers 208 | 209 | Create your own provider by inheriting from the base `Provider` class: 210 | 211 | ```python 212 | from fastllm import Provider 213 | from typing import Any 214 | import httpx 215 | 216 | class CustomProvider(Provider[YourResponseType]): 217 | def get_request_headers(self) -> dict[str, str]: 218 | return { 219 | "Authorization": f"Bearer {self.api_key}", 220 | "Content-Type": "application/json", 221 | } 222 | 223 | async def make_request( 224 | self, 225 | client: httpx.AsyncClient, 226 | request: dict[str, Any], 227 | timeout: float, 228 | ) -> YourResponseType: 229 | # Implement your request logic here 230 | pass 231 | ``` 232 | 233 | ### Progress Tracking 234 | 235 | The progress bar shows: 236 | - Request completion progress 237 | - Tokens per second (prompt and completion) 238 | - Cache hit/miss statistics 239 | - Estimated time remaining 240 | - Total elapsed time 241 | 242 | ## Contributing 243 | 244 | Contributions are welcome! Please feel free to submit a Pull Request. 245 | 246 | ## License 247 | 248 | This project is licensed under the MIT License - see the LICENSE file for details. 249 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # FastLLM Documentation 2 | 3 | FastLLM is a high-performance Python library for making parallel LLM API requests with built-in caching and multiple provider support. It's designed to efficiently handle batch processing of LLM requests while providing a familiar API interface similar to OpenAI's client library. 4 | 5 | ## Table of Contents 6 | 7 | - [Overview](overview.md) 8 | - [Getting Started](getting-started.md) 9 | - [Architecture](architecture.md) 10 | - [Core Components](core-components.md) 11 | - [API Reference](api-reference.md) 12 | - [Configuration](configuration.md) 13 | - [Development Guide](development-guide.md) -------------------------------------------------------------------------------- /docs/RELEASE.md: -------------------------------------------------------------------------------- 1 | # Release Procedure 2 | 3 | This document outlines the steps to release fastllm-kit to PyPI. 4 | 5 | ## Publishing to PyPI 6 | 7 | The library is published as `fastllm-kit` on PyPI, but maintains the import name `fastllm` for user convenience. 8 | 9 | ### Prerequisites 10 | 11 | - Ensure you have `uv` installed 12 | - Have PyPI credentials set up 13 | - Set the PyPI token with `UV_PUBLISH_TOKEN` environment variable or use `--token` flag 14 | - For TestPyPI, you'll also need TestPyPI credentials 15 | 16 | ### Release Steps 17 | 18 | 1. Update the version in `pyproject.toml` 19 | 2. Update `CHANGELOG.md` with the latest changes 20 | 3. Commit all changes 21 | 22 | 4. Clean previous builds: 23 | ``` 24 | just clean 25 | ``` 26 | 27 | 5. Build the package: 28 | ``` 29 | just build 30 | ``` 31 | 32 | 6. Publish to TestPyPI first to verify everything works: 33 | ``` 34 | just publish-test 35 | ``` 36 | Or with explicit token: 37 | ``` 38 | UV_PUBLISH_TOKEN=your_testpypi_token just publish-test 39 | ``` 40 | 41 | 7. Verify installation from TestPyPI: 42 | ``` 43 | pip uninstall -y fastllm-kit 44 | pip install --index-url https://test.pypi.org/simple/ fastllm-kit 45 | ``` 46 | 47 | 8. Verify imports work correctly: 48 | ```python 49 | from fastllm import RequestBatch, RequestManager 50 | ``` 51 | 52 | 9. If everything works, publish to PyPI: 53 | ``` 54 | just publish 55 | ``` 56 | Or with explicit token: 57 | ``` 58 | UV_PUBLISH_TOKEN=your_pypi_token just publish 59 | ``` 60 | 61 | ## Usage Instructions 62 | 63 | Users will install the package with: 64 | ``` 65 | pip install fastllm-kit 66 | ``` 67 | 68 | But will import it in their code as: 69 | ```python 70 | from fastllm import RequestBatch, RequestManager 71 | ``` 72 | 73 | This allows us to maintain a clean import interface despite the PyPI name being different from the import name. -------------------------------------------------------------------------------- /docs/api-reference.md: -------------------------------------------------------------------------------- 1 | # API Reference 2 | 3 | This document provides detailed information about FastLLM's public APIs. 4 | 5 | ## RequestManager 6 | 7 | ### Class Definition 8 | 9 | ```python 10 | class RequestManager: 11 | def __init__( 12 | self, 13 | provider: Provider[ResponseT], 14 | concurrency: int = 100, 15 | timeout: float = 30.0, 16 | retry_attempts: int = 3, 17 | retry_delay: float = 1.0, 18 | show_progress: bool = True, 19 | caching_provider: Optional[CacheProvider] = None 20 | ) 21 | ``` 22 | 23 | #### Parameters 24 | 25 | - `provider`: LLM provider instance 26 | - `concurrency`: Maximum number of concurrent requests (default: 100) 27 | - `timeout`: Request timeout in seconds (default: 30.0) 28 | - `retry_attempts`: Number of retry attempts (default: 3) 29 | - `retry_delay`: Delay between retries in seconds (default: 1.0) 30 | - `show_progress`: Whether to show progress bar (default: True) 31 | - `caching_provider`: Optional cache provider instance 32 | 33 | ### Methods 34 | 35 | #### process_batch 36 | 37 | ```python 38 | def process_batch( 39 | self, 40 | batch: Union[list[dict[str, Any]], RequestBatch] 41 | ) -> list[ResponseT] 42 | ``` 43 | 44 | Process a batch of LLM requests in parallel. 45 | 46 | **Parameters:** 47 | - `batch`: Either a RequestBatch object or a list of request dictionaries 48 | 49 | **Returns:** 50 | - List of responses in the same order as the requests 51 | 52 | ## Provider 53 | 54 | ### Base Class 55 | 56 | ```python 57 | class Provider(Generic[ResponseT], ABC): 58 | def __init__( 59 | self, 60 | api_key: str, 61 | api_base: str, 62 | headers: Optional[dict[str, str]] = None, 63 | **kwargs: Any 64 | ) 65 | ``` 66 | 67 | #### Abstract Methods 68 | 69 | ```python 70 | @abstractmethod 71 | def get_request_headers(self) -> dict[str, str]: 72 | """Get headers for API requests.""" 73 | pass 74 | 75 | @abstractmethod 76 | async def make_request( 77 | self, 78 | client: httpx.AsyncClient, 79 | request: dict[str, Any], 80 | timeout: float 81 | ) -> ResponseT: 82 | """Make a request to the provider API.""" 83 | pass 84 | ``` 85 | 86 | ### OpenAI Provider 87 | 88 | ```python 89 | class OpenAIProvider(Provider[ChatCompletion]): 90 | def __init__( 91 | self, 92 | api_key: str, 93 | api_base: str = DEFAULT_API_BASE, 94 | organization: Optional[str] = None, 95 | headers: Optional[dict[str, str]] = None, 96 | **kwargs: Any 97 | ) 98 | ``` 99 | 100 | ## Cache System 101 | 102 | ### CacheProvider Interface 103 | 104 | ```python 105 | class CacheProvider: 106 | async def exists(self, key: str) -> bool: 107 | """Check if a key exists in the cache.""" 108 | pass 109 | 110 | async def get(self, key: str): 111 | """Get a value from the cache.""" 112 | pass 113 | 114 | async def put(self, key: str, value) -> None: 115 | """Put a value in the cache.""" 116 | pass 117 | 118 | async def clear(self) -> None: 119 | """Clear all items from the cache.""" 120 | pass 121 | 122 | async def close(self) -> None: 123 | """Close the cache when done.""" 124 | pass 125 | ``` 126 | 127 | ### InMemoryCache 128 | 129 | ```python 130 | class InMemoryCache(CacheProvider): 131 | def __init__(self) 132 | ``` 133 | 134 | Simple in-memory cache implementation using a dictionary. 135 | 136 | ### DiskCache 137 | 138 | ```python 139 | class DiskCache(CacheProvider): 140 | def __init__( 141 | self, 142 | directory: str, 143 | ttl: Optional[int] = None, 144 | **cache_options 145 | ) 146 | ``` 147 | 148 | **Parameters:** 149 | - `directory`: Directory where cache files will be stored 150 | - `ttl`: Time to live in seconds for cached items 151 | - `cache_options`: Additional options for diskcache.Cache 152 | 153 | ## Request Models 154 | 155 | ### Message 156 | 157 | ```python 158 | class Message(BaseModel): 159 | role: Literal["system", "user", "assistant", "function", "tool"] = "user" 160 | content: Optional[str] = None 161 | name: Optional[str] = None 162 | function_call: Optional[dict[str, Any]] = None 163 | tool_calls: Optional[list[dict[str, Any]]] = None 164 | ``` 165 | 166 | #### Class Methods 167 | 168 | ```python 169 | @classmethod 170 | def from_dict(cls, data: Union[str, dict[str, Any]]) -> Message: 171 | """Create a message from a string or dictionary.""" 172 | ``` 173 | 174 | ### LLMRequest 175 | 176 | ```python 177 | class LLMRequest(BaseModel): 178 | provider: str 179 | messages: list[Message] 180 | model: Optional[str] = None 181 | temperature: float = Field(default=0.7, ge=0.0, le=2.0) 182 | max_completion_tokens: Optional[int] = None 183 | top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0) 184 | presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0) 185 | frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0) 186 | stop: Optional[list[str]] = None 187 | stream: bool = False 188 | ``` 189 | 190 | #### Class Methods 191 | 192 | ```python 193 | @classmethod 194 | def from_prompt( 195 | cls, 196 | provider: str, 197 | prompt: Union[str, dict[str, Any]], 198 | **kwargs 199 | ) -> LLMRequest: 200 | """Create a request from a single prompt.""" 201 | 202 | @classmethod 203 | def from_dict(cls, data: dict[str, Any]) -> LLMRequest: 204 | """Create a request from a dictionary.""" 205 | ``` 206 | 207 | ## RequestBatch 208 | 209 | ```python 210 | class RequestBatch(AbstractContextManager): 211 | def __init__(self) 212 | ``` 213 | 214 | ### Usage Example 215 | 216 | ```python 217 | with RequestBatch() as batch: 218 | batch.chat.completions.create( 219 | model="gpt-3.5-turbo", 220 | messages=[{"role": "user", "content": "Hello!"}] 221 | ) 222 | ``` 223 | 224 | ### Chat Completions API 225 | 226 | ```python 227 | def create( 228 | self, 229 | *, 230 | model: str, 231 | messages: list[dict[str, str]], 232 | temperature: Optional[float] = 0.7, 233 | top_p: Optional[float] = 1.0, 234 | n: Optional[int] = 1, 235 | stop: Optional[Union[str, list[str]]] = None, 236 | max_completion_tokens: Optional[int] = None, 237 | presence_penalty: Optional[float] = 0.0, 238 | frequency_penalty: Optional[float] = 0.0, 239 | logit_bias: Optional[dict[str, float]] = None, 240 | user: Optional[str] = None, 241 | response_format: Optional[dict[str, str]] = None, 242 | seed: Optional[int] = None, 243 | tools: Optional[list[dict[str, Any]]] = None, 244 | tool_choice: Optional[Union[str, dict[str, str]]] = None, 245 | **kwargs: Any 246 | ) -> str: 247 | """Add a chat completion request to the batch.""" 248 | ``` 249 | 250 | ## Progress Tracking 251 | 252 | ### TokenStats 253 | 254 | ```python 255 | @dataclass 256 | class TokenStats: 257 | prompt_tokens: int = 0 258 | completion_tokens: int = 0 259 | total_tokens: int = 0 260 | requests_completed: int = 0 261 | cache_hits: int = 0 262 | start_time: float = 0.0 263 | ``` 264 | 265 | ### ProgressTracker 266 | 267 | ```python 268 | class ProgressTracker: 269 | def __init__( 270 | self, 271 | total_requests: int, 272 | show_progress: bool = True 273 | ) 274 | ``` 275 | 276 | ## Response Types 277 | 278 | ### ResponseWrapper 279 | 280 | ```python 281 | class ResponseWrapper(Generic[ResponseT]): 282 | def __init__( 283 | self, 284 | response: ResponseT, 285 | request_id: str, 286 | order_id: int 287 | ) -------------------------------------------------------------------------------- /docs/architecture.md: -------------------------------------------------------------------------------- 1 | # FastLLM Architecture 2 | 3 | This document provides an overview of the FastLLM architecture and design decisions. 4 | 5 | ## Project Structure 6 | 7 | ``` 8 | fastllm/ # Main package directory 9 | ├── __init__.py # Package exports 10 | ├── core.py # Core functionality (RequestBatch, RequestManager) 11 | ├── cache.py # Caching implementations 12 | └── providers/ # LLM provider implementations 13 | ├── __init__.py 14 | ├── base.py # Base Provider class 15 | └── openai.py # OpenAI API implementation 16 | ``` 17 | 18 | ## Package and Distribution 19 | 20 | The package follows these important design decisions: 21 | 22 | - **PyPI Package Name**: `fastllm-kit` (since `fastllm` was already taken on PyPI) 23 | - **Import Name**: `fastllm` (users import with `from fastllm import ...`) 24 | - **GitHub Repository**: Maintained at `github.com/rexhaif/fastllm` 25 | 26 | The import redirection is implemented using Hatchling's build configuration in `pyproject.toml`, which ensures that the package is published as `fastllm-kit` but exposes a top-level `fastllm` module for imports. 27 | 28 | ## System Architecture 29 | 30 | ``` 31 | ┌─────────────────────────────────────────────────────────┐ 32 | │ RequestManager │ 33 | ├─────────────────────────────────────────────────────────┤ 34 | │ - Manages parallel request processing │ 35 | │ - Handles concurrency and batching │ 36 | │ - Coordinates between components │ 37 | └───────────────────┬───────────────────┬─────────────────┘ 38 | │ │ 39 | ┌───────────────▼──────┐ ┌────────▼──────────┐ 40 | │ Provider │ │ CacheProvider │ 41 | ├──────────────────────┤ ├───────────────────┤ 42 | │ - API communication │ │ - Request caching │ 43 | │ - Response parsing │ │ - Cache management│ 44 | └──────────────────────┘ └───────────────────┘ 45 | ``` 46 | 47 | ## Core Components 48 | 49 | ### RequestBatch 50 | 51 | The `RequestBatch` class provides an OpenAI-compatible request interface that allows for batching multiple requests together. 52 | 53 | ### RequestManager 54 | 55 | The `RequestManager` handles parallel processing of requests with configurable concurrency, retries, and caching. 56 | 57 | ### Providers 58 | 59 | Providers implement the interface for different LLM services: 60 | 61 | - `OpenAIProvider`: Handles requests to OpenAI-compatible APIs 62 | 63 | ### Caching 64 | 65 | The caching system supports: 66 | 67 | - `InMemoryCache`: Fast, in-process caching 68 | - `DiskCache`: Persistent disk-based caching with TTL and size limits 69 | 70 | ## Design Decisions 71 | 72 | 1. **Parallel Processing**: Designed to maximize throughput by processing many requests in parallel 73 | 2. **Request Deduplication**: Automatically detects and deduplicates identical requests 74 | 3. **Response Ordering**: Maintains request ordering regardless of completion time 75 | 4. **Caching**: Optional caching with customizable providers 76 | 77 | ## Development and Testing 78 | 79 | - Tests are implemented using `pytest` and `pytest-asyncio` 80 | - Code formatting is handled by `ruff` and `black` 81 | - Task automation is handled by `just` 82 | - Dependency management uses `uv` 83 | 84 | ## Key Utilities 85 | 86 | - `compute_request_hash`: Consistent request hashing for caching and deduplication 87 | - `TokenStats`: Tracking token usage and rate limits 88 | - `ProgressTracker`: Visual progress bar with statistics 89 | 90 | ## Request Flow 91 | 92 | 1. Create a `RequestBatch` and add requests via APIs like `batch.chat.completions.create()` 93 | 2. Each request is directly stored in OpenAI Batch format with a `custom_id` that combines the request hash and order ID 94 | 3. Pass the batch to a `RequestManager` for parallel processing 95 | 4. The manager extracts necessary metadata from the `custom_id` and `url` fields 96 | 5. Responses are returned in the original request order 97 | 98 | ## Caching 99 | 100 | The library implements efficient caching: 101 | - Request hashing for consistent cache keys 102 | - Support for both in-memory and persistent caching 103 | - Cache hits bypass network requests 104 | 105 | ## Provider Interface 106 | 107 | The library's provider system is designed to work with the simplified OpenAI Batch format: 108 | - Providers receive only the request body and necessary metadata 109 | - No conversion between formats is required during processing 110 | - Both chat completions and embeddings use the same batch structure with different URLs 111 | - API endpoints are determined automatically based on request type within each provider implementation 112 | 113 | ## Extensions 114 | 115 | The library is designed to be easily extensible: 116 | - Support for multiple LLM providers 117 | - Custom cache implementations 118 | - Flexible request formatting 119 | 120 | ## Key Features 121 | 122 | ### 1. Parallel Processing 123 | - Async/await throughout 124 | - Efficient request batching 125 | - Concurrent API calls 126 | - Progress monitoring 127 | - Support for both chat completions and embeddings 128 | 129 | ### 2. Caching 130 | - Multiple cache backends 131 | - TTL support 132 | - Async operations 133 | - Thread-safe implementation 134 | 135 | ### 3. Rate Limiting 136 | - Token-based limits 137 | - Request frequency limits 138 | - Window-based tracking 139 | - Saturation monitoring 140 | 141 | ### 4. Error Handling 142 | - Consistent error types 143 | - Graceful degradation 144 | - Detailed error messages 145 | - Retry mechanisms 146 | 147 | ## Configuration Points 148 | 149 | The system can be configured through: 150 | 1. Provider settings 151 | - API keys and endpoints 152 | - Organization IDs 153 | - Custom headers 154 | 155 | 2. Cache settings 156 | - Backend selection 157 | - TTL configuration 158 | - Storage directory 159 | - Serialization options 160 | 161 | 3. Request parameters 162 | - Model selection 163 | - Temperature and sampling 164 | - Token limits 165 | - Response streaming 166 | 167 | 4. Rate limiting 168 | - Token rate limits 169 | - Request frequency limits 170 | - Window sizes 171 | - Saturation thresholds 172 | 173 | ## Best Practices 174 | 175 | 1. **Error Handling** 176 | - Use try-except blocks for cache operations 177 | - Handle API errors gracefully 178 | - Provide meaningful error messages 179 | - Implement proper cleanup 180 | 181 | 2. **Performance** 182 | - Use appropriate cache backend 183 | - Configure batch sizes 184 | - Monitor rate limits 185 | - Track token usage 186 | 187 | 3. **Security** 188 | - Secure API key handling 189 | - Safe cache storage 190 | - Input validation 191 | - Response sanitization 192 | 193 | 4. **Maintenance** 194 | - Regular cache cleanup 195 | - Monitor disk usage 196 | - Update API versions 197 | - Track deprecations 198 | 199 | ## Data Flow 200 | 201 | 1. **Request Initialization** 202 | ```python 203 | RequestBatch 204 | → Chat Completion or Embedding Request 205 | → Provider-specific Request 206 | ``` 207 | 208 | 2. **Request Processing** 209 | ```python 210 | RequestManager 211 | → Check Cache 212 | → Make API Request if needed 213 | → Update Progress 214 | → Store in Cache 215 | ``` 216 | 217 | 3. **Response Handling** 218 | ```python 219 | API Response 220 | → ResponseWrapper 221 | → Update Statistics 222 | → Return to User 223 | ``` 224 | 225 | ## Design Principles 226 | 227 | 1. **Modularity** 228 | - Clear separation of concerns 229 | - Extensible provider system 230 | - Pluggable cache providers 231 | 232 | 2. **Performance** 233 | - Efficient parallel processing 234 | - Smart resource management 235 | - Optimized caching 236 | 237 | 3. **Reliability** 238 | - Comprehensive error handling 239 | - Automatic retries 240 | - Progress tracking 241 | 242 | 4. **Developer Experience** 243 | - Familiar API patterns 244 | - Clear type hints 245 | - Comprehensive logging 246 | - Structured logging system 247 | 248 | ## Logging System 249 | 250 | The library uses Python's built-in `logging` module for structured logging: 251 | 252 | 1. **Core Components** 253 | - Each module has its own logger (`logging.getLogger(__name__)`) 254 | - Log levels used appropriately (DEBUG, INFO, WARNING, ERROR) 255 | - Critical operations and errors are logged 256 | 257 | 2. **Key Logging Areas** 258 | - Cache operations (read/write errors) 259 | - Request processing status 260 | - Rate limiting events 261 | - Error conditions and exceptions 262 | 263 | 3. **Best Practices** 264 | - Consistent log format 265 | - Meaningful context in log messages 266 | - Error traceability 267 | - Performance impact consideration 268 | 269 | ## Error Handling 270 | 271 | The system implements comprehensive error handling: 272 | - API errors 273 | - Rate limiting 274 | - Timeouts 275 | - Cache failures 276 | - Invalid requests 277 | 278 | Each component includes appropriate error handling and propagation to ensure system stability and reliability. 279 | 280 | ## Testing Strategy 281 | 282 | The test suite is organized by component: 283 | 284 | 1. **Core Tests** (`test_core.py`): 285 | - Request/Response model validation 286 | - RequestManager functionality 287 | - Token statistics tracking 288 | 289 | 2. **Cache Tests** (`test_cache.py`): 290 | - Cache implementation verification 291 | - Request hashing consistency 292 | - Concurrent access handling 293 | 294 | 3. **Provider Tests** (`test_providers.py`): 295 | - Provider interface compliance 296 | - API communication 297 | - Response parsing 298 | 299 | 4. **Integration Tests**: 300 | - End-to-end request flow 301 | - Rate limiting behavior 302 | - Error handling scenarios 303 | 304 | ## Supported APIs 305 | 306 | The library supports the following APIs: 307 | 308 | 1. **Chat Completions** 309 | - Multiple message support 310 | - Tool and function calls 311 | - Streaming responses 312 | - Temperature and top_p sampling 313 | 314 | 2. **Embeddings** 315 | - Single or batch text input 316 | - Dimension control 317 | - Format control (float/base64) 318 | - Efficient batch processing 319 | - Compatible with semantic search use cases 320 | -------------------------------------------------------------------------------- /docs/configuration.md: -------------------------------------------------------------------------------- 1 | # Configuration Guide 2 | 3 | This guide covers the various configuration options available in FastLLM and how to use them effectively. 4 | 5 | ## Provider Configuration 6 | 7 | ### OpenAI Provider 8 | 9 | ```python 10 | from fastllm.providers import OpenAIProvider 11 | 12 | provider = OpenAIProvider( 13 | api_key="your-api-key", 14 | api_base="https://api.openai.com/v1", # Optional: custom API endpoint 15 | organization="your-org-id", # Optional: OpenAI organization ID 16 | headers={ # Optional: custom headers 17 | "Custom-Header": "value" 18 | } 19 | ) 20 | ``` 21 | 22 | #### Configuration Options 23 | 24 | | Parameter | Type | Default | Description | 25 | |-----------|------|---------|-------------| 26 | | api_key | str | Required | Your API key | 27 | | api_base | str | OpenAI default | API endpoint URL | 28 | | organization | str | None | Organization ID | 29 | | headers | dict | None | Custom headers | 30 | 31 | ## Request Manager Configuration 32 | 33 | ```python 34 | from fastllm import RequestManager 35 | 36 | manager = RequestManager( 37 | provider=provider, 38 | concurrency=100, 39 | timeout=30.0, 40 | retry_attempts=3, 41 | retry_delay=1.0, 42 | show_progress=True, 43 | caching_provider=None 44 | ) 45 | ``` 46 | 47 | ### Performance Settings 48 | 49 | | Parameter | Type | Default | Description | 50 | |-----------|------|---------|-------------| 51 | | concurrency | int | 100 | Maximum concurrent requests | 52 | | timeout | float | 30.0 | Request timeout in seconds | 53 | | retry_attempts | int | 3 | Number of retry attempts | 54 | | retry_delay | float | 1.0 | Delay between retries | 55 | 56 | ### Progress Display 57 | 58 | ```python 59 | manager = RequestManager( 60 | provider=provider, 61 | show_progress=True # Enable rich progress display 62 | ) 63 | ``` 64 | 65 | The progress display shows: 66 | - Completion percentage 67 | - Request count 68 | - Token usage rates 69 | - Cache hit ratio 70 | - Estimated time remaining 71 | 72 | ## Cache Configuration 73 | 74 | ### In-Memory Cache 75 | 76 | ```python 77 | from fastllm.cache import InMemoryCache 78 | 79 | cache = InMemoryCache() 80 | ``` 81 | 82 | Best for: 83 | - Development and testing 84 | - Small-scale applications 85 | - Temporary caching needs 86 | 87 | ### Disk Cache 88 | 89 | ```python 90 | from fastllm.cache import DiskCache 91 | 92 | cache = DiskCache( 93 | directory="cache", # Cache directory path 94 | ttl=3600, # Cache TTL in seconds 95 | size_limit=1e9 # Cache size limit in bytes 96 | ) 97 | ``` 98 | 99 | #### Configuration Options 100 | 101 | | Parameter | Type | Default | Description | 102 | |-----------|------|---------|-------------| 103 | | directory | str | Required | Cache directory path | 104 | | ttl | int | None | Time-to-live in seconds | 105 | | size_limit | int | None | Maximum cache size | 106 | 107 | ### Cache Integration 108 | 109 | ```python 110 | manager = RequestManager( 111 | provider=provider, 112 | caching_provider=cache 113 | ) 114 | ``` 115 | 116 | ## Request Configuration 117 | 118 | ### Basic Request Settings 119 | 120 | ```python 121 | with RequestBatch() as batch: 122 | batch.chat.completions.create( 123 | model="gpt-3.5-turbo", 124 | messages=[{"role": "user", "content": "Hello"}], 125 | temperature=0.7, 126 | max_completion_tokens=100 127 | ) 128 | ``` 129 | 130 | ### Available Parameters 131 | 132 | | Parameter | Type | Default | Description | 133 | |-----------|------|---------|-------------| 134 | | model | str | Required | Model identifier | 135 | | messages | list | Required | Conversation messages | 136 | | temperature | float | 0.7 | Sampling temperature | 137 | | max_completion_tokens | int | None | Max tokens to generate | 138 | | top_p | float | 1.0 | Nucleus sampling parameter | 139 | | presence_penalty | float | 0.0 | Presence penalty | 140 | | frequency_penalty | float | 0.0 | Frequency penalty | 141 | | stop | list[str] | None | Stop sequences | 142 | 143 | ## Advanced Configuration 144 | 145 | ### Custom Chunk Size 146 | 147 | The RequestManager automatically calculates optimal chunk sizes, but you can influence this through the concurrency setting: 148 | 149 | ```python 150 | manager = RequestManager( 151 | provider=provider, 152 | concurrency=50 # Will affect chunk size calculation 153 | ) 154 | ``` 155 | 156 | Chunk size is calculated as: 157 | ```python 158 | chunk_size = min(concurrency * 2, 1000) 159 | ``` 160 | 161 | ### Error Handling Configuration 162 | 163 | ```python 164 | manager = RequestManager( 165 | provider=provider, 166 | retry_attempts=5, # Increase retry attempts 167 | retry_delay=2.0, # Increase delay between retries 168 | timeout=60.0 # Increase timeout 169 | ) 170 | ``` 171 | 172 | ### Custom Headers 173 | 174 | ```python 175 | provider = OpenAIProvider( 176 | api_key="your-api-key", 177 | headers={ 178 | "User-Agent": "CustomApp/1.0", 179 | "X-Custom-Header": "value" 180 | } 181 | ) 182 | ``` 183 | 184 | ## Environment Variables 185 | 186 | FastLLM respects the following environment variables: 187 | 188 | ```bash 189 | OPENAI_API_KEY=your-api-key 190 | OPENAI_ORG_ID=your-org-id 191 | ``` 192 | 193 | ## Best Practices 194 | 195 | ### Concurrency Settings 196 | 197 | - Start with lower concurrency (10-20) and adjust based on performance 198 | - Monitor token usage and API rate limits 199 | - Consider provider-specific rate limits 200 | 201 | ### Cache Configuration 202 | 203 | - Use disk cache for production environments 204 | - Set appropriate TTL based on data freshness requirements 205 | - Monitor cache size and hit ratios 206 | 207 | ### Error Handling 208 | 209 | - Configure retry attempts based on API stability 210 | - Use appropriate timeout values 211 | - Implement proper error handling in your code 212 | 213 | ### Resource Management 214 | 215 | - Close cache providers when done 216 | - Monitor memory usage 217 | - Use appropriate chunk sizes for your use case 218 | 219 | ## Performance Optimization 220 | 221 | ### Caching Strategy 222 | 223 | - Enable caching for repeated requests 224 | - Use appropriate TTL values 225 | - Monitor cache hit ratios 226 | 227 | ### Concurrency Tuning 228 | 229 | - Adjust concurrency based on: 230 | * API rate limits 231 | * System resources 232 | * Response times 233 | 234 | ### Memory Management 235 | 236 | - Use appropriate chunk sizes 237 | - Monitor memory usage 238 | - Clean up resources properly 239 | 240 | ## Monitoring and Debugging 241 | 242 | ### Progress Tracking 243 | 244 | ```python 245 | manager = RequestManager( 246 | provider=provider, 247 | show_progress=True 248 | ) 249 | ``` 250 | 251 | ### Token Usage Monitoring 252 | 253 | Track token usage through the TokenStats class: 254 | - Prompt tokens 255 | - Completion tokens 256 | - Token rates 257 | - Cache hit ratios -------------------------------------------------------------------------------- /docs/core-components.md: -------------------------------------------------------------------------------- 1 | # Core Components 2 | 3 | This document provides detailed information about the core components of FastLLM. 4 | 5 | ## RequestManager 6 | 7 | The RequestManager is the central component responsible for handling parallel LLM API requests. 8 | 9 | ```python 10 | RequestManager( 11 | provider: Provider[ResponseT], 12 | concurrency: int = 100, 13 | timeout: float = 30.0, 14 | retry_attempts: int = 3, 15 | retry_delay: float = 1.0, 16 | show_progress: bool = True, 17 | caching_provider: Optional[CacheProvider] = None 18 | ) 19 | ``` 20 | 21 | ### Features 22 | 23 | - **Parallel Processing**: Handles multiple requests concurrently using asyncio 24 | - **Chunked Processing**: Processes requests in optimal chunks based on concurrency 25 | - **Progress Tracking**: Real-time progress monitoring with rich statistics 26 | - **Cache Integration**: Seamless integration with caching providers 27 | - **Error Handling**: Comprehensive error handling with retries 28 | 29 | ### Key Methods 30 | 31 | - `process_batch()`: Main synchronous API for batch processing 32 | - `_process_batch_async()`: Internal async implementation 33 | - `_process_request_async()`: Individual request processing 34 | - `_calculate_chunk_size()`: Dynamic chunk size optimization 35 | 36 | ## Provider System 37 | 38 | The provider system enables integration with different LLM APIs. 39 | 40 | ### Base Provider 41 | 42 | ```python 43 | class Provider(Generic[ResponseT], ABC): 44 | def __init__( 45 | self, 46 | api_key: str, 47 | api_base: str, 48 | headers: Optional[dict[str, str]] = None, 49 | **kwargs: Any 50 | ) 51 | ``` 52 | 53 | #### Key Methods 54 | 55 | - `get_request_url()`: Constructs API endpoint URLs 56 | - `get_request_headers()`: Manages API request headers 57 | - `make_request()`: Handles API communication 58 | 59 | ### OpenAI Provider 60 | 61 | Implements OpenAI-specific functionality: 62 | 63 | ```python 64 | class OpenAIProvider(Provider[ChatCompletion]): 65 | def __init__( 66 | self, 67 | api_key: str, 68 | api_base: str = DEFAULT_API_BASE, 69 | organization: Optional[str] = None, 70 | headers: Optional[dict[str, str]] = None, 71 | **kwargs: Any 72 | ) 73 | ``` 74 | 75 | ## Caching System 76 | 77 | The caching system provides flexible caching solutions. 78 | 79 | ### Cache Provider Interface 80 | 81 | ```python 82 | class CacheProvider: 83 | async def exists(self, key: str) -> bool 84 | async def get(self, key: str) 85 | async def put(self, key: str, value) -> None 86 | async def clear(self) -> None 87 | async def close(self) -> None 88 | ``` 89 | 90 | ### Implementations 91 | 92 | #### InMemoryCache 93 | 94 | Simple dictionary-based cache for non-persistent storage: 95 | 96 | ```python 97 | class InMemoryCache(CacheProvider): 98 | def __init__(self) 99 | ``` 100 | 101 | #### DiskCache 102 | 103 | Persistent disk-based cache with TTL support: 104 | 105 | ```python 106 | class DiskCache(CacheProvider): 107 | def __init__( 108 | self, 109 | directory: str, 110 | ttl: Optional[int] = None, 111 | **cache_options 112 | ) 113 | ``` 114 | 115 | ## Request/Response Models 116 | 117 | ### LLMRequest 118 | 119 | Base model for LLM requests: 120 | 121 | ```python 122 | class LLMRequest(BaseModel): 123 | provider: str 124 | messages: list[Message] 125 | model: Optional[str] = None 126 | temperature: float = Field(default=0.7, ge=0.0, le=2.0) 127 | max_completion_tokens: Optional[int] = None 128 | top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0) 129 | presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0) 130 | frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0) 131 | stop: Optional[list[str]] = None 132 | stream: bool = False 133 | ``` 134 | 135 | ### Message 136 | 137 | Represents a single message in a conversation: 138 | 139 | ```python 140 | class Message(BaseModel): 141 | role: Literal["system", "user", "assistant", "function", "tool"] = "user" 142 | content: Optional[str] = None 143 | name: Optional[str] = None 144 | function_call: Optional[dict[str, Any]] = None 145 | tool_calls: Optional[list[dict[str, Any]]] = None 146 | ``` 147 | 148 | ### TokenStats 149 | 150 | Tracks token usage and performance metrics: 151 | 152 | ```python 153 | @dataclass 154 | class TokenStats: 155 | prompt_tokens: int = 0 156 | completion_tokens: int = 0 157 | total_tokens: int = 0 158 | requests_completed: int = 0 159 | cache_hits: int = 0 160 | start_time: float = 0.0 161 | ``` 162 | 163 | ### ProgressTracker 164 | 165 | Manages progress display and statistics: 166 | 167 | ```python 168 | class ProgressTracker: 169 | def __init__(self, total_requests: int, show_progress: bool = True) 170 | ``` 171 | 172 | Features: 173 | - Real-time progress bar 174 | - Token usage statistics 175 | - Cache hit ratio tracking 176 | - Performance metrics 177 | 178 | ## Request Batching 179 | 180 | ### RequestBatch 181 | 182 | Provides an OpenAI-like interface for batching requests: 183 | 184 | ```python 185 | class RequestBatch(AbstractContextManager): 186 | def __init__(self) 187 | ``` 188 | 189 | Usage: 190 | ```python 191 | with RequestBatch() as batch: 192 | batch.chat.completions.create( 193 | model="gpt-3.5-turbo", 194 | messages=[{"role": "user", "content": "Hello!"}] 195 | ) 196 | ``` 197 | 198 | ## Component Interactions 199 | 200 | 1. **Request Flow**: 201 | ``` 202 | RequestBatch → RequestManager → Cache Check → Provider → Response 203 | ``` 204 | 205 | 2. **Caching Flow**: 206 | ``` 207 | Request → Hash Computation → Cache Lookup → (Cache Hit/Miss) → Response 208 | ``` 209 | 210 | 3. **Progress Tracking**: 211 | ``` 212 | Request Processing → Stats Update → Progress Display → Completion 213 | ``` 214 | 215 | ## Best Practices 216 | 217 | 1. **Request Management**: 218 | - Use appropriate concurrency limits 219 | - Implement proper error handling 220 | - Monitor token usage 221 | 222 | 2. **Caching**: 223 | - Choose appropriate cache provider 224 | - Configure TTL based on needs 225 | - Monitor cache hit ratios 226 | 227 | 3. **Provider Implementation**: 228 | - Handle rate limits properly 229 | - Implement comprehensive error handling 230 | - Follow provider-specific best practices -------------------------------------------------------------------------------- /docs/development-guide.md: -------------------------------------------------------------------------------- 1 | # Development Guide 2 | 3 | This guide provides information for developers who want to contribute to FastLLM or extend its functionality. 4 | 5 | ## Development Setup 6 | 7 | ### Prerequisites 8 | 9 | - Python 3.9 or higher 10 | - uv for dependency management 11 | - Git for version control 12 | 13 | ### Setting Up Development Environment 14 | 15 | 1. Clone the repository: 16 | ```bash 17 | git clone https://github.com/Rexhaif/fastllm.git 18 | cd fastllm 19 | ``` 20 | 21 | 2. Install dependencies: 22 | ```bash 23 | uv pip install --with dev 24 | ``` 25 | 26 | 3. Activate virtual environment: 27 | ```bash 28 | uv venv 29 | source .venv/bin/activate # On Unix/macOS 30 | # or 31 | .venv\Scripts\activate # On Windows 32 | ``` 33 | 34 | ## Project Structure 35 | 36 | ``` 37 | fastllm/ 38 | ├── fastllm/ 39 | │ ├── __init__.py 40 | │ ├── cache.py # Caching system 41 | │ ├── cli.py # CLI interface 42 | │ ├── core.py # Core functionality 43 | │ └── providers/ # Provider implementations 44 | │ ├── __init__.py 45 | │ ├── base.py # Base provider class 46 | │ └── openai.py # OpenAI provider 47 | ├── tests/ # Test suite 48 | ├── examples/ # Example code 49 | ├── pyproject.toml # Project configuration 50 | └── README.md # Project documentation 51 | ``` 52 | 53 | ## Adding a New Provider 54 | 55 | To add support for a new LLM provider: 56 | 57 | 1. Create a new file in `fastllm/providers/`: 58 | 59 | ```python 60 | # fastllm/providers/custom_provider.py 61 | 62 | from typing import Any, Optional 63 | import httpx 64 | from fastllm.providers.base import Provider 65 | from fastllm.core import ResponseT 66 | 67 | class CustomProvider(Provider[ResponseT]): 68 | def __init__( 69 | self, 70 | api_key: str, 71 | api_base: str, 72 | headers: Optional[dict[str, str]] = None, 73 | **kwargs: Any 74 | ): 75 | super().__init__(api_key, api_base, headers, **kwargs) 76 | # Add provider-specific initialization 77 | 78 | def get_request_headers(self) -> dict[str, str]: 79 | return { 80 | "Authorization": f"Bearer {self.api_key}", 81 | "Content-Type": "application/json", 82 | **self.headers 83 | } 84 | 85 | async def make_request( 86 | self, 87 | client: httpx.AsyncClient, 88 | request: dict[str, Any], 89 | timeout: float 90 | ) -> ResponseT: 91 | # Implement provider-specific request handling 92 | url = self.get_request_url("endpoint") 93 | response = await client.post( 94 | url, 95 | headers=self.get_request_headers(), 96 | json=request, 97 | timeout=timeout 98 | ) 99 | response.raise_for_status() 100 | return response.json() 101 | ``` 102 | 103 | 2. Add tests for the new provider: 104 | 105 | ```python 106 | # tests/test_custom_provider.py 107 | 108 | import pytest 109 | from fastllm.providers.custom_provider import CustomProvider 110 | 111 | @pytest.fixture 112 | def provider(): 113 | return CustomProvider( 114 | api_key="test-key", 115 | api_base="https://api.example.com" 116 | ) 117 | 118 | async def test_make_request(provider): 119 | # Implement provider tests 120 | pass 121 | ``` 122 | 123 | ## Adding a New Cache Provider 124 | 125 | To implement a new cache provider: 126 | 127 | 1. Create a new cache implementation: 128 | 129 | ```python 130 | from fastllm.cache import CacheProvider 131 | 132 | class CustomCache(CacheProvider): 133 | def __init__(self, **options): 134 | # Initialize your cache 135 | 136 | async def exists(self, key: str) -> bool: 137 | # Implement key existence check 138 | 139 | async def get(self, key: str): 140 | # Implement value retrieval 141 | 142 | async def put(self, key: str, value) -> None: 143 | # Implement value storage 144 | 145 | async def clear(self) -> None: 146 | # Implement cache clearing 147 | 148 | async def close(self) -> None: 149 | # Implement cleanup 150 | ``` 151 | 152 | 2. Add tests for the new cache: 153 | 154 | ```python 155 | # tests/test_custom_cache.py 156 | 157 | import pytest 158 | from your_cache_module import CustomCache 159 | 160 | @pytest.fixture 161 | async def cache(): 162 | cache = CustomCache() 163 | yield cache 164 | await cache.close() 165 | 166 | async def test_cache_operations(cache): 167 | # Implement cache tests 168 | pass 169 | ``` 170 | 171 | ## Testing 172 | 173 | ### Running Tests 174 | 175 | ```bash 176 | # Run all tests 177 | pytest 178 | 179 | # Run with coverage 180 | pytest --cov=fastllm 181 | 182 | # Run specific test file 183 | pytest tests/test_specific.py 184 | ``` 185 | 186 | ### Writing Tests 187 | 188 | 1. Use pytest fixtures for common setup: 189 | 190 | ```python 191 | @pytest.fixture 192 | def request_manager(provider, cache): 193 | return RequestManager( 194 | provider=provider, 195 | caching_provider=cache 196 | ) 197 | ``` 198 | 199 | 2. Test async functionality: 200 | 201 | ```python 202 | @pytest.mark.asyncio 203 | async def test_async_function(): 204 | # Test implementation 205 | pass 206 | ``` 207 | 208 | 3. Use mocking when appropriate: 209 | 210 | ```python 211 | from unittest.mock import Mock, patch 212 | 213 | def test_with_mock(): 214 | with patch('module.function') as mock_func: 215 | # Test implementation 216 | pass 217 | ``` 218 | 219 | ## Code Style 220 | 221 | ### Style Guide 222 | 223 | - Follow PEP 8 guidelines 224 | - Use type hints 225 | - Document public APIs 226 | - Keep functions focused and small 227 | 228 | ### Code Formatting 229 | 230 | The project uses: 231 | - Black for code formatting 232 | - Ruff for linting 233 | - isort for import sorting 234 | 235 | ```bash 236 | # Format code 237 | black . 238 | 239 | # Sort imports 240 | isort . 241 | 242 | # Run linter 243 | ruff check . 244 | ``` 245 | 246 | ## Documentation 247 | 248 | ### Docstring Format 249 | 250 | Use Google-style docstrings: 251 | 252 | ```python 253 | def function(param1: str, param2: int) -> bool: 254 | """Short description. 255 | 256 | Longer description if needed. 257 | 258 | Args: 259 | param1: Parameter description 260 | param2: Parameter description 261 | 262 | Returns: 263 | Description of return value 264 | 265 | Raises: 266 | ValueError: Description of when this error occurs 267 | """ 268 | ``` 269 | 270 | ### Building Documentation 271 | 272 | The project uses Markdown for documentation: 273 | 274 | 1. Place documentation in `docs/` 275 | 2. Use clear, concise language 276 | 3. Include code examples 277 | 4. Keep documentation up to date 278 | 279 | ## Performance Considerations 280 | 281 | When developing new features: 282 | 283 | 1. **Concurrency** 284 | - Use async/await properly 285 | - Avoid blocking operations 286 | - Handle resources correctly 287 | 288 | 2. **Memory Usage** 289 | - Monitor memory consumption 290 | - Clean up resources 291 | - Use appropriate data structures 292 | 293 | 3. **Caching** 294 | - Implement efficient caching 295 | - Handle cache invalidation 296 | - Consider memory vs. speed tradeoffs 297 | 298 | ## Error Handling 299 | 300 | 1. Use appropriate exception types: 301 | 302 | ```python 303 | class CustomProviderError(Exception): 304 | """Base exception for custom provider.""" 305 | pass 306 | 307 | class CustomProviderAuthError(CustomProviderError): 308 | """Authentication error for custom provider.""" 309 | pass 310 | ``` 311 | 312 | 2. Implement proper error handling: 313 | 314 | ```python 315 | async def make_request(self, ...): 316 | try: 317 | response = await self._make_api_call() 318 | except httpx.TimeoutException as e: 319 | raise CustomProviderError(f"API timeout: {e}") 320 | except httpx.HTTPError as e: 321 | raise CustomProviderError(f"HTTP error: {e}") 322 | ``` 323 | 324 | ## Contributing 325 | 326 | 1. Fork the repository 327 | 2. Create a feature branch 328 | 3. Make your changes 329 | 4. Add tests 330 | 5. Update documentation 331 | 6. Submit a pull request 332 | 333 | ### Commit Guidelines 334 | 335 | - Use clear commit messages 336 | - Reference issues when applicable 337 | - Keep commits focused and atomic 338 | 339 | ### Pull Request Process 340 | 341 | 1. Update documentation 342 | 2. Add tests 343 | 3. Ensure CI passes 344 | 4. Request review 345 | 5. Address feedback 346 | 347 | ## Release Process 348 | 349 | 1. Update version in pyproject.toml 350 | 2. Update CHANGELOG.md 351 | 3. Create release commit 352 | 4. Tag release 353 | 5. Push to repository 354 | 6. Build and publish to PyPI -------------------------------------------------------------------------------- /docs/getting-started.md: -------------------------------------------------------------------------------- 1 | # Getting Started with FastLLM 2 | 3 | This guide will help you get started with FastLLM for efficient parallel LLM API requests. 4 | 5 | ## Installation 6 | 7 | Install FastLLM using pip: 8 | 9 | ```bash 10 | pip install fastllm 11 | ``` 12 | 13 | ## Quick Start 14 | 15 | Here's a simple example to get you started: 16 | 17 | ```python 18 | from fastllm import RequestManager 19 | from fastllm.providers import OpenAIProvider 20 | 21 | # Initialize the provider 22 | provider = OpenAIProvider( 23 | api_key="your-api-key", 24 | organization="your-org-id" # Optional 25 | ) 26 | 27 | # Create request manager 28 | manager = RequestManager( 29 | provider=provider, 30 | concurrency=10, # Number of concurrent requests 31 | show_progress=True # Show progress bar 32 | ) 33 | 34 | # Create a batch of requests 35 | from fastllm import RequestBatch 36 | 37 | with RequestBatch() as batch: 38 | # Add multiple requests to the batch 39 | batch.chat.completions.create( 40 | model="gpt-3.5-turbo", 41 | messages=[{"role": "user", "content": "Hello!"}] 42 | ) 43 | batch.chat.completions.create( 44 | model="gpt-3.5-turbo", 45 | messages=[{"role": "user", "content": "How are you?"}] 46 | ) 47 | 48 | # Process the batch 49 | responses = manager.process_batch(batch) 50 | 51 | # Work with responses 52 | for response in responses: 53 | print(response.content) 54 | ``` 55 | 56 | ## Adding Caching 57 | 58 | Enable caching to improve performance and reduce API calls: 59 | 60 | ```python 61 | from fastllm.cache import DiskCache 62 | 63 | # Initialize disk cache 64 | cache = DiskCache( 65 | directory="cache", # Cache directory 66 | ttl=3600 # Cache TTL in seconds 67 | ) 68 | 69 | # Create request manager with cache 70 | manager = RequestManager( 71 | provider=provider, 72 | concurrency=10, 73 | caching_provider=cache, 74 | show_progress=True 75 | ) 76 | ``` 77 | 78 | ## Processing Large Batches 79 | 80 | For large batches of requests: 81 | 82 | ```python 83 | # Create many requests 84 | with RequestBatch() as batch: 85 | for i in range(1000): 86 | batch.chat.completions.create( 87 | model="gpt-3.5-turbo", 88 | messages=[ 89 | {"role": "user", "content": f"Process item {i}"} 90 | ] 91 | ) 92 | 93 | # Process with automatic chunking and progress tracking 94 | responses = manager.process_batch(batch) 95 | ``` 96 | 97 | ## Custom Configuration 98 | 99 | ### Timeout and Retry Settings 100 | 101 | ```python 102 | manager = RequestManager( 103 | provider=provider, 104 | timeout=30.0, # Request timeout in seconds 105 | retry_attempts=3, # Number of retry attempts 106 | retry_delay=1.0, # Delay between retries 107 | ) 108 | ``` 109 | 110 | ### Provider Configuration 111 | 112 | ```python 113 | provider = OpenAIProvider( 114 | api_key="your-api-key", 115 | api_base="https://api.openai.com/v1", # Custom API endpoint 116 | headers={ 117 | "Custom-Header": "value" 118 | } 119 | ) 120 | ``` 121 | 122 | ## Progress Tracking 123 | 124 | FastLLM provides detailed progress information: 125 | 126 | ```python 127 | manager = RequestManager( 128 | provider=provider, 129 | show_progress=True # Enable progress bar 130 | ) 131 | 132 | # Progress will show: 133 | # - Completion percentage 134 | # - Token usage rates 135 | # - Cache hit ratio 136 | # - Estimated time remaining 137 | ``` 138 | 139 | ## Error Handling 140 | 141 | ```python 142 | try: 143 | responses = manager.process_batch(batch) 144 | except Exception as e: 145 | print(f"Error processing batch: {e}") 146 | ``` 147 | 148 | ## Advanced Usage 149 | 150 | ### Custom Message Formatting 151 | 152 | ```python 153 | from fastllm import Message 154 | 155 | # Create structured messages 156 | messages = [ 157 | Message(role="system", content="You are a helpful assistant"), 158 | Message(role="user", content="Hello!") 159 | ] 160 | 161 | with RequestBatch() as batch: 162 | batch.chat.completions.create( 163 | model="gpt-3.5-turbo", 164 | messages=messages 165 | ) 166 | ``` 167 | 168 | ### Working with Response Data 169 | 170 | ```python 171 | for response in responses: 172 | # Access response content 173 | print(response.content) 174 | 175 | # Access usage statistics 176 | if response.usage: 177 | print(f"Prompt tokens: {response.usage.prompt_tokens}") 178 | print(f"Completion tokens: {response.usage.completion_tokens}") 179 | print(f"Total tokens: {response.usage.total_tokens}") 180 | ``` 181 | 182 | ## Best Practices 183 | 184 | 1. **Batch Processing** 185 | - Group related requests together 186 | - Use appropriate concurrency limits 187 | - Enable progress tracking for large batches 188 | 189 | 2. **Caching** 190 | - Use disk cache for persistence 191 | - Set appropriate TTL values 192 | - Monitor cache hit ratios 193 | 194 | 3. **Error Handling** 195 | - Implement proper error handling 196 | - Use retry mechanisms 197 | - Monitor token usage 198 | 199 | 4. **Resource Management** 200 | - Close cache providers when done 201 | - Monitor memory usage 202 | - Use appropriate chunk sizes 203 | 204 | ## Next Steps 205 | 206 | - Read the [Architecture](architecture.md) documentation 207 | - Explore [Core Components](core-components.md) 208 | - Check the [API Reference](api-reference.md) -------------------------------------------------------------------------------- /docs/overview.md: -------------------------------------------------------------------------------- 1 | # FastLLM Overview 2 | 3 | FastLLM is a powerful Python library that enables efficient parallel processing of Large Language Model (LLM) API requests. It provides a robust solution for applications that need to make multiple LLM API calls simultaneously while managing resources effectively. 4 | 5 | ## Key Features 6 | 7 | ### 1. Parallel Request Processing 8 | - Efficient handling of concurrent API requests 9 | - Configurable concurrency limits 10 | - Built-in request batching and chunking 11 | - Progress tracking with detailed statistics 12 | 13 | ### 2. Intelligent Caching 14 | - Multiple cache provider options (In-memory and Disk-based) 15 | - Consistent request hashing for cache keys 16 | - Configurable TTL for cached responses 17 | - Automatic cache management 18 | 19 | ### 3. Multiple Provider Support 20 | - Modular provider architecture 21 | - Built-in OpenAI provider implementation 22 | - Extensible base classes for adding new providers 23 | - Provider-agnostic request/response models 24 | 25 | ### 4. Resource Management 26 | - Automatic rate limiting 27 | - Configurable timeout handling 28 | - Retry mechanism with exponential backoff 29 | - Connection pooling 30 | 31 | ### 5. Progress Monitoring 32 | - Real-time progress tracking 33 | - Token usage statistics 34 | - Cache hit ratio monitoring 35 | - Performance metrics (tokens/second) 36 | 37 | ### 6. Developer-Friendly Interface 38 | - OpenAI-like API design 39 | - Type hints throughout the codebase 40 | - Comprehensive error handling 41 | - Async/await support 42 | 43 | ## Use Cases 44 | 45 | FastLLM is particularly useful for: 46 | 47 | 1. **Batch Processing**: Processing large numbers of LLM requests efficiently 48 | - Data analysis pipelines 49 | - Content generation at scale 50 | - Bulk text processing 51 | 52 | 2. **High-Performance Applications**: Applications requiring optimal LLM API usage 53 | - Real-time applications 54 | - Interactive systems 55 | - API proxies and middleware 56 | 57 | 3. **Resource-Conscious Systems**: Systems that need to optimize API usage 58 | - Cost optimization through caching 59 | - Rate limit management 60 | - Token usage optimization 61 | 62 | 4. **Development and Testing**: Tools for LLM application development 63 | - Rapid prototyping 64 | - Testing and benchmarking 65 | - Performance optimization 66 | 67 | ## Performance Considerations 68 | 69 | FastLLM is designed with performance in mind: 70 | 71 | - **Parallel Processing**: Efficiently processes multiple requests concurrently 72 | - **Smart Batching**: Automatically chunks requests for optimal throughput 73 | - **Cache Optimization**: Reduces API calls through intelligent caching 74 | - **Resource Management**: Prevents overload through concurrency control 75 | - **Memory Efficiency**: Manages memory usage through chunked processing -------------------------------------------------------------------------------- /docs/rate-limiting.md: -------------------------------------------------------------------------------- 1 | # Rate Limiting Design 2 | 3 | ## Overview 4 | 5 | This document outlines the design for implementing a flexible rate limiting system for FastLLM providers. The system will support both request-based and token-based rate limits with human-readable time units. 6 | 7 | ## Requirements 8 | 9 | 1. Support multiple rate limits per provider 10 | 2. Handle both request-per-time-unit and tokens-per-time-unit limits 11 | 3. Support human-readable time unit specifications (e.g., "10s", "1m", "1h", "1d") 12 | 4. Track token usage from provider responses 13 | 5. Integrate with existing Provider and RequestManager architecture 14 | 15 | ## System Design 16 | 17 | ### Rate Limit Configuration 18 | 19 | Rate limits will be specified in the provider configuration using a list structure: 20 | 21 | ```python 22 | rate_limits = [ 23 | { 24 | "type": "requests", 25 | "limit": 100, 26 | "period": "1m" 27 | }, 28 | { 29 | "type": "tokens", 30 | "limit": 100000, 31 | "period": "1h" 32 | } 33 | ] 34 | ``` 35 | 36 | ### Components 37 | 38 | #### 1. RateLimitManager 39 | 40 | A new component responsible for tracking and enforcing rate limits: 41 | 42 | ```python 43 | class RateLimitManager: 44 | def __init__(self, limits: List[RateLimit]): 45 | self.limits = limits 46 | self.windows = {} # Tracks usage windows 47 | 48 | async def check_limits(self) -> bool: 49 | # Check all limits 50 | 51 | async def record_usage(self, usage: Usage): 52 | # Record request and token usage 53 | ``` 54 | 55 | #### 2. RateLimit Models 56 | 57 | ```python 58 | @dataclass 59 | class RateLimit: 60 | type: Literal["requests", "tokens"] 61 | limit: int 62 | period: str # Human-readable period 63 | 64 | @dataclass 65 | class Usage: 66 | request_count: int 67 | token_count: Optional[int] 68 | timestamp: datetime 69 | ``` 70 | 71 | #### 3. TimeWindow 72 | 73 | Handles the sliding window implementation for rate limiting: 74 | 75 | ```python 76 | class TimeWindow: 77 | def __init__(self, period: str): 78 | self.period = self._parse_period(period) 79 | self.usage_records = [] 80 | 81 | def add_usage(self, usage: Usage): 82 | # Add usage and cleanup old records 83 | 84 | def current_usage(self) -> int: 85 | # Calculate current usage in window 86 | ``` 87 | 88 | ### Integration with Existing Architecture 89 | 90 | #### 1. Provider Base Class Enhancement 91 | 92 | ```python 93 | class Provider: 94 | def __init__(self, rate_limits: List[Dict]): 95 | self.rate_limit_manager = RateLimitManager( 96 | [RateLimit(**limit) for limit in rate_limits] 97 | ) 98 | 99 | async def _check_rate_limits(self): 100 | # Check before making requests 101 | 102 | async def _record_usage(self, response): 103 | # Record after successful requests 104 | ``` 105 | 106 | #### 2. RequestManager Integration 107 | 108 | The RequestManager will be enhanced to: 109 | - Check rate limits before processing requests 110 | - Handle rate limit errors gracefully 111 | - Implement backoff strategies when limits are reached 112 | 113 | ### Time Unit Parsing 114 | 115 | Time unit parsing will support: 116 | - Seconds: "s", "sec", "second", "seconds" 117 | - Minutes: "m", "min", "minute", "minutes" 118 | - Hours: "h", "hr", "hour", "hours" 119 | - Days: "d", "day", "days" 120 | 121 | Example implementation: 122 | 123 | ```python 124 | def parse_time_unit(period: str) -> timedelta: 125 | match = re.match(r"(\d+)([smhd])", period) 126 | if not match: 127 | raise ValueError("Invalid time unit format") 128 | 129 | value, unit = match.groups() 130 | value = int(value) 131 | 132 | return { 133 | 's': timedelta(seconds=value), 134 | 'm': timedelta(minutes=value), 135 | 'h': timedelta(hours=value), 136 | 'd': timedelta(days=value) 137 | }[unit] 138 | ``` 139 | 140 | ## Usage Example 141 | 142 | ```python 143 | # Provider configuration 144 | openai_provider = OpenAIProvider( 145 | api_key="...", 146 | rate_limits=[ 147 | { 148 | "type": "requests", 149 | "limit": 100, 150 | "period": "1m" 151 | }, 152 | { 153 | "type": "tokens", 154 | "limit": 100000, 155 | "period": "1h" 156 | } 157 | ] 158 | ) 159 | 160 | # Usage in code 161 | async with openai_provider as provider: 162 | response = await provider.complete(prompt) 163 | # Rate limits are automatically checked and updated 164 | ``` 165 | 166 | ## Error Handling 167 | 168 | 1. Rate Limit Exceeded 169 | ```python 170 | class RateLimitExceeded(Exception): 171 | def __init__(self, limit_type: str, retry_after: float): 172 | self.limit_type = limit_type 173 | self.retry_after = retry_after 174 | ``` 175 | 176 | 2. Recovery Strategy 177 | - Implement exponential backoff 178 | - Queue requests when limits are reached 179 | - Provide retry-after information to clients 180 | 181 | ## Implementation Phases 182 | 183 | 1. Phase 1: Basic Implementation 184 | - Implement RateLimitManager 185 | - Add time unit parsing 186 | - Integrate with Provider base class 187 | 188 | 2. Phase 2: Enhanced Features 189 | - Add token tracking 190 | - Implement sliding windows 191 | - Add multiple limit support 192 | 193 | 3. Phase 3: Optimization 194 | - Add request queuing 195 | - Implement smart backoff strategies 196 | - Add monitoring and metrics 197 | 198 | ## Monitoring and Metrics 199 | 200 | The rate limiting system will expose metrics for: 201 | - Current usage per window 202 | - Remaining quota 203 | - Rate limit hits 204 | - Average usage patterns 205 | 206 | These metrics can be integrated with the existing monitoring system. 207 | 208 | ## Future Considerations 209 | 210 | 1. Distributed Rate Limiting 211 | - Support for Redis-based rate limiting 212 | - Cluster-aware rate limiting 213 | 214 | 2. Dynamic Rate Limits 215 | - Allow providers to update limits based on API responses 216 | - Support for dynamic quota adjustments 217 | 218 | 3. Rate Limit Optimization 219 | - Predictive rate limiting 220 | - Smart request scheduling -------------------------------------------------------------------------------- /examples/embedding_example.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Test script for embeddings API with parallel processing and caching.""" 3 | 4 | import json 5 | import os 6 | from pathlib import Path 7 | from typing import Any, List, Optional, Union, Literal 8 | import asyncio 9 | import numpy as np 10 | 11 | import typer 12 | from rich.console import Console 13 | from rich.panel import Panel 14 | from rich.table import Table 15 | from dotenv import load_dotenv 16 | 17 | from fastllm import ( 18 | RequestBatch, RequestManager, 19 | ResponseWrapper, InMemoryCache, DiskCache, 20 | OpenAIProvider 21 | ) 22 | 23 | # Default values for command options 24 | DEFAULT_MODEL = "text-embedding-3-small" 25 | DEFAULT_DIMENSIONS = 384 26 | DEFAULT_CONCURRENCY = 10 27 | DEFAULT_OUTPUT = "embedding_results.json" 28 | 29 | app = typer.Typer() 30 | 31 | load_dotenv() 32 | 33 | # Sample texts for embedding 34 | DEFAULT_TEXTS = [ 35 | f"The quick brown fox jumps over the lazy dog {i}" for i in range(100) 36 | ] 37 | 38 | 39 | def cosine_similarity(a, b): 40 | """Calculate cosine similarity between two vectors.""" 41 | return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) 42 | 43 | 44 | def normalize_l2(x): 45 | """L2 normalize a vector or batch of vectors.""" 46 | x = np.array(x) 47 | if x.ndim == 1: 48 | norm = np.linalg.norm(x) 49 | if norm == 0: 50 | return x 51 | return x / norm 52 | else: 53 | norm = np.linalg.norm(x, 2, axis=1, keepdims=True) 54 | return np.where(norm == 0, x, x / norm) 55 | 56 | 57 | def process_response( 58 | response: ResponseWrapper, index: int 59 | ) -> dict[str, Any]: 60 | """Process an embedding response into a serializable format.""" 61 | if not response.is_embedding_response: 62 | return { 63 | "index": index, 64 | "type": "error", 65 | "error": "Not an embedding response" 66 | } 67 | 68 | return { 69 | "index": index, 70 | "type": "success", 71 | "request_id": response.request_id, 72 | "model": response.response.get("model"), 73 | "usage": response.response.get("usage"), 74 | "embedding_count": len(response.embeddings), 75 | "embedding_dimensions": [len(emb) for emb in response.embeddings], 76 | } 77 | 78 | 79 | def create_embedding_batch( 80 | texts: List[str], 81 | model: str, 82 | dimensions: Optional[int] = None, 83 | encoding_format: Optional[str] = None 84 | ) -> RequestBatch: 85 | """Create a batch of embedding requests.""" 86 | batch = RequestBatch() 87 | 88 | # Add individual text embedding requests 89 | for text in texts: 90 | batch.embeddings.create( 91 | model=model, 92 | input=text, 93 | dimensions=dimensions, 94 | encoding_format=encoding_format, 95 | ) 96 | 97 | # Add a batch request with multiple texts (last 3 texts) 98 | if len(texts) >= 3: 99 | batch.embeddings.create( 100 | model=model, 101 | input=texts[-3:], 102 | dimensions=dimensions, 103 | encoding_format=encoding_format, 104 | ) 105 | 106 | return batch 107 | 108 | 109 | def run_test( 110 | *, # Force keyword arguments 111 | api_key: str, 112 | model: str, 113 | texts: List[str], 114 | dimensions: Optional[int], 115 | encoding_format: Optional[str], 116 | concurrency: int, 117 | output: Path | str, 118 | no_progress: bool = False, 119 | cache_type: Literal["memory", "disk"] = "memory", 120 | cache_ttl: Optional[int] = None, 121 | ) -> None: 122 | """Run the embedding test with given parameters.""" 123 | console = Console() 124 | 125 | # Create batch of requests 126 | batch = create_embedding_batch(texts, model, dimensions, encoding_format) 127 | 128 | # Show configuration 129 | console.print( 130 | Panel.fit( 131 | "\n".join( 132 | [ 133 | f"Model: {model}", 134 | f"Dimensions: {dimensions or 'default'}", 135 | f"Format: {encoding_format or 'default'}", 136 | f"Texts: {len(texts)}", 137 | f"Requests: {len(batch)}", 138 | f"Concurrency: {concurrency}", 139 | ] 140 | ), 141 | title="[bold blue]Embedding Test Configuration", 142 | ) 143 | ) 144 | 145 | # Create cache provider based on type 146 | if cache_type == "memory": 147 | cache_provider = InMemoryCache() 148 | else: 149 | cache_provider = DiskCache( 150 | directory="./cache", 151 | ttl=cache_ttl, 152 | size_limit=int(2e9), # 2GB size limit 153 | ) 154 | 155 | provider = OpenAIProvider( 156 | api_key=api_key, 157 | ) 158 | manager = RequestManager( 159 | provider=provider, 160 | concurrency=concurrency, 161 | show_progress=not no_progress, 162 | caching_provider=cache_provider, 163 | ) 164 | 165 | try: 166 | # First run: Process batch 167 | console.print("[bold blue]Starting first run (no cache)...[/bold blue]") 168 | responses_first = manager.process_batch(batch) 169 | 170 | # Extract and process results 171 | all_embeddings = [] 172 | results_data_first = [] 173 | 174 | for i, response in enumerate(responses_first): 175 | result = process_response(response, i) 176 | results_data_first.append(result) 177 | 178 | if response.is_embedding_response: 179 | embeddings = response.embeddings 180 | for embedding in embeddings: 181 | all_embeddings.append(embedding) 182 | 183 | # Summary table for first run 184 | console.print("\n[bold green]First Run Results:[/bold green]") 185 | table = Table(show_header=True) 186 | table.add_column("Request #") 187 | table.add_column("Model") 188 | table.add_column("Embeddings") 189 | table.add_column("Dimensions") 190 | table.add_column("Tokens") 191 | 192 | for i, result in enumerate(results_data_first): 193 | if result["type"] == "success": 194 | dims = ", ".join([str(d) for d in result["embedding_dimensions"]]) 195 | table.add_row( 196 | f"{i+1}", 197 | result["model"], 198 | str(result["embedding_count"]), 199 | dims, 200 | str(result["usage"]["prompt_tokens"]) 201 | ) 202 | 203 | console.print(table) 204 | 205 | # Compare some embeddings 206 | if len(all_embeddings) >= 2: 207 | console.print("\n[bold blue]Embedding Similarities:[/bold blue]") 208 | similarity_table = Table(show_header=True) 209 | similarity_table.add_column("Embedding Pair") 210 | similarity_table.add_column("Cosine Similarity") 211 | 212 | # Compare a few pairs (not all combinations to keep output manageable) 213 | num_comparisons = min(5, len(all_embeddings) * (len(all_embeddings) - 1) // 2) 214 | compared = set() 215 | comparison_count = 0 216 | 217 | for i in range(len(all_embeddings)): 218 | for j in range(i+1, len(all_embeddings)): 219 | if comparison_count >= num_comparisons: 220 | break 221 | if (i, j) not in compared: 222 | sim = cosine_similarity(all_embeddings[i], all_embeddings[j]) 223 | similarity_table.add_row(f"{i+1} and {j+1}", f"{sim:.4f}") 224 | compared.add((i, j)) 225 | comparison_count += 1 226 | 227 | console.print(similarity_table) 228 | 229 | # Second run: Process the same batch, expecting cached results 230 | console.print("\n[bold blue]Starting second run (cached)...[/bold blue]") 231 | responses_second = manager.process_batch(batch) 232 | results_data_second = [] 233 | 234 | for i, response in enumerate(responses_second): 235 | result = process_response(response, i) 236 | results_data_second.append(result) 237 | 238 | # Compare cache performance 239 | console.print( 240 | Panel.fit( 241 | "\n".join( 242 | [ 243 | f"First Run - Successful: [green]{sum(1 for r in results_data_first if r['type'] == 'success')}[/green]", 244 | f"Second Run - Successful: [green]{sum(1 for r in results_data_second if r['type'] == 'success')}[/green]", 245 | f"Total Requests: {len(batch)}", 246 | ] 247 | ), 248 | title="[bold green]Cache Performance", 249 | ) 250 | ) 251 | 252 | # Save results from both runs 253 | if output != "NO_OUTPUT": 254 | output_path = Path(output) 255 | output_path.write_text( 256 | json.dumps( 257 | { 258 | "config": { 259 | "model": model, 260 | "dimensions": dimensions, 261 | "encoding_format": encoding_format, 262 | "concurrency": concurrency, 263 | "cache_type": cache_type, 264 | "cache_ttl": cache_ttl, 265 | }, 266 | "first_run_results": results_data_first, 267 | "second_run_results": results_data_second, 268 | }, 269 | indent=2, 270 | ) 271 | ) 272 | console.print(f"\nResults saved to [bold]{output_path}[/bold]") 273 | 274 | finally: 275 | # Clean up disk cache if used 276 | if cache_type == "disk": 277 | # Run close in asyncio event loop 278 | asyncio.run(cache_provider.close()) 279 | 280 | 281 | @app.command() 282 | def main( 283 | model: str = typer.Option( 284 | DEFAULT_MODEL, 285 | "--model", 286 | "-m", 287 | help="Embedding model to use", 288 | ), 289 | dimensions: Optional[int] = typer.Option( 290 | DEFAULT_DIMENSIONS, 291 | "--dimensions", 292 | "-d", 293 | help="Dimensions for the embeddings (None uses model default)", 294 | ), 295 | encoding_format: Optional[str] = typer.Option( 296 | None, 297 | "--format", 298 | "-f", 299 | help="Encoding format (float or base64)", 300 | ), 301 | concurrency: int = typer.Option( 302 | DEFAULT_CONCURRENCY, 303 | "--concurrency", 304 | "-c", 305 | help="Concurrent requests", 306 | ), 307 | output: str = typer.Option( 308 | DEFAULT_OUTPUT, 309 | "--output", 310 | "-o", 311 | help="Output file", 312 | ), 313 | no_progress: bool = typer.Option( 314 | False, 315 | "--no-progress", 316 | help="Disable progress tracking", 317 | ), 318 | cache_type: str = typer.Option( 319 | "memory", 320 | "--cache-type", 321 | help="Cache type to use (memory or disk)", 322 | ), 323 | cache_ttl: Optional[int] = typer.Option( 324 | None, 325 | "--cache-ttl", 326 | help="Time to live in seconds for cached items (disk cache only)", 327 | ), 328 | input_file: Optional[str] = typer.Option( 329 | None, 330 | "--input-file", 331 | "-i", 332 | help="File with texts to embed (one per line)", 333 | ), 334 | ) -> None: 335 | """Run embeddings test with caching.""" 336 | # Get API key from environment variable 337 | api_key = os.environ.get("OPENAI_API_KEY") 338 | if not api_key: 339 | raise ValueError("Please set the OPENAI_API_KEY environment variable") 340 | 341 | # Load texts from file if provided, otherwise use default texts 342 | texts = DEFAULT_TEXTS 343 | if input_file: 344 | try: 345 | with open(input_file, 'r') as f: 346 | texts = [line.strip() for line in f if line.strip()] 347 | except Exception as e: 348 | typer.echo(f"Error reading input file: {e}") 349 | raise typer.Exit(code=1) 350 | 351 | run_test( 352 | api_key=api_key, 353 | model=model, 354 | texts=texts, 355 | dimensions=dimensions, 356 | encoding_format=encoding_format, 357 | concurrency=concurrency, 358 | output=output, 359 | no_progress=no_progress, 360 | cache_type=cache_type, 361 | cache_ttl=cache_ttl, 362 | ) 363 | 364 | 365 | if __name__ == "__main__": 366 | app() -------------------------------------------------------------------------------- /examples/notebook_test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from fastllm.core import RequestBatch, RequestManager\n", 10 | "from fastllm.providers.openai import OpenAIProvider\n", 11 | "from fastllm.cache import DiskCache\n", 12 | "\n", 13 | "%load_ext rich" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "manager = RequestManager(\n", 23 | " provider=OpenAIProvider(\n", 24 | " api_key=\"..\",\n", 25 | " api_base=\"https://openrouter.ai/api/v1\",\n", 26 | " ),\n", 27 | " caching_provider=DiskCache(directory=\"cache\"),\n", 28 | " concurrency=10,\n", 29 | " show_progress=True,\n", 30 | ")" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 3, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "ids = []\n", 40 | "with RequestBatch() as batch:\n", 41 | " for i in range(100):\n", 42 | " ids.append(\n", 43 | " batch.chat.completions.create(\n", 44 | " model=\"meta-llama/llama-3.2-3b-instruct\",\n", 45 | " messages=[\n", 46 | " {\n", 47 | " \"role\": \"user\",\n", 48 | " \"content\": f\"Print only number, number is {i}. Do not include any other text.\",\n", 49 | " }\n", 50 | " ],\n", 51 | " max_completion_tokens=100,\n", 52 | " temperature=0.5,\n", 53 | " )\n", 54 | " )" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 4, 60 | "metadata": {}, 61 | "outputs": [ 62 | { 63 | "data": { 64 | "application/vnd.jupyter.widget-view+json": { 65 | "model_id": "7b2a3a604b714056be4a38ad673ce7e0", 66 | "version_major": 2, 67 | "version_minor": 0 68 | }, 69 | "text/plain": [ 70 | "Output()" 71 | ] 72 | }, 73 | "metadata": {}, 74 | "output_type": "display_data" 75 | }, 76 | { 77 | "data": { 78 | "text/html": [ 79 | "
\n"
 80 |       ],
 81 |       "text/plain": []
 82 |      },
 83 |      "metadata": {},
 84 |      "output_type": "display_data"
 85 |     }
 86 |    ],
 87 |    "source": [
 88 |     "responses = manager.process_batch(batch)"
 89 |    ]
 90 |   },
 91 |   {
 92 |    "cell_type": "code",
 93 |    "execution_count": 8,
 94 |    "metadata": {},
 95 |    "outputs": [
 96 |     {
 97 |      "data": {
 98 |       "text/html": [
 99 |        "
\n"
100 |       ],
101 |       "text/plain": []
102 |      },
103 |      "metadata": {},
104 |      "output_type": "display_data"
105 |     },
106 |     {
107 |      "data": {
108 |       "text/plain": [
109 |        "\u001b[32m'b3cf61e94e5ae858'\u001b[0m"
110 |       ]
111 |      },
112 |      "execution_count": 8,
113 |      "metadata": {},
114 |      "output_type": "execute_result"
115 |     }
116 |    ],
117 |    "source": [
118 |     "ids[-1]"
119 |    ]
120 |   },
121 |   {
122 |    "cell_type": "code",
123 |    "execution_count": 7,
124 |    "metadata": {},
125 |    "outputs": [
126 |     {
127 |      "data": {
128 |       "text/html": [
129 |        "
\n"
130 |       ],
131 |       "text/plain": []
132 |      },
133 |      "metadata": {},
134 |      "output_type": "display_data"
135 |     },
136 |     {
137 |      "data": {
138 |       "text/plain": [
139 |        "\u001b[32m'b3cf61e94e5ae858'\u001b[0m"
140 |       ]
141 |      },
142 |      "execution_count": 7,
143 |      "metadata": {},
144 |      "output_type": "execute_result"
145 |     }
146 |    ],
147 |    "source": [
148 |     "responses[-1].request_id"
149 |    ]
150 |   },
151 |   {
152 |    "cell_type": "code",
153 |    "execution_count": null,
154 |    "metadata": {},
155 |    "outputs": [],
156 |    "source": []
157 |   }
158 |  ],
159 |  "metadata": {
160 |   "kernelspec": {
161 |    "display_name": ".venv",
162 |    "language": "python",
163 |    "name": "python3"
164 |   },
165 |   "language_info": {
166 |    "codemirror_mode": {
167 |     "name": "ipython",
168 |     "version": 3
169 |    },
170 |    "file_extension": ".py",
171 |    "mimetype": "text/x-python",
172 |    "name": "python",
173 |    "nbconvert_exporter": "python",
174 |    "pygments_lexer": "ipython3",
175 |    "version": "3.12.7"
176 |   }
177 |  },
178 |  "nbformat": 4,
179 |  "nbformat_minor": 2
180 | }
181 | 


--------------------------------------------------------------------------------
/examples/parallel_test.py:
--------------------------------------------------------------------------------
  1 | """Test script for parallel request handling."""
  2 | 
  3 | import json
  4 | import os
  5 | from pathlib import Path
  6 | from typing import Any, Optional, Union, Literal
  7 | import asyncio
  8 | 
  9 | import typer
 10 | from openai.types.chat import ChatCompletion
 11 | from rich.console import Console
 12 | from rich.panel import Panel
 13 | from dotenv import load_dotenv
 14 | 
 15 | from fastllm.core import RequestBatch, RequestManager, ResponseWrapper
 16 | from fastllm.providers.openai import OpenAIProvider
 17 | from fastllm.cache import InMemoryCache, DiskCache
 18 | 
 19 | # Default values for command options
 20 | DEFAULT_REPEATS = 10
 21 | DEFAULT_CONCURRENCY = 50
 22 | DEFAULT_TEMPERATURE = 0.7
 23 | DEFAULT_OUTPUT = "results.json"
 24 | 
 25 | app = typer.Typer()
 26 | 
 27 | load_dotenv()
 28 | 
 29 | 
 30 | def process_response(
 31 |     response: ResponseWrapper[ChatCompletion], index: int
 32 | ) -> dict[str, Any]:
 33 |     """Process a response into a serializable format."""
 34 |     return {
 35 |         "index": index,
 36 |         "type": "success",
 37 |         "request_id": response.request_id,
 38 |         "raw_response": response.response,
 39 |     }
 40 | 
 41 | 
 42 | def run_test(
 43 |     *,  # Force keyword arguments
 44 |     api_key: str,
 45 |     model: str,
 46 |     repeats: int,
 47 |     concurrency: int,
 48 |     output: Path | str,
 49 |     temperature: float,
 50 |     max_tokens: Optional[int],
 51 |     no_progress: bool = False,
 52 |     cache_type: Literal["memory", "disk"] = "memory",
 53 |     cache_ttl: Optional[int] = None,
 54 | ) -> None:
 55 |     """Run the test with given parameters."""
 56 |     console = Console()
 57 | 
 58 |     # Create batch of requests using OpenAI-style API
 59 |     with RequestBatch() as batch:
 60 |         # Add single prompt requests
 61 |         for i in range(repeats):
 62 |             batch.chat.completions.create(
 63 |                 model=model,
 64 |                 messages=[
 65 |                     {
 66 |                         "role": "user",
 67 |                         "content": f"Print only number, number is {i}. Do not include any other text.",
 68 |                     }
 69 |                 ],
 70 |                 temperature=temperature,
 71 |                 max_completion_tokens=max_tokens,
 72 |             )
 73 | 
 74 |     # Show configuration
 75 |     console.print(
 76 |         Panel.fit(
 77 |             "\n".join(
 78 |                 [
 79 |                     f"Model: {model}",
 80 |                     f"Temperature: {temperature}",
 81 |                     f"Max Tokens: {max_tokens or 'default'}",
 82 |                     f"Requests: {len(batch)}",
 83 |                     f"Concurrency: {concurrency}",
 84 |                 ]
 85 |             ),
 86 |             title="[bold blue]Test Configuration",
 87 |         )
 88 |     )
 89 | 
 90 |     # Create cache provider based on type
 91 |     if cache_type == "memory":
 92 |         cache_provider = InMemoryCache()
 93 |     else:
 94 |         cache_provider = DiskCache(
 95 |             directory="./cache",
 96 |             ttl=cache_ttl,
 97 |             size_limit=int(2e9),  # 2GB size limit
 98 |         )
 99 | 
100 |     provider = OpenAIProvider(
101 |         api_key=api_key,
102 |         api_base="https://llm.buffedby.ai/v1",
103 |     )
104 |     manager = RequestManager(
105 |         provider=provider,
106 |         concurrency=concurrency,
107 |         show_progress=not no_progress,
108 |         caching_provider=cache_provider,
109 |     )
110 | 
111 |     try:
112 |         # First run: Process batch
113 |         responses_first = manager.process_batch(batch)
114 |         successful_first = 0
115 |         results_data_first = []
116 |         for i, response in enumerate(responses_first):
117 |             result = process_response(response, i)
118 |             successful_first += 1
119 |             results_data_first.append(result)
120 |             if i+3 > len(responses_first):
121 |                 console.print(f"Response #{i+1} of {len(responses_first)}")
122 |                 console.print(result)
123 | 
124 |         console.print(
125 |             Panel.fit(
126 |                 "\n".join(
127 |                     [
128 |                         f"First Run - Successful: [green]{successful_first}[/green]",
129 |                         f"Total: {len(responses_first)} (matches {len(batch)} requests)",
130 |                     ]
131 |                 ),
132 |                 title="[bold green]Results - First Run",
133 |             )
134 |         )
135 | 
136 |         # Second run: Process the same batch, expecting cached results
137 |         responses_second = manager.process_batch(batch)
138 |         successful_second = 0
139 |         results_data_second = []
140 |         for i, response in enumerate(responses_second):
141 |             result = process_response(response, i)
142 |             successful_second += 1
143 |             results_data_second.append(result)
144 |             if i+2 > len(responses_second):
145 |                 console.print(f"Response #{i+1} of {len(responses_second)}")
146 |                 console.print(result)
147 |         console.print(
148 |             Panel.fit(
149 |                 "\n".join(
150 |                     [
151 |                         f"Second Run - Successful: [green]{successful_second}[/green]",
152 |                         f"Total: {len(responses_second)} (matches {len(batch)} requests)",
153 |                     ]
154 |                 ),
155 |                 title="[bold green]Results - Second Run (Cached)",
156 |             )
157 |         )
158 | 
159 |         # Save results from both runs
160 |         if output != "NO_OUTPUT":
161 |             output = Path(output)
162 |             output.write_text(
163 |                 json.dumps(
164 |                     {
165 |                     "config": {
166 |                         "model": model,
167 |                         "temperature": temperature,
168 |                         "max_tokens": max_tokens,
169 |                         "repeats": repeats,
170 |                         "concurrency": concurrency,
171 |                         "cache_type": cache_type,
172 |                         "cache_ttl": cache_ttl,
173 |                     },
174 |                     "first_run_results": results_data_first,
175 |                     "second_run_results": results_data_second,
176 |                     "first_run_summary": {
177 |                         "successful": successful_first,
178 |                         "total": len(responses_first),
179 |                     },
180 |                     "second_run_summary": {
181 |                         "successful": successful_second,
182 |                         "total": len(responses_second),
183 |                         },
184 |                     },
185 |                     indent=2,
186 |                 )
187 |             )
188 |     finally:
189 |         # Clean up disk cache if used
190 |         if cache_type == "disk":
191 |             # Run close in asyncio event loop
192 |             asyncio.run(cache_provider.close())
193 | 
194 | 
195 | @app.command()
196 | def main(
197 |     model: str = typer.Option(
198 |         "meta-llama/llama-3.2-3b-instruct",
199 |         "--model",
200 |         "-m",
201 |         help="Model to use",
202 |     ),
203 |     repeats: int = typer.Option(
204 |         DEFAULT_REPEATS,
205 |         "--repeats",
206 |         "-n",
207 |         help="Number of repeats",
208 |     ),
209 |     concurrency: int = typer.Option(
210 |         DEFAULT_CONCURRENCY,
211 |         "--concurrency",
212 |         "-c",
213 |         help="Concurrent requests",
214 |     ),
215 |     output: str = typer.Option(
216 |         DEFAULT_OUTPUT,
217 |         "--output",
218 |         "-o",
219 |         help="Output file",
220 |     ),
221 |     temperature: float = typer.Option(
222 |         DEFAULT_TEMPERATURE,
223 |         "--temperature",
224 |         "-t",
225 |         help="Temperature for generation",
226 |     ),
227 |     max_tokens: Optional[int] = typer.Option(
228 |         None,
229 |         "--max-tokens",
230 |         help="Maximum tokens to generate",
231 |     ),
232 |     no_progress: bool = typer.Option(
233 |         False,
234 |         "--no-progress",
235 |         help="Disable progress tracking",
236 |     ),
237 |     cache_type: str = typer.Option(
238 |         "memory",
239 |         "--cache-type",
240 |         help="Cache type to use (memory or disk)",
241 |     ),
242 |     cache_ttl: Optional[int] = typer.Option(
243 |         None,
244 |         "--cache-ttl",
245 |         help="Time to live in seconds for cached items (disk cache only)",
246 |     ),
247 | ) -> None:
248 |     """Run parallel request test."""
249 |     api_key = os.environ["BB_AI_API_KEY"]
250 | 
251 |     run_test(
252 |         api_key=api_key,
253 |         model=model,
254 |         repeats=repeats,
255 |         concurrency=concurrency,
256 |         output=output,
257 |         temperature=temperature,
258 |         max_tokens=max_tokens,
259 |         no_progress=no_progress,
260 |         cache_type=cache_type,
261 |         cache_ttl=cache_ttl,
262 |     )
263 | 
264 | 
265 | if __name__ == "__main__":
266 |     app()
267 | 


--------------------------------------------------------------------------------
/fastllm/__init__.py:
--------------------------------------------------------------------------------
 1 | """FastLLM - High-performance parallel LLM API request tool."""
 2 | 
 3 | __version__ = "0.1.0"
 4 | 
 5 | from fastllm.core import (
 6 |     RequestBatch,
 7 |     RequestManager,
 8 |     ResponseWrapper,
 9 |     TokenStats,
10 | )
11 | from fastllm.cache import (
12 |     CacheProvider,
13 |     InMemoryCache,
14 |     DiskCache,
15 |     compute_request_hash,
16 | )
17 | from fastllm.providers.base import Provider
18 | from fastllm.providers.openai import OpenAIProvider
19 | 
20 | __all__ = [
21 |     # Core components
22 |     "RequestBatch",
23 |     "RequestManager",
24 |     "ResponseWrapper",
25 |     "TokenStats",
26 |     
27 |     # Cache components
28 |     "CacheProvider",
29 |     "InMemoryCache",
30 |     "DiskCache",
31 |     "compute_request_hash",
32 |     
33 |     # Provider components
34 |     "Provider",
35 |     "OpenAIProvider",
36 | ]
37 | 


--------------------------------------------------------------------------------
/fastllm/cache.py:
--------------------------------------------------------------------------------
  1 | import json
  2 | import xxhash
  3 | import asyncio
  4 | import logging
  5 | from typing import Any, Dict, Optional
  6 | from diskcache import Cache
  7 | 
  8 | # Configure logging
  9 | logger = logging.getLogger(__name__)
 10 | 
 11 | 
 12 | class CacheProvider:
 13 |     """Base class for cache providers with async interface."""
 14 |     
 15 |     async def exists(self, key: str) -> bool:
 16 |         """Check if a key exists in the cache."""
 17 |         raise NotImplementedError
 18 | 
 19 |     async def get(self, key: str):
 20 |         """Get a value from the cache."""
 21 |         raise NotImplementedError
 22 | 
 23 |     async def put(self, key: str, value) -> None:
 24 |         """Put a value in the cache."""
 25 |         raise NotImplementedError
 26 | 
 27 |     async def clear(self) -> None:
 28 |         """Clear all items from the cache."""
 29 |         raise NotImplementedError
 30 | 
 31 |     async def close(self) -> None:
 32 |         """Close the cache when done."""
 33 |         pass
 34 | 
 35 | 
 36 | def compute_request_hash(request: dict) -> str:
 37 |     """Compute a hash for a request that can be used as a cache key.
 38 |     
 39 |     Args:
 40 |         request: The request dictionary to hash
 41 |         
 42 |     Returns:
 43 |         str: A hex string hash of the request
 44 |         
 45 |     Note:
 46 |         - None values and empty values are removed from the request before hashing
 47 |         - Internal fields (_request_id, _order_id) are removed
 48 |         - Default values are not added if not present
 49 |     """
 50 |     # Create a copy of the request and remove any fields that are not part of the request content
 51 |     temp_request = request.copy()
 52 |     
 53 |     # Remove internal tracking fields that shouldn't affect caching
 54 |     temp_request.pop("_request_id", None)
 55 |     temp_request.pop("_order_id", None)
 56 |     
 57 |     # Extract known fields and extra params
 58 |     known_fields = {
 59 |         "provider", "model", "messages", "temperature", "max_completion_tokens",
 60 |         "top_p", "presence_penalty", "frequency_penalty", "stop", "stream",
 61 |         # Embedding specific fields
 62 |         "type", "input", "dimensions", "encoding_format", "user"
 63 |     }
 64 |     
 65 |     def clean_value(v):
 66 |         """Remove empty values and normalize None values."""
 67 |         if v is None:
 68 |             return None
 69 |         if isinstance(v, (dict, list)):
 70 |             return clean_dict_or_list(v)
 71 |         if isinstance(v, str) and not v.strip():
 72 |             return None
 73 |         return v
 74 |     
 75 |     def clean_dict_or_list(obj):
 76 |         """Recursively clean dictionaries and lists."""
 77 |         if isinstance(obj, dict):
 78 |             cleaned = {k: clean_value(v) for k, v in obj.items()}
 79 |             return {k: v for k, v in cleaned.items() if v is not None}
 80 |         if isinstance(obj, list):
 81 |             cleaned = [clean_value(v) for v in obj]
 82 |             return [v for v in cleaned if v is not None]
 83 |         return obj
 84 |     
 85 |     # Clean and separate core parameters and extra parameters
 86 |     core_params = {k: clean_value(v) for k, v in temp_request.items() if k in known_fields}
 87 |     extra_params = {k: clean_value(v) for k, v in temp_request.items() if k not in known_fields}
 88 |     
 89 |     # Remove None values and empty values
 90 |     core_params = {k: v for k, v in core_params.items() if v is not None}
 91 |     extra_params = {k: v for k, v in extra_params.items() if v is not None}
 92 |     
 93 |     # Create a combined dictionary with sorted extra params
 94 |     hash_dict = {
 95 |         "core": core_params,
 96 |         "extra": dict(sorted(extra_params.items()))  # Sort extra params for consistent hashing
 97 |     }
 98 |     
 99 |     # Serialize with sorted keys for a consistent representation
100 |     request_str = json.dumps(hash_dict, sort_keys=True, ensure_ascii=False)
101 |     return xxhash.xxh64(request_str.encode("utf-8")).hexdigest()
102 | 
103 | 
104 | class InMemoryCache(CacheProvider):
105 |     """Simple in-memory cache implementation using a dictionary."""
106 |     
107 |     def __init__(self):
108 |         self._cache: Dict[str, Any] = {}
109 | 
110 |     async def exists(self, key: str) -> bool:
111 |         """Check if a key exists in the cache."""
112 |         return key in self._cache
113 | 
114 |     async def get(self, key: str):
115 |         """Get a value from the cache."""
116 |         if not await self.exists(key):
117 |             raise KeyError(f"Cache for key {key} does not exist")
118 |         return self._cache[key]
119 | 
120 |     async def put(self, key: str, value) -> None:
121 |         """Put a value in the cache."""
122 |         self._cache[key] = value
123 | 
124 |     async def clear(self) -> None:
125 |         """Clear all items from the cache."""
126 |         self._cache.clear()
127 | 
128 | 
129 | class DiskCache(CacheProvider):
130 |     """Disk-based cache implementation using diskcache with async support."""
131 |     
132 |     def __init__(self, directory: str, ttl: Optional[int] = None, **cache_options):
133 |         """Initialize disk cache.
134 |         
135 |         Args:
136 |             directory: Directory where cache files will be stored
137 |             ttl: Time to live in seconds for cached items (None means no expiration)
138 |             **cache_options: Additional options to pass to diskcache.Cache
139 |             
140 |         Raises:
141 |             OSError: If the directory is invalid or cannot be created
142 |         """
143 |         try:
144 |             self._cache = Cache(directory, **cache_options)
145 |             self._ttl = ttl
146 |         except Exception as e:
147 |             # Convert any cache initialization error to OSError
148 |             raise OSError(f"Failed to initialize disk cache: {str(e)}")
149 | 
150 |     async def _run_in_executor(self, func, *args):
151 |         """Run a blocking cache operation in the default executor."""
152 |         loop = asyncio.get_event_loop()
153 |         return await loop.run_in_executor(None, func, *args)
154 | 
155 |     async def exists(self, key: str) -> bool:
156 |         """Check if a key exists in the cache."""
157 |         try:
158 |             # Use the internal __contains__ method which is faster than get
159 |             return await self._run_in_executor(self._cache.__contains__, key)
160 |         except Exception as e:
161 |             raise OSError(f"Failed to check cache key: {str(e)}")
162 | 
163 |     async def get(self, key: str):
164 |         """Get a value from the cache."""
165 |         try:
166 |             value = await self._run_in_executor(self._cache.get, key)
167 |             if value is None:
168 |                 # Convert None value to KeyError for consistent behavior with InMemoryCache
169 |                 raise KeyError(f"Cache for key {key} does not exist")
170 |             return value
171 |         except Exception as e:
172 |             if isinstance(e, KeyError):
173 |                 raise
174 |             raise OSError(f"Failed to get cache value: {str(e)}")
175 | 
176 |     async def put(self, key: str, value) -> None:
177 |         """Put a value in the cache with optional TTL."""
178 |         try:
179 |             await self._run_in_executor(self._cache.set, key, value, self._ttl)
180 |         except Exception as e:
181 |             raise OSError(f"Failed to store cache value: {str(e)}")
182 | 
183 |     async def clear(self) -> None:
184 |         """Clear all items from the cache."""
185 |         try:
186 |             await self._run_in_executor(self._cache.clear)
187 |         except Exception as e:
188 |             raise OSError(f"Failed to clear cache: {str(e)}")
189 | 
190 |     async def close(self) -> None:
191 |         """Close the cache when done."""
192 |         try:
193 |             await self._run_in_executor(self._cache.close)
194 |         except Exception as e:
195 |             raise OSError(f"Failed to close cache: {str(e)}") 


--------------------------------------------------------------------------------
/fastllm/core.py:
--------------------------------------------------------------------------------
  1 | """Core functionality for parallel LLM API requests."""
  2 | 
  3 | import asyncio
  4 | import time
  5 | import logging
  6 | from contextlib import AbstractContextManager, nullcontext
  7 | from dataclasses import dataclass
  8 | from typing import (
  9 |     Any,
 10 |     Generic,
 11 |     Optional,
 12 |     TypeVar,
 13 |     Union,
 14 | )
 15 | 
 16 | import httpx
 17 | from openai.types import CompletionUsage
 18 | from openai.types.chat import ChatCompletion
 19 | from openai.types.chat.chat_completion import Choice, CompletionUsage
 20 | from openai.types.chat.chat_completion_message import ChatCompletionMessage
 21 | from pydantic import BaseModel
 22 | from rich.progress import (
 23 |     BarColumn,
 24 |     Progress,
 25 |     SpinnerColumn,
 26 |     TaskProgressColumn,
 27 |     TextColumn,
 28 |     TimeElapsedColumn,
 29 |     TimeRemainingColumn,
 30 |     MofNCompleteColumn
 31 | )
 32 | 
 33 | from fastllm.cache import compute_request_hash
 34 | 
 35 | # Configure logging
 36 | logger = logging.getLogger(__name__)
 37 | 
 38 | 
 39 | # Define a type variable for provider-specific response types
 40 | ResponseT = TypeVar("ResponseT", bound=Union[ChatCompletion, Any])
 41 | 
 42 | DUMMY_RESPONSE = ChatCompletion(
 43 |     id="dummy_id",
 44 |     choices=[
 45 |         Choice(
 46 |             index=0,
 47 |             message=ChatCompletionMessage(content="dummy_content", role="assistant"),
 48 |             finish_reason="stop"
 49 |         )
 50 |     ],
 51 |     created=0,
 52 |     model="dummy_model",
 53 |     object="chat.completion",
 54 |     service_tier="default",
 55 |     system_fingerprint="dummy_system_fingerprint",
 56 |     usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
 57 | )
 58 | 
 59 | 
 60 | class ResponseWrapper(Generic[ResponseT]):
 61 |     """Wrapper for provider responses that includes request ID for sorting."""
 62 | 
 63 |     def __init__(self, response: ResponseT, request_id: str, order_id: int):
 64 |         self.response = response
 65 |         self.request_id = request_id
 66 |         self._order_id = order_id
 67 | 
 68 |     @property
 69 |     def usage(self) -> Optional[CompletionUsage]:
 70 |         """Get usage statistics if available."""
 71 |         if isinstance(self.response, ChatCompletion):
 72 |             return self.response.usage
 73 |         elif isinstance(self.response, dict) and 'usage' in self.response:
 74 |             # Handle dict responses (like embeddings)
 75 |             usage = self.response['usage']
 76 |             # Convert to CompletionUsage if not already
 77 |             if not isinstance(usage, CompletionUsage):
 78 |                 return CompletionUsage(
 79 |                     prompt_tokens=usage.get('prompt_tokens', 0),
 80 |                     completion_tokens=usage.get('completion_tokens', 0),
 81 |                     total_tokens=usage.get('total_tokens', 0)
 82 |                 )
 83 |             return usage
 84 |         return None
 85 |         
 86 |     @property
 87 |     def is_embedding_response(self) -> bool:
 88 |         """Check if this is an embedding response."""
 89 |         if isinstance(self.response, dict):
 90 |             return 'data' in self.response and all('embedding' in item for item in self.response.get('data', []))
 91 |         return False
 92 |         
 93 |     @property
 94 |     def embeddings(self) -> list:
 95 |         """Get embeddings from response if available."""
 96 |         if not self.is_embedding_response:
 97 |             return []
 98 |             
 99 |         if isinstance(self.response, dict) and 'data' in self.response:
100 |             return [item.get('embedding', []) for item in self.response.get('data', [])]
101 |         return []
102 | 
103 | 
104 | @dataclass
105 | class TokenStats:
106 |     """Statistics about token usage."""
107 | 
108 |     prompt_tokens: int = 0
109 |     completion_tokens: int = 0
110 |     total_tokens: int = 0
111 |     requests_completed: int = 0
112 |     cache_hits: int = 0  # Track cache hits
113 |     start_time: float = 0.0
114 |     token_limit: Optional[int] = None  # Rate limit for tokens per minute
115 |     request_limit: Optional[int] = None  # Rate limit for requests per minute
116 |     window_tokens: int = 0  # Tokens in current rate limit window
117 |     window_requests: int = 0  # Requests in current rate limit window
118 |     token_limit: Optional[int] = None  # Rate limit for tokens per minute
119 |     request_limit: Optional[int] = None  # Rate limit for requests per minute
120 |     window_tokens: int = 0  # Tokens in current rate limit window
121 |     window_requests: int = 0  # Requests in current rate limit window
122 | 
123 |     @property
124 |     def elapsed_time(self) -> float:
125 |         return time.time() - self.start_time
126 | 
127 |     @property
128 |     def prompt_tokens_per_second(self) -> float:
129 |         if self.elapsed_time == 0:
130 |             return 0.0
131 |         return self.prompt_tokens / self.elapsed_time
132 | 
133 |     @property
134 |     def completion_tokens_per_second(self) -> float:
135 |         if self.elapsed_time == 0:
136 |             return 0.0
137 |         return self.completion_tokens / self.elapsed_time
138 | 
139 |     @property
140 |     def cache_hit_ratio(self) -> float:
141 |         if self.requests_completed == 0:
142 |             return 0.0
143 |         return self.cache_hits / self.requests_completed
144 | 
145 |     @property
146 |     def token_saturation(self) -> float:
147 |         """Calculate token usage saturation (0.0 to 1.0)."""
148 |         if not self.token_limit or self.elapsed_time == 0:
149 |             return 0.0
150 |         tokens_per_minute = (self.window_tokens / self.elapsed_time) * 60
151 |         return tokens_per_minute / self.token_limit
152 | 
153 |     @property
154 |     def request_saturation(self) -> float:
155 |         """Calculate request rate saturation (0.0 to 1.0)."""
156 |         if not self.request_limit or self.elapsed_time == 0:
157 |             return 0.0
158 |         requests_per_minute = (self.window_requests / self.elapsed_time) * 60
159 |         return requests_per_minute / self.request_limit
160 | 
161 |     @property
162 |     def token_saturation(self) -> float:
163 |         """Calculate token usage saturation (0.0 to 1.0)."""
164 |         if not self.token_limit or self.elapsed_time == 0:
165 |             return 0.0
166 |         tokens_per_minute = (self.window_tokens / self.elapsed_time) * 60
167 |         return tokens_per_minute / self.token_limit
168 | 
169 |     @property
170 |     def request_saturation(self) -> float:
171 |         """Calculate request rate saturation (0.0 to 1.0)."""
172 |         if not self.request_limit or self.elapsed_time == 0:
173 |             return 0.0
174 |         requests_per_minute = (self.window_requests / self.elapsed_time) * 60
175 |         return requests_per_minute / self.request_limit
176 | 
177 |     def update(self, prompt_tokens: int, completion_tokens: int, is_cache_hit: bool = False) -> None:
178 |         """Update token statistics."""
179 |         self.prompt_tokens += prompt_tokens
180 |         self.completion_tokens += completion_tokens
181 |         self.total_tokens += prompt_tokens + completion_tokens
182 |         self.requests_completed += 1
183 |         if is_cache_hit:
184 |             self.cache_hits += 1
185 |         else:
186 |             # Only update window stats for non-cache hits
187 |             self.window_tokens += prompt_tokens + completion_tokens
188 |             self.window_requests += 1
189 | 
190 | 
191 | class ProgressTracker:
192 |     """Tracks progress and token usage for batch requests."""
193 | 
194 |     def __init__(self, total_requests: int, show_progress: bool = True):
195 |         self.stats = TokenStats(start_time=time.time())
196 |         self.total_requests = total_requests
197 |         self.show_progress = show_progress
198 | 
199 |         # Create progress display
200 |         self.progress = Progress(
201 |             SpinnerColumn(),
202 |             TextColumn("[progress.description]{task.description}"),
203 |             BarColumn(),
204 |             TaskProgressColumn(),
205 |             MofNCompleteColumn(),
206 |             TimeRemainingColumn(),
207 |             TimeElapsedColumn(),
208 |             TextColumn("[blue]{task.fields[stats]}"),
209 |             TextColumn("[yellow]{task.fields[cache]}"),
210 |             disable=not show_progress,
211 |         )
212 | 
213 |         # Add main progress task
214 |         self.task_id = self.progress.add_task(
215 |             description="Processing requests",
216 |             total=total_requests,
217 |             stats="Starting...",
218 |             cache="",
219 |         )
220 | 
221 |     def __enter__(self):
222 |         """Start progress display."""
223 |         self.progress.start()
224 |         return self
225 | 
226 |     def __exit__(self, exc_type, exc_val, exc_tb):
227 |         """Stop progress display."""
228 |         self.progress.stop()
229 | 
230 |     def update(self, prompt_tokens: int, completion_tokens: int, is_cache_hit: bool = False):
231 |         """Update progress and token statistics."""
232 |         self.stats.update(prompt_tokens, completion_tokens, is_cache_hit)
233 | 
234 |         # Update progress display with token rates and cache stats
235 |         stats_text = (
236 |             f"[green]⬆ {self.stats.prompt_tokens_per_second:.1f}[/green] "
237 |             f"[red]⬇ {self.stats.completion_tokens_per_second:.1f}[/red] t/s"
238 |         )
239 |         
240 |         cache_text = (
241 |             f"Cache: [green]{self.stats.cache_hit_ratio*100:.1f}%[/green] hits, "
242 |             f"[yellow]{(1-self.stats.cache_hit_ratio)*100:.1f}%[/yellow] new"
243 |         )
244 | 
245 |         self.progress.update(
246 |             self.task_id,
247 |             advance=1,
248 |             stats=stats_text,
249 |             cache=cache_text,
250 |         )
251 | 
252 | 
253 | class RequestManager:
254 |     """Manages parallel LLM API requests."""
255 | 
256 |     def __init__(
257 |         self,
258 |         provider: 'Provider[ResponseT]',
259 |         concurrency: int = 100,
260 |         timeout: float = 30.0,
261 |         retry_attempts: int = 3,
262 |         retry_delay: float = 1.0,
263 |         show_progress: bool = True,
264 |         caching_provider: Optional['CacheProvider'] = None,
265 |         return_dummy_on_error: bool = False,
266 |         dummy_response: Optional[ResponseT] = DUMMY_RESPONSE
267 |     ):
268 |         self.provider = provider
269 |         self.concurrency = concurrency
270 |         self.timeout = timeout
271 |         self.retry_attempts = retry_attempts
272 |         self.retry_delay = retry_delay
273 |         self.show_progress = show_progress
274 |         self.cache = caching_provider
275 |         self.return_dummy_on_error = return_dummy_on_error
276 |         self.dummy_response = dummy_response
277 | 
278 |     def _calculate_chunk_size(self) -> int:
279 |         """Calculate optimal chunk size based on concurrency.
280 |         
281 |         The chunk size is calculated as 2 * concurrency to allow for some overlap
282 |         and better resource utilization while still maintaining reasonable memory usage.
283 |         This provides a balance between creating too many tasks at once and
284 |         underutilizing the available concurrency.
285 |         """
286 |         return min(self.concurrency * 10, 25000)  # Cap at 25000 to prevent excessive memory usage
287 | 
288 |     def process_batch(
289 |         self,
290 |         batch: Union[list[dict[str, Any]], "RequestBatch"],
291 |     ) -> list[ResponseT]:
292 |         """Process a batch of LLM requests in parallel.
293 | 
294 |         This is the main synchronous API endpoint that users should call.
295 |         Internally it uses asyncio to handle requests concurrently.
296 |         Works in both regular Python environments and Jupyter notebooks.
297 | 
298 |         Args:
299 |             batch: Either a RequestBatch object or a list of request dictionaries
300 | 
301 |         Returns:
302 |             List of responses in the same order as the requests
303 |         """
304 |         try:
305 |             loop = asyncio.get_running_loop()
306 |             if loop.is_running():
307 |                 # We're in a Jupyter notebook or similar environment
308 |                 # where the loop is already running
309 |                 import nest_asyncio
310 |                 nest_asyncio.apply()
311 |             return loop.run_until_complete(self._process_batch_async(batch))
312 |         except RuntimeError:
313 |             # No event loop running, create a new one
314 |             return asyncio.run(self._process_batch_async(batch))
315 | 
316 |     async def _process_request_async(
317 |         self,
318 |         client: httpx.AsyncClient,
319 |         request: dict[str, Any],
320 |         progress: Optional[ProgressTracker] = None,
321 |     ) -> ResponseWrapper[ResponseT]:
322 |         """Process a single request with caching support."""
323 |         # Get order ID and request ID from request
324 |         order_id = request.get('_order_id', 0)
325 |         request_id = request.get('_request_id')
326 |         
327 |         if request_id is None:
328 |             # Only compute if not already present
329 |             request_id = compute_request_hash(request)
330 |             request['_request_id'] = request_id
331 | 
332 |         # Check cache first if available
333 |         if self.cache is not None:
334 |             try:
335 |                 if await self.cache.exists(request_id):
336 |                     cached_response = await self.cache.get(request_id)
337 |                     wrapped = ResponseWrapper(cached_response, request_id, order_id)
338 |                     if progress and wrapped.usage:
339 |                         # Update progress with cache hit
340 |                         progress.update(
341 |                             wrapped.usage.prompt_tokens,
342 |                             wrapped.usage.completion_tokens or 0,  # Handle embeddings having no completion tokens
343 |                             is_cache_hit=True
344 |                         )
345 |                     return wrapped
346 |             except Exception as e:
347 |                 logger.warning(f"Cache read error: {str(e)}")
348 | 
349 |         # Process request with retries
350 |         for attempt in range(self.retry_attempts):
351 |             try:
352 |                 response = await self.provider.make_request(
353 |                     client,
354 |                     request,
355 |                     self.timeout,
356 |                 )
357 | 
358 |                 # Create wrapper and update progress
359 |                 wrapped = ResponseWrapper(response, request_id, order_id)
360 |                 
361 |                 if progress:
362 |                     # For embeddings, usage only has prompt_tokens
363 |                     if isinstance(response, dict) and 'usage' in response:
364 |                         usage = response['usage']
365 |                         prompt_tokens = usage.get('prompt_tokens', 0)
366 |                         # Embeddings don't have completion tokens
367 |                         completion_tokens = usage.get('completion_tokens', 0)
368 |                         progress.update(
369 |                             prompt_tokens,
370 |                             completion_tokens,
371 |                             is_cache_hit=False
372 |                         )
373 |                     elif wrapped.usage:
374 |                         progress.update(
375 |                             wrapped.usage.prompt_tokens,
376 |                             wrapped.usage.completion_tokens or 0,  # Handle embeddings having no completion tokens
377 |                             is_cache_hit=False
378 |                         )
379 | 
380 |                 # Cache successful response
381 |                 if self.cache is not None:
382 |                     try:
383 |                         await self.cache.put(request_id, response)
384 |                     except Exception as e:
385 |                         logger.warning(f"Cache write error: {str(e)}")
386 |                 
387 |                 return wrapped
388 | 
389 |             except Exception as e:
390 |                 if attempt == self.retry_attempts - 1:
391 |                     if progress:
392 |                         # Update progress even for failed requests
393 |                         progress.update(0, 0, is_cache_hit=False)
394 |                     if self.return_dummy_on_error:
395 |                         # no caching for failed requests
396 |                         return ResponseWrapper(self.dummy_response, request_id, order_id)
397 |                     else:
398 |                         raise
399 |                 await asyncio.sleep(self.retry_delay * (attempt + 1))
400 | 
401 |     async def _process_batch_async(
402 |         self,
403 |         batch: Union[list[dict[str, Any]], "RequestBatch"],
404 |     ) -> list[ResponseWrapper[ResponseT]]:
405 |         """Internal async implementation of batch processing."""
406 |         # Create semaphore for this batch processing run
407 |         semaphore = asyncio.Semaphore(self.concurrency)
408 | 
409 |         # Convert RequestBatch to list of requests if needed
410 |         if isinstance(batch, RequestBatch):
411 |             # Extract original requests from batch format
412 |             requests = []
413 |             for batch_req in batch.requests:
414 |                 # Extract request_id and order_id from custom_id
415 |                 request_id, order_id_str = batch_req["custom_id"].split("#")
416 |                 order_id = int(order_id_str)
417 |                 
418 |                 # Determine type from URL
419 |                 req_type = "chat_completion" if batch_req["url"] == "/v1/chat/completions" else "embedding"
420 |                 
421 |                 # Extract the original request from the batch format
422 |                 request = {
423 |                     **batch_req["body"],
424 |                     "_request_id": request_id,
425 |                     "_order_id": order_id,
426 |                     "type": req_type
427 |                 }
428 |                 requests.append(request)
429 |         else:
430 |             # Handle raw request list - compute request IDs and add order IDs
431 |             requests = []
432 |             for i, request in enumerate(batch):
433 |                 request = request.copy()  # Don't modify original request
434 |                 if "_request_id" not in request:
435 |                     request["_request_id"] = compute_request_hash(request)
436 |                 request["_order_id"] = i
437 |                 requests.append(request)
438 | 
439 |         # Create progress tracker if enabled
440 |         tracker = (
441 |             ProgressTracker(len(requests), show_progress=self.show_progress)
442 |             if self.show_progress
443 |             else None
444 |         )
445 | 
446 |         async def process_request_with_semaphore(
447 |             client: httpx.AsyncClient,
448 |             request: dict[str, Any],
449 |             progress: Optional[ProgressTracker] = None,
450 |         ) -> ResponseWrapper[ResponseT]:
451 |             """Process a single request with semaphore control."""
452 |             async with semaphore:
453 |                 return await self._process_request_async(client, request, progress)
454 | 
455 |         async def process_batch_chunk(
456 |             client: httpx.AsyncClient, chunk: list[dict[str, Any]]
457 |         ) -> list[ResponseWrapper[ResponseT]]:
458 |             """Process a chunk of requests."""
459 |             batch_tasks = [
460 |                 process_request_with_semaphore(client, req, tracker) for req in chunk
461 |             ]
462 |             results = await asyncio.gather(*batch_tasks)
463 |             return [(r._order_id, r) for r in results]
464 | 
465 |         # Process requests in chunks based on calculated chunk size
466 |         chunk_size = self._calculate_chunk_size()
467 |         all_results = []
468 |         context = tracker if tracker else nullcontext()
469 | 
470 |         # Create a single client for the entire batch
471 |         async with httpx.AsyncClient(timeout=self.timeout) as client:
472 |             with context:
473 |                 for batch_start in range(0, len(requests), chunk_size):
474 |                     batch_requests = requests[
475 |                         batch_start : batch_start + chunk_size
476 |                     ]
477 |                     batch_results = await process_batch_chunk(client, batch_requests)
478 |                     all_results.extend(batch_results)
479 | 
480 |         # Sort responses by order ID and return just the responses
481 |         return [r for _, r in sorted(all_results, key=lambda x: x[0])]
482 | 
483 | 
484 | class RequestBatch(AbstractContextManager):
485 |     """A batch of requests to be processed together in OpenAI Batch format."""
486 | 
487 |     def __init__(self):
488 |         self.requests = []
489 |         self._next_order_id = 0
490 | 
491 |     def __enter__(self):
492 |         return self
493 | 
494 |     def __exit__(self, exc_type, exc_val, exc_tb):
495 |         pass
496 | 
497 |     def __len__(self):
498 |         return len(self.requests)
499 | 
500 |     def _add_request(self, request: dict[str, Any]) -> str:
501 |         """Add a request to the batch and return its request ID (cache key).
502 |         
503 |         Args:
504 |             request: The request to add to the batch
505 |             
506 |         Returns:
507 |             str: The request ID (cache key) for this request
508 |         """
509 |         
510 |         # Compute request ID for caching if not already present
511 |         request_id = compute_request_hash(request)
512 |         order_id = self._next_order_id
513 |         self._next_order_id += 1
514 |         
515 |         # Determine the endpoint URL based on request type
516 |         url = "/v1/chat/completions"
517 |         if request.get("type") == "embedding":
518 |             url = "/v1/embeddings"
519 |         
520 |         # Create a custom_id from request_id and order_id
521 |         custom_id = f"{request_id}#{order_id}"
522 |         
523 |         # Create batch format request directly
524 |         batch_request = {
525 |             "custom_id": custom_id,
526 |             "url": url,
527 |             "body": {k: v for k, v in request.items() if k not in ["type"]}
528 |         }
529 |         
530 |         # Add to batch
531 |         self.requests.append(batch_request)
532 |         return request_id
533 | 
534 |     @classmethod
535 |     def merge(cls, batches: list["RequestBatch"]) -> "RequestBatch":
536 |         """Merge multiple request batches into a single batch."""
537 |         merged = cls()
538 |         for batch in batches:
539 |             merged.requests.extend(batch.requests)
540 |         return merged
541 | 
542 |     @property
543 |     def chat(self):
544 |         """Access chat completion methods."""
545 |         return self.Chat(self)
546 |         
547 |     @property
548 |     def embeddings(self):
549 |         """Access embeddings methods."""
550 |         return self.Embeddings(self)
551 | 
552 |     class Chat:
553 |         """Chat API that mimics OpenAI's interface."""
554 | 
555 |         def __init__(self, batch):
556 |             self.batch = batch
557 |             self.completions = self.Completions(batch)
558 | 
559 |         class Completions:
560 |             """Chat completions API that mimics OpenAI's interface."""
561 | 
562 |             def __init__(self, batch):
563 |                 self.batch = batch
564 | 
565 |             def create(
566 |                 self,
567 |                 *,
568 |                 model: str,
569 |                 messages: list[dict[str, str]],
570 |                 temperature: Optional[float] = 0.7,
571 |                 top_p: Optional[float] = 1.0,
572 |                 n: Optional[int] = 1,
573 |                 stop: Optional[Union[str, list[str]]] = None,
574 |                 max_completion_tokens: Optional[int] = None,
575 |                 presence_penalty: Optional[float] = 0.0,
576 |                 frequency_penalty: Optional[float] = 0.0,
577 |                 logit_bias: Optional[dict[str, float]] = None,
578 |                 user: Optional[str] = None,
579 |                 response_format: Optional[dict[str, str]] = None,
580 |                 seed: Optional[int] = None,
581 |                 tools: Optional[list[dict[str, Any]]] = None,
582 |                 tool_choice: Optional[Union[str, dict[str, str]]] = None,
583 |                 **kwargs: Any
584 |             ) -> str:
585 |                 """Add a chat completion request to the batch.
586 |                 
587 |                 Args:
588 |                     model: The model to use for completion
589 |                     messages: The messages to generate a completion for
590 |                     temperature: Sampling temperature (0-2)
591 |                     top_p: Nucleus sampling parameter (0-1)
592 |                     n: Number of completions to generate
593 |                     stop: Stop sequences to use
594 |                     max_completion_tokens: Maximum tokens to generate
595 |                     presence_penalty: Presence penalty (-2 to 2)
596 |                     frequency_penalty: Frequency penalty (-2 to 2)
597 |                     logit_bias: Token biases to use
598 |                     user: User identifier
599 |                     response_format: Format for the response
600 |                     seed: Random seed for reproducibility
601 |                     tools: List of tools available to the model
602 |                     tool_choice: Tool choice configuration
603 |                     **kwargs: Additional provider-specific parameters
604 | 
605 |                 Returns:
606 |                     str: The request ID (cache key) for this request
607 |                 """
608 |                 # Create the request body
609 |                 body = {
610 |                     "model": model,
611 |                     "messages": messages,
612 |                     "temperature": temperature,
613 |                     "top_p": top_p,
614 |                     "n": n,
615 |                     "stop": stop,
616 |                     "max_completion_tokens": max_completion_tokens,
617 |                     "presence_penalty": presence_penalty,
618 |                     "frequency_penalty": frequency_penalty,
619 |                     "logit_bias": logit_bias,
620 |                     "user": user,
621 |                     "response_format": response_format,
622 |                     "seed": seed,
623 |                     "tools": tools,
624 |                     "tool_choice": tool_choice,
625 |                     **kwargs,
626 |                 }
627 |                 
628 |                 # Remove None values to match OpenAI's behavior
629 |                 body = {k: v for k, v in body.items() if v is not None}
630 |                 
631 |                 # Compute request_id at creation time
632 |                 
633 |                 request_id = compute_request_hash({"type": "chat_completion", **body})
634 |                 order_id = self.batch._next_order_id
635 |                 self.batch._next_order_id += 1
636 |                 
637 |                 # Create custom_id from request_id and order_id
638 |                 custom_id = f"{request_id}#{order_id}"
639 |                 
640 |                 # Create the batch request directly in OpenAI Batch format
641 |                 batch_request = {
642 |                     "custom_id": custom_id,
643 |                     "url": "/v1/chat/completions",
644 |                     "body": body
645 |                 }
646 |                 
647 |                 # Add to batch
648 |                 self.batch.requests.append(batch_request)
649 |                 
650 |                 return request_id
651 |     
652 |     class Embeddings:
653 |         """Embeddings API that mimics OpenAI's interface."""
654 | 
655 |         def __init__(self, batch):
656 |             self.batch = batch
657 | 
658 |         def create(
659 |             self,
660 |             *,
661 |             model: str,
662 |             input: Union[str, list[str]],
663 |             dimensions: Optional[int] = None,
664 |             encoding_format: Optional[str] = None,
665 |             user: Optional[str] = None,
666 |             **kwargs: Any
667 |         ) -> str:
668 |             """Add an embedding request to the batch.
669 |             
670 |             Args:
671 |                 model: The model to use for embeddings (e.g., text-embedding-3-small)
672 |                 input: The text to embed (either a string or a list of strings)
673 |                 dimensions: The number of dimensions to return. Only supported with 
674 |                             text-embedding-3 models. Defaults to the model's max dimensions.
675 |                 encoding_format: The format to return the embeddings in (float or base64)
676 |                 user: A unique identifier for the end-user
677 |                 **kwargs: Additional provider-specific parameters
678 |                 
679 |             Returns:
680 |                 str: The request ID (cache key) for this request
681 |             """
682 |             # Create the request body
683 |             body = {
684 |                 "model": model,
685 |                 "input": input,
686 |                 "dimensions": dimensions,
687 |                 "encoding_format": encoding_format,
688 |                 "user": user,
689 |                 **kwargs,
690 |             }
691 |             
692 |             # Remove None values to match OpenAI's behavior
693 |             body = {k: v for k, v in body.items() if v is not None}
694 |             
695 |             request_id = compute_request_hash({"type": "embedding", **body})
696 |             order_id = self.batch._next_order_id
697 |             self.batch._next_order_id += 1
698 |             
699 |             # Create custom_id from request_id and order_id
700 |             custom_id = f"{request_id}#{order_id}"
701 |             
702 |             # Create the batch request directly in OpenAI Batch format
703 |             batch_request = {
704 |                 "custom_id": custom_id,
705 |                 "url": "/v1/embeddings",
706 |                 "body": body
707 |             }
708 |             
709 |             # Add to batch
710 |             self.batch.requests.append(batch_request)
711 |             
712 |             return request_id
713 | 


--------------------------------------------------------------------------------
/fastllm/providers/__init__.py:
--------------------------------------------------------------------------------
1 | """Provider implementations."""
2 | 
3 | from .base import Provider
4 | from .openai import OpenAIProvider
5 | 
6 | __all__ = ["Provider", "OpenAIProvider"]
7 | 


--------------------------------------------------------------------------------
/fastllm/providers/base.py:
--------------------------------------------------------------------------------
 1 | """Base classes for LLM providers."""
 2 | 
 3 | from abc import ABC, abstractmethod
 4 | from typing import Any, Generic, Optional
 5 | 
 6 | import httpx
 7 | 
 8 | from fastllm.core import ResponseT
 9 | 
10 | class Provider(Generic[ResponseT], ABC):
11 |     """Base class for LLM providers."""
12 | 
13 |     # Internal tracking headers
14 |     _REFERER = "https://github.com/Rexhaif/fastllm"
15 |     _APP_NAME = "FastLLM"
16 | 
17 |     def __init__(
18 |         self,
19 |         api_key: str,
20 |         api_base: str,
21 |         headers: Optional[dict[str, str]] = None,
22 |         **kwargs: Any,
23 |     ):
24 |         self.api_key = api_key
25 |         self.api_base = api_base.rstrip("/")  # Remove trailing slash if present
26 |         self._default_headers = {
27 |             "HTTP-Referer": self._REFERER,
28 |             "X-Title": self._APP_NAME,
29 |         }
30 |         self.headers = {
31 |             **self._default_headers,
32 |             **(headers or {}),
33 |         }
34 | 
35 |     def get_request_url(self, endpoint: str) -> str:
36 |         """Get full URL for API endpoint."""
37 |         return f"{self.api_base}/{endpoint.lstrip('/')}"
38 | 
39 |     @abstractmethod
40 |     def get_request_headers(self) -> dict[str, str]:
41 |         """Get headers for API requests."""
42 |         pass
43 | 
44 |     @abstractmethod
45 |     async def make_request(
46 |         self,
47 |         client: httpx.AsyncClient,
48 |         request: dict[str, Any],
49 |         timeout: float,
50 |     ) -> ResponseT:
51 |         """Make a request to the provider API."""
52 |         pass 


--------------------------------------------------------------------------------
/fastllm/providers/openai.py:
--------------------------------------------------------------------------------
  1 | """OpenAI API provider implementation."""
  2 | 
  3 | from typing import Any, Optional, cast, Type
  4 | 
  5 | import httpx
  6 | from openai.types.chat import ChatCompletion
  7 | from openai.types import CreateEmbeddingResponse
  8 | 
  9 | from fastllm.providers.base import Provider
 10 | 
 11 | DEFAULT_API_BASE = "https://api.openai.com/v1"
 12 | 
 13 | 
 14 | class OpenAIProvider(Provider[ChatCompletion]):
 15 |     """OpenAI provider."""
 16 | 
 17 |     def __init__(
 18 |         self,
 19 |         api_key: str,
 20 |         api_base: str = DEFAULT_API_BASE,
 21 |         organization: Optional[str] = None,
 22 |         headers: Optional[dict[str, str]] = None,
 23 |         **kwargs: Any,
 24 |     ):
 25 |         super().__init__(api_key, api_base, headers, **kwargs)
 26 |         self.organization = organization
 27 | 
 28 |     def get_request_headers(self) -> dict[str, str]:
 29 |         """Get headers for OpenAI API requests."""
 30 |         headers = {
 31 |             "Authorization": f"Bearer {self.api_key}",
 32 |             "Content-Type": "application/json",
 33 |             **self.headers,
 34 |         }
 35 |         if self.organization:
 36 |             headers["OpenAI-Organization"] = self.organization
 37 |         return headers
 38 | 
 39 |     async def make_request(
 40 |         self,
 41 |         client: httpx.AsyncClient,
 42 |         request: dict[str, Any],
 43 |         timeout: float,
 44 |     ) -> ChatCompletion | CreateEmbeddingResponse:
 45 |         """Make a request to the OpenAI API."""
 46 |         # Determine request type from the request or infer from content
 47 |         if isinstance(request, dict):
 48 |             # Extract request type from the request data
 49 |             request_type = request.get("type")
 50 |             if request_type is None:
 51 |                 # Infer type based on content
 52 |                 if "messages" in request:
 53 |                     request_type = "chat_completion"
 54 |                 elif "input" in request:
 55 |                     request_type = "embedding"
 56 |                 else:
 57 |                     request_type = "chat_completion"  # Default
 58 |         else:
 59 |             # Handle unexpected input
 60 |             raise ValueError(f"Unexpected request type: {type(request)}")
 61 |         
 62 |         # Determine API path based on request type
 63 |         if request_type == "embedding":
 64 |             api_path = "embeddings"
 65 |         else:
 66 |             api_path = "chat/completions"
 67 | 
 68 |         url = self.get_request_url(api_path)
 69 |         payload = self._prepare_payload(request, request_type)
 70 | 
 71 |         response = await client.post(
 72 |             url,
 73 |             headers=self.get_request_headers(),
 74 |             json=payload,
 75 |             timeout=timeout,
 76 |         )
 77 |         response.raise_for_status()
 78 |         data = response.json()
 79 | 
 80 |         if request_type == "embedding":
 81 |             return CreateEmbeddingResponse(**data)
 82 |         else:
 83 |             return ChatCompletion(**data)
 84 | 
 85 |         
 86 |     
 87 |     def _prepare_payload(self, request: dict[str, Any], request_type: str) -> dict[str, Any]:
 88 |         """Prepare the API payload from the request data."""
 89 |         # Extract known fields and extra params
 90 |         known_fields = {
 91 |             "provider", "model", "messages", "temperature", "max_completion_tokens",
 92 |             "top_p", "presence_penalty", "frequency_penalty", "stop", "stream",
 93 |             "type", "input", "dimensions", "encoding_format", "user"
 94 |         }
 95 |         
 96 |         # Start with a copy of the request
 97 |         payload = {k: v for k, v in request.items() if k not in ["provider", "type", "_order_id", "_request_id"]}
 98 |         
 99 |         # Handle embedding requests
100 |         if request_type == "embedding":
101 |             # Ensure required fields are present
102 |             if "model" not in payload:
103 |                 raise ValueError("Model is required for embedding requests")
104 |             if "input" not in payload:
105 |                 raise ValueError("Input is required for embedding requests")
106 |             
107 |             # Keep only relevant fields for embeddings
108 |             embedding_fields = {"model", "input", "dimensions", "encoding_format", "user"}
109 |             return {k: v for k, v in payload.items() if k in embedding_fields}
110 |         
111 |         # Handle chat completion requests
112 |         if "model" not in payload:
113 |             raise ValueError("Model is required for chat completion requests")
114 |         if "messages" not in payload:
115 |             raise ValueError("Messages are required for chat completion requests")
116 |         
117 |         # Map max_completion_tokens to max_tokens if present
118 |         if "max_completion_tokens" in payload:
119 |             payload["max_tokens"] = payload.pop("max_completion_tokens")
120 |         
121 |         # Remove any None values
122 |         return {k: v for k, v in payload.items() if v is not None}
123 | 


--------------------------------------------------------------------------------
/justfile:
--------------------------------------------------------------------------------
 1 | set shell := ["bash", "-c"]
 2 | set dotenv-load := true
 3 | 
 4 | # List all available commands
 5 | default:
 6 |     @just --list
 7 | 
 8 | # Install all dependencies using uv
 9 | install:
10 |     uv venv
11 |     uv pip install -e .
12 |     uv pip install -e ".[dev]"
13 | 
14 | # Run tests
15 | test:
16 |     uv run pytest tests/ -v
17 | 
18 | # Format code using ruff
19 | format:
20 |     uv run ruff format .
21 |     uv run ruff check . --fix
22 | 
23 | # Run linting checks
24 | lint:
25 |     uv run ruff check .
26 | 
27 | # Clean up cache files
28 | clean:
29 |     rm -rf .pytest_cache
30 |     rm -rf .coverage
31 |     rm -rf .ruff_cache
32 |     rm -rf dist
33 |     rm -rf build
34 |     rm -rf *.egg-info
35 |     find . -type d -name __pycache__ -exec rm -rf {} +
36 | 
37 | live_test_completions:
38 |     uv run python examples/parallel_test.py --model openai/gpt-4o-mini --repeats 100 --concurrency 75 --cache-type memory --output NO_OUTPUT
39 | 
40 | # Build the package for distribution
41 | build:
42 |     uv build --no-sources
43 | 
44 | # Build and publish to PyPI
45 | publish: clean build
46 |     UV_PUBLISH_TOKEN=$UV_PUBLISH_TOKEN_PROD uv publish
47 | 
48 | # Build and publish to TestPyPI
49 | publish-test: clean build
50 |     UV_PUBLISH_TOKEN=$UV_PUBLISH_TOKEN_TEST uv publish --index testpypi
51 | 


--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
 1 | [project]
 2 | name = "fastllm-kit"
 3 | version = "0.1.9"
 4 | description = "High-performance parallel LLM API request tool with caching and multiple provider support"
 5 | authors = [
 6 |     {name = "Daniil Larionov", email = "rexhaif.io@gmail.com"}
 7 | ]
 8 | readme = "README.md"
 9 | license = {text = "MIT"}
10 | requires-python = ">=3.9"
11 | keywords = ["llm", "ai", "openai", "parallel", "caching", "api"]
12 | classifiers = [
13 |     "Development Status :: 4 - Beta",
14 |     "Intended Audience :: Developers",
15 |     "License :: OSI Approved :: MIT License",
16 |     "Programming Language :: Python :: 3.9",
17 |     "Programming Language :: Python :: 3.10",
18 |     "Programming Language :: Python :: 3.11",
19 |     "Topic :: Software Development :: Libraries :: Python Modules",
20 |     "Topic :: Scientific/Engineering :: Artificial Intelligence",
21 | ]
22 | dependencies = [
23 |     "httpx>=0.27.2",
24 |     "pydantic>=2.10.6",
25 |     "rich>=13.9.4",
26 |     "diskcache>=5.6.3",
27 |     "anyio>=4.8.0",
28 |     "typing_extensions>=4.12.2",
29 |     "tqdm>=4.67.1",
30 |     "nest-asyncio>=1.6.0",
31 |     "openai>=1.61.0",
32 |     "xxhash>=3.0.0",
33 | ]
34 | 
35 | [project.optional-dependencies]
36 | dev = [
37 |     "ruff>=0.3.7",
38 |     "pytest>=8.3.4",
39 |     "pytest-asyncio>=0.23.8",
40 |     "pytest-cov>=4.1.0",
41 |     "black>=24.10.0",
42 |     "coverage>=7.6.10",
43 |     "ipywidgets>=8.1.5",
44 |     "dotenv>=0.9.9",
45 |     "typer>=0.15.2",
46 |     "numpy>=2.0.2",
47 | ]
48 | 
49 | [tool.ruff]
50 | line-length = 88
51 | target-version = "py39"
52 | select = ["E", "F", "B", "I", "N", "UP", "PL", "RUF"]
53 | ignore = []
54 | 
55 | [tool.ruff.format]
56 | quote-style = "double"
57 | indent-style = "space"
58 | skip-magic-trailing-comma = false
59 | line-ending = "auto"
60 | 
61 | [tool.pytest.ini_options]
62 | asyncio_mode = "strict"
63 | asyncio_default_fixture_loop_scope = "function"
64 | 
65 | [build-system]
66 | requires = ["hatchling"]
67 | build-backend = "hatchling.build"
68 | 
69 | [dependency-groups]
70 | dev = [
71 |     "pytest>=8.3.4",
72 |     "pytest-asyncio>=0.25.3",
73 | ]
74 | 
75 | [tool.hatch.build.targets.wheel]
76 | packages = ["fastllm"]
77 | 
78 | [tool.hatch.metadata]
79 | allow-direct-references = true
80 | 
81 | [[tool.uv.index]]
82 | name = "testpypi"
83 | url = "https://test.pypi.org/simple/"
84 | publish-url = "https://test.pypi.org/legacy/"
85 | explicit = true
86 | 


--------------------------------------------------------------------------------
/tests/test_asyncio.py:
--------------------------------------------------------------------------------
  1 | """Tests for async functionality."""
  2 | 
  3 | import asyncio
  4 | import time
  5 | 
  6 | import pytest
  7 | from openai.types.completion_usage import CompletionUsage
  8 | 
  9 | from fastllm.core import RequestManager
 10 | 
 11 | # Constants for testing
 12 | MAX_DURATION = 0.5  # Maximum expected duration for concurrent execution
 13 | EXPECTED_RESPONSES = 5  # Expected number of responses
 14 | EXPECTED_SUCCESSES = 2  # Expected number of successful responses
 15 | EXPECTED_FAILURES = 1  # Expected number of failed responses
 16 | 
 17 | 
 18 | class DummyRequestManager(RequestManager):
 19 |     async def _make_provider_request(self, client, request):
 20 |         await asyncio.sleep(0.1)  # Simulate network delay
 21 |         # Extract message content from request
 22 |         message_content = request.get("messages", [{}])[0].get("content", "No content")
 23 |         response_dict = {
 24 |             "request_id": id(request),  # Add unique request ID
 25 |             "content": f"Response to: {message_content}",
 26 |             "finish_reason": "dummy_end",
 27 |             "provider": "dummy",
 28 |             "raw_response": {"dummy_key": "dummy_value"},
 29 |             "usage": {
 30 |                 "prompt_tokens": 10,
 31 |                 "completion_tokens": 5,
 32 |                 "total_tokens": 15
 33 |             }
 34 |         }
 35 |         return response_dict
 36 | 
 37 | 
 38 | @pytest.mark.asyncio
 39 | async def test_dummy_manager_single_request():
 40 |     manager = DummyRequestManager(provider="dummy")
 41 |     request = {
 42 |         "provider": "dummy", 
 43 |         "messages": [{"role": "user", "content": "Hello async!"}], 
 44 |         "model": "dummy-model"
 45 |     }
 46 |     response = await manager._make_provider_request(None, request)
 47 |     assert response["content"] == "Response to: Hello async!"
 48 |     assert response["finish_reason"] == "dummy_end"
 49 |     assert response["usage"]["total_tokens"] == 15
 50 | 
 51 | 
 52 | @pytest.mark.asyncio
 53 | async def test_dummy_manager_concurrent_requests():
 54 |     manager = DummyRequestManager(provider="dummy")
 55 |     requests = [
 56 |         {
 57 |             "provider": "dummy", 
 58 |             "messages": [{"role": "user", "content": f"Message {i}"}], 
 59 |             "model": "dummy-model"
 60 |         }
 61 |         for i in range(5)
 62 |     ]
 63 | 
 64 |     start = time.perf_counter()
 65 |     responses = await asyncio.gather(
 66 |         *[manager._make_provider_request(None, req) for req in requests]
 67 |     )
 68 |     end = time.perf_counter()
 69 | 
 70 |     # All requests should complete successfully
 71 |     assert len(responses) == 5
 72 |     for i, response in enumerate(responses):
 73 |         assert response["content"] == f"Response to: Message {i}"
 74 |         assert response["finish_reason"] == "dummy_end"
 75 |         assert response["usage"]["total_tokens"] == 15
 76 | 
 77 |     # Requests should be processed concurrently
 78 |     # Total time should be less than sequential time (5 * 0.1s)
 79 |     assert end - start < 0.5  # Allow some overhead
 80 | 
 81 | 
 82 | class FailingDummyManager(RequestManager):
 83 |     async def _make_provider_request(self, client, request):
 84 |         message_content = request.get("messages", [{}])[0].get("content", "")
 85 |         if "fail" in message_content.lower():
 86 |             raise Exception("Provider failure")
 87 |         await asyncio.sleep(0.1)
 88 |         response_dict = {
 89 |             "request_id": id(request),  # Add unique request ID
 90 |             "content": f"Response to: {message_content}",
 91 |             "finish_reason": "dummy_end",
 92 |             "provider": "dummy",
 93 |             "raw_response": {"dummy_key": "dummy_value"},
 94 |             "usage": {
 95 |                 "prompt_tokens": 10,
 96 |                 "completion_tokens": 5,
 97 |                 "total_tokens": 15
 98 |             }
 99 |         }
100 |         return response_dict
101 | 
102 | 
103 | @pytest.mark.asyncio
104 | async def test_dummy_manager_request_failure():
105 |     manager = FailingDummyManager(provider="dummy")
106 |     request = {
107 |         "provider": "dummy", 
108 |         "messages": [{"role": "user", "content": "fail this request"}], 
109 |         "model": "dummy-model"
110 |     }
111 |     with pytest.raises(Exception) as exc_info:
112 |         await manager._make_provider_request(None, request)
113 |     assert "Provider failure" in str(exc_info.value)
114 | 
115 | 
116 | @pytest.mark.asyncio
117 | async def test_gather_with_mixed_success_and_failure():
118 |     manager = FailingDummyManager(provider="dummy")
119 |     requests = [
120 |         {
121 |             "provider": "dummy", 
122 |             "messages": [{"role": "user", "content": "Message 1"}], 
123 |             "model": "dummy-model"
124 |         },
125 |         {
126 |             "provider": "dummy", 
127 |             "messages": [{"role": "user", "content": "fail this one"}], 
128 |             "model": "dummy-model"
129 |         },
130 |         {
131 |             "provider": "dummy", 
132 |             "messages": [{"role": "user", "content": "Message 3"}], 
133 |             "model": "dummy-model"
134 |         },
135 |     ]
136 | 
137 |     responses = await asyncio.gather(
138 |         *[manager._make_provider_request(None, req) for req in requests],
139 |         return_exceptions=True,
140 |     )
141 |     successes = [resp for resp in responses if not isinstance(resp, Exception)]
142 |     failures = [resp for resp in responses if isinstance(resp, Exception)]
143 | 
144 |     assert len(successes) == 2
145 |     assert len(failures) == 1
146 |     assert "Provider failure" in str(failures[0])
147 | 
148 | 
149 | @pytest.mark.asyncio
150 | async def test_task_scheduling_order():
151 |     manager = DummyRequestManager(provider="dummy")
152 |     requests = [
153 |         {
154 |             "provider": "dummy", 
155 |             "messages": [{"role": "user", "content": f"Message {i}"}], 
156 |             "model": "dummy-model"
157 |         }
158 |         for i in range(3)
159 |     ]
160 | 
161 |     # Create tasks but don't await them yet
162 |     tasks = [manager._make_provider_request(None, req) for req in requests]
163 |     
164 |     # Schedule tasks in reverse order
165 |     responses = []
166 |     for task in reversed(tasks):
167 |         responses.append(await task)
168 | 
169 |     # Despite scheduling in reverse order, responses should match request order
170 |     for i, response in enumerate(responses):
171 |         assert f"Message {2-i}" in response["content"]
172 | 
173 | 
174 | @pytest.mark.asyncio
175 | async def test_task_cancellation():
176 |     manager = DummyRequestManager(provider="dummy")
177 |     request = {
178 |         "provider": "dummy", 
179 |         "messages": [{"role": "user", "content": "Cancel me"}], 
180 |         "model": "dummy-model"
181 |     }
182 | 
183 |     # Start the task
184 |     task = asyncio.create_task(manager._make_provider_request(None, request))
185 |     
186 |     # Cancel it immediately
187 |     task.cancel()
188 |     
189 |     with pytest.raises(asyncio.CancelledError):
190 |         await task
191 | 


--------------------------------------------------------------------------------
/tests/test_cache.py:
--------------------------------------------------------------------------------
  1 | import asyncio
  2 | import os
  3 | 
  4 | import pytest
  5 | 
  6 | from fastllm.cache import DiskCache, InMemoryCache, compute_request_hash
  7 | 
  8 | 
  9 | @pytest.mark.asyncio
 10 | async def test_inmemory_cache_put_get_exists_and_clear():
 11 |     cache = InMemoryCache()
 12 |     key = "test_key"
 13 |     value = {"foo": "bar"}
 14 | 
 15 |     # Initially, key should not exist
 16 |     assert not await cache.exists(key)
 17 | 
 18 |     # Put value
 19 |     await cache.put(key, value)
 20 | 
 21 |     # Check existence
 22 |     assert await cache.exists(key)
 23 | 
 24 |     # Get the value back
 25 |     retrieved_val = await cache.get(key)
 26 |     assert retrieved_val == value
 27 | 
 28 |     # Clear cache
 29 |     await cache.clear()
 30 |     assert not await cache.exists(key)
 31 |     with pytest.raises(KeyError):
 32 |         await cache.get(key)
 33 | 
 34 | 
 35 | @pytest.mark.asyncio
 36 | async def test_disk_cache_put_get_exists_and_clear(tmp_path):
 37 |     # Create a temporary directory for the disk cache
 38 |     cache_dir = os.path.join(tmp_path, "disk_cache")
 39 |     os.makedirs(cache_dir, exist_ok=True)
 40 |     cache = DiskCache(directory=cache_dir)
 41 |     key = "disk_test"
 42 |     value = {"alpha": 123}
 43 | 
 44 |     # Initially, key should not exist
 45 |     assert not await cache.exists(key)
 46 | 
 47 |     # Put the value
 48 |     await cache.put(key, value)
 49 | 
 50 |     # Check existence
 51 |     assert await cache.exists(key)
 52 | 
 53 |     # Retrieve the value
 54 |     result = await cache.get(key)
 55 |     assert result == value
 56 | 
 57 |     # Clear the cache and verify
 58 |     await cache.clear()
 59 |     assert not await cache.exists(key)
 60 |     with pytest.raises(KeyError):
 61 |         await cache.get(key)
 62 | 
 63 |     # Close the disk cache
 64 |     await cache.close()
 65 | 
 66 | 
 67 | @pytest.mark.asyncio
 68 | async def test_disk_cache_ttl(tmp_path):
 69 |     # Create temporary directory for disk cache with TTL
 70 |     cache_dir = os.path.join(tmp_path, "disk_cache_ttl")
 71 |     os.makedirs(cache_dir, exist_ok=True)
 72 |     # Set TTL to 1 second
 73 |     cache = DiskCache(directory=cache_dir, ttl=1)
 74 |     key = "disk_ttl"
 75 |     value = "temporary"
 76 |     await cache.put(key, value)
 77 | 
 78 |     # Immediately, key should exist
 79 |     assert await cache.exists(key)
 80 | 
 81 |     # Wait for TTL expiry
 82 |     await asyncio.sleep(1.1)
 83 | 
 84 |     # After TTL expiration, key should be expired (get raises KeyError)
 85 |     with pytest.raises(KeyError):
 86 |         await cache.get(key)
 87 | 
 88 |     await cache.clear()
 89 |     await cache.close()
 90 | 
 91 | 
 92 | def test_compute_request_hash_consistency():
 93 |     req1 = {
 94 |         "provider": "test",
 95 |         "model": "dummy",
 96 |         "messages": [{"role": "user", "content": "Hello"}],
 97 |         "temperature": 0.5,
 98 |         "_request_id": "ignore_me",
 99 |         "extra_param": "value",
100 |     }
101 |     req2 = {
102 |         "provider": "test",
103 |         "model": "dummy",
104 |         "messages": [{"role": "user", "content": "Hello"}],
105 |         "temperature": 0.5,
106 |         "extra_param": "value",
107 |         "_order_id": "should_be_removed",
108 |     }
109 | 
110 |     hash1 = compute_request_hash(req1)
111 |     hash2 = compute_request_hash(req2)
112 | 
113 |     # The hashes should be identical because _request_id and _order_id are removed
114 |     assert hash1 == hash2
115 | 
116 | 
117 | @pytest.mark.asyncio
118 | async def test_disk_cache_invalid_directory(tmp_path):
119 |     """Test DiskCache behavior with invalid directory."""
120 |     # Try to create cache in a non-existent directory that can't be created
121 |     # (using a file as a directory should raise OSError)
122 |     file_path = os.path.join(tmp_path, "file")
123 |     with open(file_path, "w") as f:
124 |         f.write("test")
125 |     
126 |     with pytest.raises(OSError):
127 |         DiskCache(directory=file_path)
128 | 
129 | 
130 | @pytest.mark.asyncio
131 | async def test_disk_cache_concurrent_access(tmp_path):
132 |     """Test concurrent access to DiskCache."""
133 |     cache_dir = os.path.join(tmp_path, "concurrent_cache")
134 |     os.makedirs(cache_dir, exist_ok=True)
135 |     cache = DiskCache(directory=cache_dir)
136 | 
137 |     # Create multiple concurrent operations
138 |     async def concurrent_operation(key, value):
139 |         await cache.put(key, value)
140 |         assert await cache.exists(key)
141 |         result = await cache.get(key)
142 |         assert result == value
143 | 
144 |     # Run multiple operations concurrently
145 |     tasks = [
146 |         concurrent_operation(f"key_{i}", f"value_{i}")
147 |         for i in range(5)
148 |     ]
149 |     await asyncio.gather(*tasks)
150 | 
151 |     # Verify all values are still accessible
152 |     for i in range(5):
153 |         assert await cache.get(f"key_{i}") == f"value_{i}"
154 | 
155 |     await cache.clear()
156 |     await cache.close()
157 | 
158 | 
159 | @pytest.mark.asyncio
160 | async def test_inmemory_cache_large_values():
161 |     """Test InMemoryCache with large values."""
162 |     cache = InMemoryCache()
163 |     key = "large_value"
164 |     # Create a large value (1MB string)
165 |     large_value = "x" * (1024 * 1024)
166 |     
167 |     await cache.put(key, large_value)
168 |     result = await cache.get(key)
169 |     assert result == large_value
170 | 
171 |     await cache.clear()
172 | 
173 | 
174 | def test_compute_request_hash_edge_cases():
175 |     """Test compute_request_hash with edge cases."""
176 |     # Empty request
177 |     empty_req = {}
178 |     empty_hash = compute_request_hash(empty_req)
179 |     assert empty_hash  # Should return a hash even for empty dict
180 | 
181 |     # Request with only internal fields
182 |     internal_req = {
183 |         "_request_id": "123",
184 |         "_order_id": "456",
185 |     }
186 |     internal_hash = compute_request_hash(internal_req)
187 |     assert internal_hash == compute_request_hash({})  # Should be same as empty
188 | 
189 |     # Request with nested structures
190 |     nested_req = {
191 |         "provider": "test",
192 |         "messages": [
193 |             {"role": "system", "content": "You are a bot"},
194 |             {"role": "user", "content": "Hi"},
195 |             {"role": "assistant", "content": "Hello!"},
196 |         ],
197 |         "options": {
198 |             "temperature": 0.7,
199 |             "max_tokens": 100,
200 |             "stop": [".", "!"],
201 |         },
202 |     }
203 |     nested_hash = compute_request_hash(nested_req)
204 |     assert nested_hash  # Should handle nested structures
205 | 
206 |     # Verify order independence
207 |     reordered_req = {
208 |         "messages": [
209 |             {"role": "system", "content": "You are a bot"},
210 |             {"role": "user", "content": "Hi"},
211 |             {"role": "assistant", "content": "Hello!"},
212 |         ],
213 |         "provider": "test",
214 |         "options": {
215 |             "max_tokens": 100,
216 |             "temperature": 0.7,
217 |             "stop": [".", "!"],
218 |         },
219 |     }
220 |     reordered_hash = compute_request_hash(reordered_req)
221 |     assert nested_hash == reordered_hash  # Hash should be independent of field order
222 | 


--------------------------------------------------------------------------------
/tests/test_core.py:
--------------------------------------------------------------------------------
  1 | """Tests for core functionality."""
  2 | 
  3 | import time
  4 | from unittest import mock
  5 | 
  6 | import time
  7 | from unittest import mock
  8 | 
  9 | import pytest
 10 | 
 11 | from fastllm.core import (
 12 |     RequestBatch,
 13 |     RequestManager,
 14 |     ResponseWrapper,
 15 |     TokenStats,
 16 |     ProgressTracker,
 17 | )
 18 | from fastllm.cache import InMemoryCache, compute_request_hash
 19 | 
 20 | # Constants for testing
 21 | DEFAULT_CHUNK_SIZE = 20
 22 | DEFAULT_MAX_CHUNK_SIZE = 1000
 23 | DEFAULT_PROMPT_TOKENS = 10
 24 | DEFAULT_COMPLETION_TOKENS = 5
 25 | DEFAULT_TOTAL_TOKENS = DEFAULT_PROMPT_TOKENS + DEFAULT_COMPLETION_TOKENS
 26 | DEFAULT_RETRY_ATTEMPTS = 2
 27 | DEFAULT_CONCURRENCY = 5
 28 | DEFAULT_TIMEOUT = 1.0
 29 | DEFAULT_RETRY_DELAY = 0.0
 30 | 
 31 | # TokenStats Tests
 32 | def test_token_stats():
 33 |     """Test basic TokenStats functionality."""
 34 |     ts = TokenStats(start_time=time.time() - 2)  # started 2 seconds ago
 35 |     assert ts.cache_hit_ratio == 0.0
 36 |     ts.update(DEFAULT_PROMPT_TOKENS, DEFAULT_COMPLETION_TOKENS, is_cache_hit=True)
 37 |     assert ts.prompt_tokens == DEFAULT_PROMPT_TOKENS
 38 |     assert ts.completion_tokens == DEFAULT_COMPLETION_TOKENS
 39 |     assert ts.total_tokens == DEFAULT_TOTAL_TOKENS
 40 |     assert ts.requests_completed == 1
 41 |     assert ts.cache_hits == 1
 42 |     assert ts.prompt_tokens_per_second > 0
 43 |     assert ts.completion_tokens_per_second > 0
 44 | 
 45 | 
 46 | def test_token_stats_rate_limits():
 47 |     """Test rate limit tracking in TokenStats."""
 48 |     current_time = time.time()
 49 |     stats = TokenStats(
 50 |         start_time=current_time,
 51 |         token_limit=1000,  # 1000 tokens per minute
 52 |         request_limit=100,  # 100 requests per minute
 53 |     )
 54 | 
 55 |     # Test initial state
 56 |     assert stats.token_saturation == 0.0
 57 |     assert stats.request_saturation == 0.0
 58 | 
 59 |     # Update with some usage (non-cache hits)
 60 |     stats.update(50, 50, is_cache_hit=False)  # 100 tokens total
 61 |     stats.update(25, 25, is_cache_hit=False)  # 50 tokens total
 62 | 
 63 |     # Cache hits should not affect rate limit tracking
 64 |     stats.update(100, 100, is_cache_hit=True)
 65 |     assert stats.window_tokens == 150  # Still 150 from non-cache hits
 66 |     assert stats.window_requests == 2  # Still 2 from non-cache hits
 67 | 
 68 | # RequestManager Tests
 69 | class DummyProvider:
 70 |     async def make_request(self, client, request, timeout):
 71 |         return {
 72 |             "content": "Test response",
 73 |             "finish_reason": "stop",
 74 |             "usage": {
 75 |                 "prompt_tokens": 10,
 76 |                 "completion_tokens": 5,
 77 |                 "total_tokens": 15
 78 |             }
 79 |         }
 80 | 
 81 | 
 82 | @pytest.mark.asyncio
 83 | async def test_request_manager():
 84 |     """Test basic RequestManager functionality."""
 85 |     provider = DummyProvider()
 86 |     manager = RequestManager(provider=provider)
 87 |     
 88 |     # Create a request dictionary
 89 |     request = {
 90 |         "provider": "dummy",
 91 |         "messages": [{"role": "user", "content": "Test message"}],
 92 |         "model": "dummy-model"
 93 |     }
 94 |     
 95 |     # Test make_provider_request
 96 |     response = await manager._make_provider_request(None, request)
 97 |     assert response["content"] == "Test response"
 98 |     assert response["finish_reason"] == "stop"
 99 |     assert response["usage"]["total_tokens"] == 15
100 | 
101 | 
102 | class FailingProvider:
103 |     async def make_request(self, client, request, timeout):
104 |         raise Exception("Provider error")
105 | 
106 | 
107 | @pytest.mark.asyncio
108 | async def test_request_manager_failure():
109 |     """Test RequestManager with failing provider."""
110 |     provider = FailingProvider()
111 |     manager = RequestManager(provider=provider)
112 |     request = {
113 |         "provider": "dummy",
114 |         "messages": [{"role": "user", "content": "Test message"}],
115 |         "model": "dummy-model"
116 |     }
117 |     
118 |     with pytest.raises(Exception) as exc_info:
119 |         await manager._make_provider_request(None, request)
120 |     assert "Provider error" in str(exc_info.value)
121 | 
122 | 
123 | def test_progress_tracker_update():
124 |     """Test ProgressTracker updates."""
125 |     tracker = ProgressTracker(total_requests=10, show_progress=False)
126 |     tracker.update(DEFAULT_PROMPT_TOKENS, DEFAULT_COMPLETION_TOKENS)
127 |     assert tracker.stats.prompt_tokens == DEFAULT_PROMPT_TOKENS
128 |     assert tracker.stats.completion_tokens == DEFAULT_COMPLETION_TOKENS
129 |     assert tracker.stats.total_tokens == DEFAULT_TOTAL_TOKENS
130 | 
131 | 
132 | def test_progress_tracker_context_manager():
133 |     """Test ProgressTracker as context manager."""
134 |     with ProgressTracker(total_requests=10, show_progress=False) as tracker:
135 |         tracker.update(DEFAULT_PROMPT_TOKENS, DEFAULT_COMPLETION_TOKENS)
136 |     assert tracker.stats.prompt_tokens == DEFAULT_PROMPT_TOKENS
137 | 
138 | 
139 | def test_progress_tracker_with_limits():
140 |     """Test ProgressTracker with rate limits."""
141 |     tracker = ProgressTracker(total_requests=10, show_progress=False)
142 |     tracker.stats.token_limit = 1000
143 |     tracker.stats.request_limit = 100
144 |     tracker.update(DEFAULT_PROMPT_TOKENS, DEFAULT_COMPLETION_TOKENS)
145 |     assert tracker.stats.token_saturation > 0
146 |     assert tracker.stats.request_saturation > 0
147 | 
148 | 
149 | class CachingProvider:
150 |     """Test provider that always returns the same response."""
151 |     
152 |     def __init__(self):
153 |         self.call_count = 0
154 |     
155 |     async def make_request(self, client, request, timeout):
156 |         self.call_count += 1
157 |         return {
158 |             "content": "Cached response",
159 |             "finish_reason": "stop",
160 |             "usage": {
161 |                 "prompt_tokens": 10,
162 |                 "completion_tokens": 5,
163 |                 "total_tokens": 15
164 |             }
165 |         }
166 | 
167 | 
168 | @pytest.mark.asyncio
169 | async def test_request_manager_caching():
170 |     """Test RequestManager with caching."""
171 |     provider = CachingProvider()
172 |     cache = InMemoryCache()
173 |     manager = RequestManager(
174 |         provider=provider,
175 |         caching_provider=cache,
176 |         show_progress=False,
177 |     )
178 |     
179 |     # Create identical requests
180 |     request1 = {
181 |         "provider": "dummy",
182 |         "messages": [{"role": "user", "content": "Test message"}],
183 |         "model": "dummy-model"
184 |     }
185 |     
186 |     # Add request_id
187 |     request_id = compute_request_hash(request1)
188 |     request1["_request_id"] = request_id
189 |     
190 |     # First request should hit the provider
191 |     response1 = await manager._process_request_async(None, request1, None)
192 |     assert provider.call_count == 1
193 |     
194 |     # Second request with the same content should hit the cache
195 |     request2 = request1.copy()
196 |     response2 = await manager._process_request_async(None, request2, None)
197 |     assert provider.call_count == 1  # Should not increase
198 |     
199 |     # Different request should hit the provider
200 |     request3 = {
201 |         "provider": "dummy",
202 |         "messages": [{"role": "user", "content": "Different message"}],
203 |         "model": "dummy-model"
204 |     }
205 |     request3["_request_id"] = compute_request_hash(request3)
206 |     
207 |     response3 = await manager._process_request_async(None, request3, None)
208 |     assert provider.call_count == 2
209 | 
210 | 
211 | @pytest.mark.asyncio
212 | async def test_request_manager_cache_errors():
213 |     """Test RequestManager handles cache errors gracefully."""
214 |     
215 |     class ErrorCache(InMemoryCache):
216 |         async def exists(self, key: str) -> bool:
217 |             raise Exception("Cache error")
218 |         
219 |         async def get(self, key: str):
220 |             raise Exception("Cache error")
221 |         
222 |         async def put(self, key: str, value) -> None:
223 |             raise Exception("Cache error")
224 |     
225 |     provider = DummyProvider()
226 |     manager = RequestManager(
227 |         provider=provider,
228 |         caching_provider=ErrorCache(),
229 |         show_progress=False,
230 |     )
231 |     
232 |     # Create a request
233 |     request = {
234 |         "provider": "dummy",
235 |         "messages": [{"role": "user", "content": "Test message"}],
236 |         "model": "dummy-model",
237 |         "_request_id": "test-id",
238 |     }
239 |     
240 |     # Request should succeed despite cache errors
241 |     response = await manager._process_request_async(None, request, None)
242 |     # Access response content from the wrapped response
243 |     assert response.response["content"] == "Test response"
244 | 
245 | 
246 | @pytest.mark.asyncio
247 | async def test_request_manager_failed_response_not_cached():
248 |     """Test that failed responses are not cached."""
249 |     
250 |     class FailingProvider:
251 |         def __init__(self, fail_first=True):
252 |             self.call_count = 0
253 |             self.fail_first = fail_first
254 |             self.has_failed = False
255 |         
256 |         async def make_request(self, client, request, timeout):
257 |             self.call_count += 1
258 |             
259 |             if self.fail_first and not self.has_failed:
260 |                 self.has_failed = True
261 |                 raise Exception("Provider failure")
262 |             
263 |             return {
264 |                 "content": "Success response",
265 |                 "finish_reason": "stop",
266 |                 "usage": {
267 |                     "prompt_tokens": 10,
268 |                     "completion_tokens": 5,
269 |                     "total_tokens": 15
270 |                 }
271 |             }
272 |     
273 |     provider = FailingProvider()
274 |     cache = InMemoryCache()
275 |     manager = RequestManager(
276 |         provider=provider,
277 |         caching_provider=cache,
278 |         retry_attempts=1,  # Only retry once
279 |         show_progress=False,
280 |     )
281 |     
282 |     # Create a request
283 |     request = {
284 |         "provider": "dummy",
285 |         "messages": [{"role": "user", "content": "Test message"}],
286 |         "model": "dummy-model",
287 |         "_request_id": "test-id",
288 |     }
289 |     
290 |     # First attempt should fail
291 |     with pytest.raises(Exception):
292 |         await manager._process_request_async(None, request, None)
293 |     
294 |     # Verify the call was made
295 |     assert provider.call_count == 1
296 |     
297 |     # Second attempt with same request should hit the provider again
298 |     # since the failed response should not be cached
299 |     try:
300 |         await manager._process_request_async(None, request, None)
301 |     except:
302 |         pass
303 |     
304 |     assert provider.call_count == 2
305 | 


--------------------------------------------------------------------------------
/tests/test_openai.py:
--------------------------------------------------------------------------------
  1 | import httpx
  2 | import pytest
  3 | 
  4 | from fastllm.providers.openai import OpenAIProvider
  5 | 
  6 | # Constants for testing
  7 | DEFAULT_TEMPERATURE = 0.9
  8 | DEFAULT_TEMPERATURE_ALT = 0.6
  9 | HTTP_OK_MIN = 200
 10 | HTTP_OK_MAX = 300
 11 | 
 12 | 
 13 | def test_prepare_payload_from_simple_request():
 14 |     # Test preparing a simple request payload
 15 |     provider = OpenAIProvider(api_key="testkey")
 16 |     request = {
 17 |         "model": "gpt-3.5-turbo",
 18 |         "messages": [{"role": "user", "content": "Test message"}],
 19 |         "temperature": DEFAULT_TEMPERATURE
 20 |     }
 21 |     payload = provider._prepare_payload(request, "chat_completion")
 22 |     assert payload["model"] == "gpt-3.5-turbo"
 23 |     assert "messages" in payload
 24 |     messages = payload["messages"]
 25 |     assert isinstance(messages, list)
 26 |     assert messages[0]["role"] == "user"
 27 |     assert messages[0]["content"] == "Test message"
 28 |     assert payload["temperature"] == DEFAULT_TEMPERATURE
 29 | 
 30 | 
 31 | def test_prepare_payload_from_system_message():
 32 |     # Test preparing a payload with a system message
 33 |     provider = OpenAIProvider(api_key="testkey")
 34 |     request = {
 35 |         "model": "gpt-3.5-turbo",
 36 |         "messages": [{"role": "system", "content": "System message"}],
 37 |     }
 38 |     payload = provider._prepare_payload(request, "chat_completion")
 39 |     assert payload["model"] == "gpt-3.5-turbo"
 40 |     messages = payload["messages"]
 41 |     assert len(messages) == 1
 42 |     assert messages[0]["role"] == "system"
 43 |     assert messages[0]["content"] == "System message"
 44 | 
 45 | 
 46 | def test_prepare_payload_omits_none_values():
 47 |     # Test that None values are omitted from the payload
 48 |     provider = OpenAIProvider(api_key="testkey")
 49 |     request = {
 50 |         "model": "gpt-3.5-turbo",
 51 |         "messages": [{"role": "user", "content": "Ignore None"}],
 52 |         "top_p": None, 
 53 |         "stop": None
 54 |     }
 55 |     payload = provider._prepare_payload(request, "chat_completion")
 56 |     # top_p and stop should not be in payload if they are None
 57 |     assert "top_p" not in payload
 58 |     assert "stop" not in payload
 59 | 
 60 | 
 61 | def test_prepare_payload_omits_internal_tracking_ids():
 62 |     # Test that internal tracking ids are never sent to providers
 63 |     provider = OpenAIProvider(api_key="testkey")
 64 |     request = {
 65 |         "model": "gpt-3.5-turbo",
 66 |         "messages": [{"role": "user", "content": "Hello world"}],
 67 |         "_order_id": 123,
 68 |         "_request_id": "abcd1234"
 69 |     }
 70 |     payload = provider._prepare_payload(request, "chat_completion")
 71 |     # _order_id and _request_id should not be in payload
 72 |     assert "_order_id" not in payload
 73 |     assert "_request_id" not in payload
 74 | 
 75 | 
 76 | def test_prepare_payload_with_extra_params():
 77 |     # Test that extra parameters are included in the payload
 78 |     provider = OpenAIProvider(api_key="testkey")
 79 |     request = {
 80 |         "model": "gpt-3.5-turbo",
 81 |         "messages": [{"role": "user", "content": "Extra params"}],
 82 |         "custom_param": "custom_value"
 83 |     }
 84 |     payload = provider._prepare_payload(request, "chat_completion")
 85 |     assert "custom_param" in payload
 86 |     assert payload["custom_param"] == "custom_value"
 87 | 
 88 | 
 89 | def test_openai_provider_get_request_url():
 90 |     # Test that the OpenAIProvider constructs the correct request URL
 91 |     provider = OpenAIProvider(api_key="testkey", api_base="https://api.openai.com")
 92 |     url = provider.get_request_url("completions")
 93 |     assert url == "https://api.openai.com/completions"
 94 | 
 95 | 
 96 | def test_openai_provider_get_request_headers():
 97 |     provider = OpenAIProvider(
 98 |         api_key="testkey",
 99 |         api_base="https://api.openai.com",
100 |         organization="org-123",
101 |         headers={"X-Custom": "custom-value"},
102 |     )
103 |     headers = provider.get_request_headers()
104 |     assert headers["Authorization"] == "Bearer testkey"
105 |     assert headers["Content-Type"] == "application/json"
106 |     assert headers["OpenAI-Organization"] == "org-123"
107 |     assert headers["X-Custom"] == "custom-value"
108 | 
109 | 
110 | def test_prepare_payload_for_embeddings():
111 |     # Test preparing a payload for embeddings
112 |     provider = OpenAIProvider(api_key="testkey")
113 |     request = {
114 |         "model": "text-embedding-ada-002",
115 |         "input": "Test input",
116 |         "dimensions": 1536,
117 |         "user": "test-user"
118 |     }
119 |     payload = provider._prepare_payload(request, "embedding")
120 |     assert payload["model"] == "text-embedding-ada-002"
121 |     assert payload["input"] == "Test input"
122 |     assert payload["dimensions"] == 1536
123 |     assert payload["user"] == "test-user"
124 | 
125 | 
126 | def test_prepare_payload_for_embeddings_with_array_input():
127 |     # Test preparing a payload for embeddings with array input
128 |     provider = OpenAIProvider(api_key="testkey")
129 |     request = {
130 |         "model": "text-embedding-ada-002",
131 |         "input": ["Test input 1", "Test input 2"],
132 |         "encoding_format": "float"
133 |     }
134 |     payload = provider._prepare_payload(request, "embedding")
135 |     assert payload["model"] == "text-embedding-ada-002"
136 |     assert isinstance(payload["input"], list)
137 |     assert len(payload["input"]) == 2
138 |     assert payload["input"][0] == "Test input 1"
139 |     assert payload["input"][1] == "Test input 2"
140 |     assert payload["encoding_format"] == "float"
141 | 
142 | 
143 | def test_map_max_completion_tokens():
144 |     # Test that max_completion_tokens is properly mapped to max_tokens
145 |     provider = OpenAIProvider(api_key="testkey")
146 |     request = {
147 |         "model": "gpt-3.5-turbo",
148 |         "messages": [{"role": "user", "content": "Test message"}],
149 |         "max_completion_tokens": 100
150 |     }
151 |     payload = provider._prepare_payload(request, "chat_completion")
152 |     assert "max_tokens" in payload
153 |     assert payload["max_tokens"] == 100
154 |     assert "max_completion_tokens" not in payload
155 | 
156 | 
157 | class FakeResponse:
158 |     def __init__(self, json_data, status_code=HTTP_OK_MIN):
159 |         self._json_data = json_data
160 |         self.status_code = status_code
161 | 
162 |     def raise_for_status(self):
163 |         if not (HTTP_OK_MIN <= self.status_code < HTTP_OK_MAX):
164 |             raise httpx.HTTPStatusError("Error", request=None, response=self)
165 | 
166 |     def json(self):
167 |         return self._json_data
168 | 
169 | 
170 | class FakeAsyncClient:
171 |     async def post(self, url, headers, json, timeout):
172 |         # Return appropriate fake response based on request type
173 |         if "embeddings" in url:
174 |             # Return a fake embeddings response
175 |             fake_json = {
176 |                 "object": "list",
177 |                 "data": [
178 |                     {
179 |                         "object": "embedding",
180 |                         "embedding": [0.1, 0.2, 0.3],
181 |                         "index": 0
182 |                     }
183 |                 ],
184 |                 "model": "text-embedding-ada-002",
185 |                 "usage": {"prompt_tokens": 8, "total_tokens": 8}
186 |             }
187 |         else:
188 |             # Return a fake chat completion response
189 |             fake_json = {
190 |                 "id": "chatcmpl-xyz",
191 |                 "object": "chat.completion",
192 |                 "model": "gpt-3.5-turbo",
193 |                 "created": 1690000000,
194 |                 "choices": [
195 |                     {
196 |                         "index": 0,
197 |                         "message": {"role": "assistant", "content": "Test reply"},
198 |                         "finish_reason": "stop",
199 |                     }
200 |                 ],
201 |                 "usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3},
202 |             }
203 |         return FakeResponse(fake_json)
204 | 
205 | 
206 | @pytest.mark.asyncio
207 | async def test_openai_provider_make_request():
208 |     provider = OpenAIProvider(api_key="testkey", api_base="https://api.openai.com")
209 |     fake_client = FakeAsyncClient()
210 |     request_data = {
211 |         "model": "gpt-3.5-turbo",
212 |         "messages": [{"role": "user", "content": "Tell me a joke."}],
213 |         "temperature": 0.5,
214 |     }
215 |     # Pass a dict directly to make_request
216 |     result = await provider.make_request(fake_client, request_data, timeout=1.0)
217 |     # Check that the result has the expected fake response data
218 |     assert result.id == "chatcmpl-xyz"
219 |     assert result.object == "chat.completion"
220 |     assert isinstance(result.choices, list)
221 |     # Access the content via attributes
222 |     assert result.choices[0].message.content == "Test reply"
223 | 
224 | 
225 | @pytest.mark.asyncio
226 | async def test_openai_provider_make_embedding_request():
227 |     provider = OpenAIProvider(api_key="testkey", api_base="https://api.openai.com")
228 |     fake_client = FakeAsyncClient()
229 |     request_data = {
230 |         "model": "text-embedding-ada-002",
231 |         "input": "Sample text for embedding",
232 |         "type": "embedding"  # Add type indicator for path determination
233 |     }
234 |     # Pass an embedding request to make_request
235 |     result = await provider.make_request(
236 |         fake_client, 
237 |         request_data, 
238 |         timeout=1.0
239 |     )
240 |     # For embeddings, we expect a dict since it doesn't parse as ChatCompletion
241 |     assert isinstance(result, dict)
242 |     assert result["object"] == "list"
243 |     assert "data" in result
244 |     assert isinstance(result["data"], list)
245 |     assert len(result["data"]) > 0
246 |     assert "embedding" in result["data"][0]
247 |     assert isinstance(result["data"][0]["embedding"], list)
248 | 
249 | 
250 | class FakeAsyncClientError:
251 |     async def post(self, url, headers, json, timeout):
252 |         # Return a fake response with an error status code
253 |         return FakeResponse({"error": "Bad Request"}, status_code=400)
254 | 
255 | 
256 | @pytest.mark.asyncio
257 | async def test_openai_provider_make_request_error():
258 |     provider = OpenAIProvider(api_key="testkey", api_base="https://api.openai.com")
259 |     fake_client_error = FakeAsyncClientError()
260 |     request_data = {
261 |         "model": "gpt-3.5-turbo",
262 |         "messages": [{"role": "user", "content": "This will error."}],
263 |         "temperature": 0.5,
264 |     }
265 |     with pytest.raises(httpx.HTTPStatusError):
266 |         await provider.make_request(fake_client_error, request_data, timeout=1.0)
267 | 
268 | 
269 | def test_embedding_request_validation():
270 |     # Test validation for embedding requests
271 |     provider = OpenAIProvider(api_key="testkey")
272 |     
273 |     # Test missing model
274 |     with pytest.raises(ValueError):
275 |         provider._prepare_payload({"input": "test"}, "embedding")
276 |     
277 |     # Test missing input
278 |     with pytest.raises(ValueError):
279 |         provider._prepare_payload({"model": "text-embedding-ada-002"}, "embedding")
280 |     
281 |     # Test valid minimal embedding request
282 |     payload = provider._prepare_payload({
283 |         "model": "text-embedding-ada-002",
284 |         "input": "test"
285 |     }, "embedding")
286 |     assert payload == {"model": "text-embedding-ada-002", "input": "test"}
287 | 
288 | 
289 | def test_embedding_request_type_detection():
290 |     # Test automatic detection of embedding requests
291 |     provider = OpenAIProvider(api_key="testkey")
292 |     fake_client = FakeAsyncClient()
293 |     
294 |     # Request with input field should be detected as embedding
295 |     request = {
296 |         "model": "text-embedding-ada-002",
297 |         "input": "test",
298 |     }
299 |     
300 |     # Check that the request is correctly identified as embedding type
301 |     payload = provider._prepare_payload(request, "embedding")
302 |     assert payload == {"model": "text-embedding-ada-002", "input": "test"}
303 | 


--------------------------------------------------------------------------------
/tests/test_progress_tracker.py:
--------------------------------------------------------------------------------
 1 | import time
 2 | 
 3 | from fastllm.core import ProgressTracker
 4 | 
 5 | # Constants for testing
 6 | INITIAL_PROMPT_TOKENS = 10
 7 | INITIAL_COMPLETION_TOKENS = 20
 8 | INITIAL_TOTAL_TOKENS = INITIAL_PROMPT_TOKENS + INITIAL_COMPLETION_TOKENS
 9 | 
10 | UPDATE_PROMPT_TOKENS = 5
11 | UPDATE_COMPLETION_TOKENS = 5
12 | FINAL_PROMPT_TOKENS = INITIAL_PROMPT_TOKENS + UPDATE_PROMPT_TOKENS
13 | FINAL_COMPLETION_TOKENS = INITIAL_COMPLETION_TOKENS + UPDATE_COMPLETION_TOKENS
14 | FINAL_TOTAL_TOKENS = FINAL_PROMPT_TOKENS + FINAL_COMPLETION_TOKENS
15 | 
16 | REQUESTS_COMPLETED = 2
17 | 
18 | 
19 | def test_progress_tracker():
20 |     tracker = ProgressTracker(total_requests=1, show_progress=False)
21 | 
22 |     # Initial update
23 |     tracker.update(INITIAL_PROMPT_TOKENS, INITIAL_COMPLETION_TOKENS, False)
24 | 
25 |     # Assert that the token stats are updated
26 |     assert tracker.stats.prompt_tokens == INITIAL_PROMPT_TOKENS
27 |     assert tracker.stats.completion_tokens == INITIAL_COMPLETION_TOKENS
28 |     assert tracker.stats.total_tokens == INITIAL_TOTAL_TOKENS
29 | 
30 |     # Test cache hit scenario
31 |     tracker.update(UPDATE_PROMPT_TOKENS, UPDATE_COMPLETION_TOKENS, True)
32 |     assert tracker.stats.prompt_tokens == FINAL_PROMPT_TOKENS
33 |     assert tracker.stats.completion_tokens == FINAL_COMPLETION_TOKENS
34 |     assert tracker.stats.total_tokens == FINAL_TOTAL_TOKENS
35 | 
36 |     # The cache hit count should be incremented by 1
37 |     assert tracker.stats.cache_hits == 1
38 | 
39 |     # Test that requests_completed is incremented correctly
40 |     # Initially, it was 0, then two updates
41 |     assert tracker.stats.requests_completed == REQUESTS_COMPLETED
42 | 
43 | 
44 | def test_progress_tracker_update():
45 |     # Create a ProgressTracker with a fixed total_requests and disable progress display
46 |     tracker = ProgressTracker(total_requests=5, show_progress=False)
47 |     # Simulate 1 second elapsed
48 |     tracker.stats.start_time = time.time() - 1
49 | 
50 |     # Update tracker with some token counts
51 |     tracker.update(INITIAL_PROMPT_TOKENS, INITIAL_COMPLETION_TOKENS, False)
52 | 
53 |     # Assert that the token stats are updated
54 |     assert tracker.stats.prompt_tokens == INITIAL_PROMPT_TOKENS
55 |     assert tracker.stats.completion_tokens == INITIAL_COMPLETION_TOKENS
56 |     assert tracker.stats.total_tokens == INITIAL_TOTAL_TOKENS
57 | 
58 |     # Test cache hit scenario
59 |     tracker.update(UPDATE_PROMPT_TOKENS, UPDATE_COMPLETION_TOKENS, True)
60 |     assert tracker.stats.prompt_tokens == FINAL_PROMPT_TOKENS
61 |     assert tracker.stats.completion_tokens == FINAL_COMPLETION_TOKENS
62 |     assert tracker.stats.total_tokens == FINAL_TOTAL_TOKENS
63 |     # The cache hit count should be incremented by 1
64 |     assert tracker.stats.cache_hits == 1
65 | 
66 |     # Test that requests_completed is incremented correctly
67 |     # Initially, it was 0, then two updates
68 |     assert tracker.stats.requests_completed == REQUESTS_COMPLETED
69 | 


--------------------------------------------------------------------------------
/tests/test_providers.py:
--------------------------------------------------------------------------------
  1 | from fastllm.providers.base import Provider
  2 | from fastllm.providers.openai import OpenAIProvider
  3 | from openai.types.chat import ChatCompletion
  4 | from openai.types.chat.chat_completion import Choice, ChatCompletionMessage
  5 | from openai.types.completion_usage import CompletionUsage
  6 | 
  7 | 
  8 | class DummyProvider(Provider):
  9 |     def __init__(
 10 |         self,
 11 |         api_key="dummy",
 12 |         api_base="https://api.example.com",
 13 |         organization=None,
 14 |         headers=None,
 15 |     ):
 16 |         super().__init__(api_key, api_base, headers)
 17 |         self.organization = organization
 18 | 
 19 |     def get_request_headers(self):
 20 |         headers = {
 21 |             "Authorization": f"Bearer {self.api_key}",
 22 |             "Content-Type": "application/json",
 23 |         }
 24 |         if self.organization:
 25 |             headers["OpenAI-Organization"] = self.organization
 26 |         return headers
 27 | 
 28 |     async def _make_actual_request(self, client, request, timeout):
 29 |         # Simulate a dummy response
 30 |         return ChatCompletion(
 31 |             id="chatcmpl-dummy",
 32 |             object="chat.completion",
 33 |             created=1234567890,
 34 |             model=request.get("model", "dummy-model"),
 35 |             choices=[
 36 |                 Choice(
 37 |                     index=0,
 38 |                     message=ChatCompletionMessage(
 39 |                         role="assistant",
 40 |                         content="This is a dummy response",
 41 |                     ),
 42 |                     finish_reason="stop",
 43 |                 )
 44 |             ],
 45 |             usage=CompletionUsage(
 46 |                 prompt_tokens=10,
 47 |                 completion_tokens=10,
 48 |                 total_tokens=20,
 49 |             ),
 50 |         )
 51 | 
 52 |     async def make_request(self, client, request, timeout):
 53 |         return await self._make_actual_request(client, request, timeout)
 54 | 
 55 | 
 56 | def test_dummy_provider_get_request_url():
 57 |     provider = DummyProvider(api_key="testkey", api_base="https://api.test.com")
 58 |     url = provider.get_request_url("endpoint")
 59 |     assert url == "https://api.test.com/endpoint"
 60 | 
 61 | 
 62 | def test_dummy_provider_get_request_headers():
 63 |     provider = DummyProvider(api_key="testkey")
 64 |     headers = provider.get_request_headers()
 65 |     assert headers["Authorization"] == "Bearer testkey"
 66 |     assert headers["Content-Type"] == "application/json"
 67 | 
 68 | 
 69 | def test_openai_provider_get_request_headers_org():
 70 |     provider = OpenAIProvider(api_key="testkey", organization="org123")
 71 |     headers = provider.get_request_headers()
 72 |     assert headers["Authorization"] == "Bearer testkey"
 73 |     assert headers["Content-Type"] == "application/json"
 74 |     assert headers.get("OpenAI-Organization") == "org123"
 75 | 
 76 | 
 77 | def test_openai_provider_prepare_chat_completion_payload():
 78 |     # Test conversion from a simple request to payload
 79 |     provider = OpenAIProvider(api_key="testkey")
 80 |     request = {
 81 |         "model": "gpt-dummy",
 82 |         "messages": [{"role": "user", "content": "Hello world!"}]
 83 |     }
 84 |     payload = provider._prepare_payload(request, "chat_completion")
 85 |     
 86 |     # Verify that the model and messages are set correctly
 87 |     assert payload["model"] == "gpt-dummy"
 88 |     assert "messages" in payload
 89 |     assert isinstance(payload["messages"], list)
 90 |     assert payload["messages"][0]["role"] == "user"
 91 |     assert payload["messages"][0]["content"] == "Hello world!"
 92 | 
 93 | 
 94 | def test_openai_provider_prepare_embedding_payload():
 95 |     # Test conversion from embedding request to payload
 96 |     provider = OpenAIProvider(api_key="testkey")
 97 |     request = {
 98 |         "model": "text-embedding-3-small",
 99 |         "input": ["Sample text 1", "Sample text 2"]
100 |     }
101 |     payload = provider._prepare_payload(request, "embedding")
102 |     
103 |     # Verify model and input are set correctly
104 |     assert payload["model"] == "text-embedding-3-small"
105 |     assert payload["input"] == ["Sample text 1", "Sample text 2"]
106 | 


--------------------------------------------------------------------------------
/tests/test_request_batch.py:
--------------------------------------------------------------------------------
  1 | """Tests for request batching functionality."""
  2 | 
  3 | import pytest
  4 | from fastllm.core import RequestBatch
  5 | from fastllm.cache import compute_request_hash
  6 | 
  7 | # Constants for testing
  8 | EXPECTED_REQUESTS = 2
  9 | FIRST_REQUEST_ID = 0
 10 | SECOND_REQUEST_ID = 1
 11 | 
 12 | # Constants for batch addition testing
 13 | BATCH_SIZE_ONE = 1
 14 | BATCH_SIZE_TWO = 2
 15 | BATCH_SIZE_THREE = 3
 16 | FIRST_BATCH_START_ID = 0
 17 | SECOND_BATCH_START_ID = 1
 18 | THIRD_BATCH_START_ID = 2
 19 | 
 20 | # Constants for multiple additions testing
 21 | INITIAL_BATCH_SIZE = 1
 22 | FINAL_BATCH_SIZE = 3
 23 | FIRST_MULTIPLE_ID = 0
 24 | SECOND_MULTIPLE_ID = 1
 25 | THIRD_MULTIPLE_ID = 2
 26 | 
 27 | 
 28 | def test_request_batch():
 29 |     """Test basic request batch functionality and request_id generation."""
 30 |     batch = RequestBatch()
 31 |     request_id = batch.chat.completions.create(
 32 |         model="dummy-model",
 33 |         messages=[{"role": "user", "content": "Hi"}],
 34 |     )
 35 |     
 36 |     # Verify request was added
 37 |     assert len(batch.requests) == 1
 38 |     
 39 |     # Verify OpenAI Batch format
 40 |     assert "custom_id" in batch.requests[0]
 41 |     assert "url" in batch.requests[0]
 42 |     assert "body" in batch.requests[0]
 43 |     
 44 |     # Verify custom_id format and extract request_id and order_id
 45 |     custom_id_parts = batch.requests[0]["custom_id"].split("#")
 46 |     assert len(custom_id_parts) == 2
 47 |     extracted_request_id, order_id_str = custom_id_parts
 48 |     assert extracted_request_id == request_id
 49 |     assert order_id_str == "0"
 50 |     
 51 |     # Verify URL indicates chat completion
 52 |     assert batch.requests[0]["url"] == "/v1/chat/completions"
 53 |     
 54 |     # Verify request_id is computed correctly
 55 |     # Include all fields that affect the hash
 56 |     expected_request = {"type": "chat_completion", **batch.requests[0]["body"]}
 57 |     assert request_id == compute_request_hash(expected_request)
 58 | 
 59 | 
 60 | def test_request_batch_merge():
 61 |     """Test merging request batches and request_id preservation."""
 62 |     # Create first batch
 63 |     batch1 = RequestBatch()
 64 |     request_id1 = batch1.chat.completions.create(
 65 |         model="dummy-model",
 66 |         messages=[{"role": "user", "content": "Hi"}],
 67 |     )
 68 |     assert len(batch1.requests) == 1
 69 |     assert batch1.requests[0]["custom_id"].split("#")[0] == request_id1
 70 | 
 71 |     # Create second batch
 72 |     batch2 = RequestBatch()
 73 |     request_id2 = batch2.chat.completions.create(
 74 |         model="dummy-model",
 75 |         messages=[{"role": "user", "content": "Hello"}],
 76 |     )
 77 |     request_id3 = batch2.chat.completions.create(
 78 |         model="dummy-model",
 79 |         messages=[{"role": "user", "content": "Hey"}],
 80 |     )
 81 |     assert len(batch2.requests) == 2
 82 |     assert batch2.requests[0]["custom_id"].split("#")[0] == request_id2
 83 |     assert batch2.requests[1]["custom_id"].split("#")[0] == request_id3
 84 | 
 85 |     # Test merging batches
 86 |     merged_batch = RequestBatch.merge([batch1, batch2])
 87 |     assert len(merged_batch.requests) == 3
 88 |     
 89 |     # Verify request_ids are preserved after merge
 90 |     assert merged_batch.requests[0]["custom_id"].split("#")[0] == request_id1
 91 |     assert merged_batch.requests[1]["custom_id"].split("#")[0] == request_id2
 92 |     assert merged_batch.requests[2]["custom_id"].split("#")[0] == request_id3
 93 | 
 94 | 
 95 | def test_request_batch_multiple_merges():
 96 |     """Test merging multiple request batches and request_id preservation."""
 97 |     # Create first batch
 98 |     batch1 = RequestBatch()
 99 |     request_id1 = batch1.chat.completions.create(
100 |         model="dummy-model",
101 |         messages=[{"role": "user", "content": "Hi"}],
102 |     )
103 |     assert len(batch1.requests) == 1
104 |     assert batch1.requests[0]["custom_id"].split("#")[0] == request_id1
105 | 
106 |     # Create second batch
107 |     batch2 = RequestBatch()
108 |     request_id2 = batch2.chat.completions.create(
109 |         model="dummy-model",
110 |         messages=[{"role": "user", "content": "Hello"}],
111 |     )
112 |     assert batch2.requests[0]["custom_id"].split("#")[0] == request_id2
113 | 
114 |     # Create third batch
115 |     batch3 = RequestBatch()
116 |     request_id3 = batch3.chat.completions.create(
117 |         model="dummy-model",
118 |         messages=[{"role": "user", "content": "Hey"}],
119 |     )
120 |     assert batch3.requests[0]["custom_id"].split("#")[0] == request_id3
121 | 
122 |     # Test merging multiple batches
123 |     final_batch = RequestBatch.merge([batch1, batch2, batch3])
124 |     assert len(final_batch.requests) == 3
125 |     
126 |     # Verify request_ids are preserved after merge
127 |     assert final_batch.requests[0]["custom_id"].split("#")[0] == request_id1
128 |     assert final_batch.requests[1]["custom_id"].split("#")[0] == request_id2
129 |     assert final_batch.requests[2]["custom_id"].split("#")[0] == request_id3
130 | 
131 | 
132 | def test_request_id_consistency():
133 |     """Test that identical requests get the same request_id."""
134 |     batch = RequestBatch()
135 |     
136 |     # Create two identical requests
137 |     request_id1 = batch.chat.completions.create(
138 |         model="dummy-model",
139 |         messages=[{"role": "user", "content": "Hi"}],
140 |     )
141 |     
142 |     # Create a new batch to avoid order_id interference
143 |     batch2 = RequestBatch()
144 |     request_id2 = batch2.chat.completions.create(
145 |         model="dummy-model",
146 |         messages=[{"role": "user", "content": "Hi"}],
147 |     )
148 |     
149 |     # Verify that identical requests get the same request_id
150 |     assert request_id1 == request_id2
151 |     
152 |     # Create a different request
153 |     request_id3 = batch2.chat.completions.create(
154 |         model="dummy-model",
155 |         messages=[{"role": "user", "content": "Different"}],
156 |     )
157 |     
158 |     # Verify that different requests get different request_ids
159 |     assert request_id1 != request_id3
160 | 
161 | 
162 | def test_request_id_with_none_values():
163 |     """Test that None values are properly handled in request_id computation."""
164 |     batch = RequestBatch()
165 |     
166 |     # Create request with some None values
167 |     request_id = batch.chat.completions.create(
168 |         model="dummy-model",
169 |         messages=[{"role": "user", "content": "Hi"}],
170 |         temperature=None,  # Should be replaced with default
171 |         top_p=None,  # Should be replaced with default
172 |     )
173 |     
174 |     # Get the actual request body that was created
175 |     body1 = batch.requests[0]["body"]
176 |     
177 |     # Create identical request without None values
178 |     batch2 = RequestBatch()
179 |     request_id2 = batch2.chat.completions.create(
180 |         model="dummy-model",
181 |         messages=[{"role": "user", "content": "Hi"}],
182 |     )
183 |     
184 |     # Get the second actual request body
185 |     body2 = batch2.requests[0]["body"]
186 |     
187 |     # Both requests should have the same content
188 |     assert body1 == body2
189 |     
190 |     # Both request IDs should be identical since None values are replaced with defaults
191 |     assert request_id == request_id2
192 | 
193 | 
194 | def test_embeddings_request_format():
195 |     """Test the format of embeddings requests."""
196 |     batch = RequestBatch()
197 |     request_id = batch.embeddings.create(
198 |         model="text-embedding-ada-002",
199 |         input="Hello world",
200 |     )
201 |     
202 |     # Verify request was added
203 |     assert len(batch.requests) == 1
204 |     
205 |     # Verify OpenAI Batch format
206 |     assert "custom_id" in batch.requests[0]
207 |     assert batch.requests[0]["url"] == "/v1/embeddings"
208 |     assert "body" in batch.requests[0]
209 |     
210 |     # Verify custom_id format and extract request_id and order_id
211 |     custom_id_parts = batch.requests[0]["custom_id"].split("#")
212 |     assert len(custom_id_parts) == 2
213 |     extracted_request_id, order_id_str = custom_id_parts
214 |     assert extracted_request_id == request_id
215 |     assert order_id_str == "0"
216 |     
217 |     # Verify body content
218 |     assert batch.requests[0]["body"]["model"] == "text-embedding-ada-002"
219 |     assert batch.requests[0]["body"]["input"] == "Hello world"
220 | 


--------------------------------------------------------------------------------