├── .gitignore ├── CHANGELOG.md ├── README.md ├── docs ├── README.md ├── RELEASE.md ├── api-reference.md ├── architecture.md ├── configuration.md ├── core-components.md ├── development-guide.md ├── getting-started.md ├── overview.md └── rate-limiting.md ├── embedding_results.json ├── examples ├── embedding_example.py ├── notebook_test.ipynb └── parallel_test.py ├── fastllm ├── __init__.py ├── cache.py ├── core.py └── providers │ ├── __init__.py │ ├── base.py │ └── openai.py ├── justfile ├── pyproject.toml └── tests ├── test_asyncio.py ├── test_cache.py ├── test_core.py ├── test_openai.py ├── test_progress_tracker.py ├── test_providers.py └── test_request_batch.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | build/ 8 | develop-eggs/ 9 | dist/ 10 | downloads/ 11 | eggs/ 12 | .eggs/ 13 | lib/ 14 | lib64/ 15 | parts/ 16 | sdist/ 17 | var/ 18 | wheels/ 19 | *.egg-info/ 20 | .installed.cfg 21 | *.egg 22 | 23 | # Virtual Environment 24 | .env 25 | .venv 26 | env/ 27 | venv/ 28 | ENV/ 29 | 30 | # IDE 31 | .idea/ 32 | .vscode/ 33 | *.swp 34 | *.swo 35 | .DS_Store 36 | 37 | # Project specific 38 | results.json 39 | parallel_test_results.json 40 | .coverage 41 | htmlcov/ 42 | .pytest_cache/ 43 | .ruff_cache/ 44 | .mypy_cache/ 45 | cache/ 46 | uv.lock 47 | *.lock -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), 6 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 7 | 8 | ## [0.1.0] - 2024-03 9 | 10 | ### Added 11 | - Initial release with core functionality 12 | - Parallel request processing with configurable concurrency 13 | - In-memory and disk-based caching support 14 | - Multiple provider support (OpenAI, OpenRouter) 15 | - Request batching with OpenAI-style API 16 | - Progress tracking and statistics 17 | - Request deduplication and response ordering 18 | - Configurable retry mechanism 19 | - Rich progress bar with detailed statistics 20 | - Support for existing asyncio event loops 21 | - Jupyter notebook compatibility 22 | - Request ID (cache key) return from batch creation methods 23 | 24 | ### Changed 25 | - Optimized request processing for better performance 26 | - Improved error handling and reporting 27 | - Enhanced request ID handling and exposure 28 | - Added compatibility with existing asyncio event loops 29 | - Fixed asyncio loop handling in Jupyter notebooks 30 | - Made request IDs accessible to users for cache management 31 | 32 | ### Fixed 33 | - Request ID validation and string conversion 34 | - Cache persistence issues 35 | - Response ordering in parallel processing -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FastLLM 2 | 3 | High-performance parallel LLM API request tool with support for multiple providers and caching capabilities. 4 | 5 | ## Features 6 | 7 | - Parallel request processing with configurable concurrency 8 | - Allows you to process 20000+ prompt tokens per second and 1500+ output tokens per second even for extremely large LLMs, such as Deepseek-V3. 9 | - Built-in caching support (in-memory and disk-based) 10 | - Progress tracking with token usage statistics 11 | - Support for multiple LLM providers (OpenAI, OpenRouter, etc.) 12 | - OpenAI-style API for request batching 13 | - Retry mechanism with configurable attempts and delays 14 | - Request deduplication and response ordering 15 | 16 | ## Installation 17 | 18 | Use pip: 19 | ```bash 20 | pip install fastllm-kit 21 | ``` 22 | 23 | Alternatively, use uv: 24 | ```bash 25 | uv pip install fastllm-kit 26 | ``` 27 | 28 | > **Important:** fastllm does not support yet libsqlite3.49.1, please use libsqlite3.49.0 or lower. See [this issue](https://github.com/grantjenks/python-diskcache/issues/343) for more details. This might be an issue for users with conda environments. 29 | 30 | For development: 31 | ```bash 32 | # Clone the repository 33 | git clone https://github.com/Rexhaif/fastllm.git 34 | cd fastllm 35 | 36 | # Create a virtual environment and install dependencies 37 | uv venv 38 | uv pip install -e ".[dev]" 39 | ``` 40 | 41 | ## Dependencies 42 | 43 | FastLLM requires Python 3.9 or later and depends on the following packages: 44 | 45 | - `httpx` (^0.27.2) - For async HTTP requests 46 | - `pydantic` (^2.10.6) - For data validation and settings management 47 | - `rich` (^13.9.4) - For beautiful terminal output and progress bars 48 | - `diskcache` (^5.6.3) - For persistent disk caching 49 | - `asyncio` (^3.4.3) - For asynchronous operations 50 | - `anyio` (^4.8.0) - For async I/O operations 51 | - `tqdm` (^4.67.1) - For progress tracking 52 | - `typing_extensions` (^4.12.2) - For enhanced type hints 53 | 54 | Development dependencies: 55 | - `ruff` (^0.3.7) - For linting and formatting 56 | - `pytest` (^8.3.4) - For testing 57 | - `pytest-asyncio` (^0.23.8) - For async tests 58 | - `pytest-cov` (^4.1.0) - For test coverage 59 | - `black` (^24.10.0) - For code formatting 60 | - `coverage` (^7.6.10) - For code coverage reporting 61 | 62 | ## Development 63 | 64 | The project uses [just](https://github.com/casey/just) for task automation and [uv](https://github.com/astral/uv) for dependency management. 65 | 66 | Common tasks: 67 | ```bash 68 | # Install dependencies 69 | just install 70 | 71 | # Run tests 72 | just test 73 | 74 | # Format code 75 | just format 76 | 77 | # Run linting 78 | just lint 79 | 80 | # Clean up cache files 81 | just clean 82 | ``` 83 | 84 | ## Quick Start 85 | 86 | ```python 87 | from fastllm import RequestBatch, RequestManager, OpenAIProvider, InMemoryCache 88 | 89 | # Create a provider 90 | provider = OpenAIProvider( 91 | api_key="your-api-key", 92 | # Optional: custom API base URL 93 | api_base="https://api.openai.com/v1", 94 | ) 95 | 96 | # Create a cache provider (optional) 97 | cache = InMemoryCache() # or DiskCache(directory="./cache") 98 | 99 | # Create a request manager 100 | manager = RequestManager( 101 | provider=provider, 102 | concurrency=50, # Number of concurrent requests 103 | show_progress=True, # Show progress bar 104 | caching_provider=cache, # Enable caching 105 | ) 106 | 107 | # Create a batch of requests 108 | request_ids = [] # Store request IDs for later use 109 | with RequestBatch() as batch: 110 | # Add requests to the batch 111 | for i in range(10): 112 | # create() returns the request ID (caching key) 113 | request_id = batch.chat.completions.create( 114 | model="gpt-3.5-turbo", 115 | messages=[{ 116 | "role": "user", 117 | "content": f"What is {i} + {i}?" 118 | }], 119 | temperature=0.7, 120 | include_reasoning=True, # Optional: include model reasoning 121 | ) 122 | request_ids.append(request_id) 123 | 124 | # Process the batch 125 | responses = manager.process_batch(batch) 126 | 127 | # Process responses 128 | for request_id, response in zip(request_ids, responses): 129 | print(f"Request {request_id}: {response.response.choices[0].message.content}") 130 | 131 | # You can use request IDs to check cache status 132 | for request_id in request_ids: 133 | is_cached = await cache.exists(request_id) 134 | print(f"Request {request_id} is {'cached' if is_cached else 'not cached'}") 135 | ``` 136 | 137 | ## Advanced Usage 138 | 139 | ### Async Support 140 | 141 | FastLLM can be used both synchronously and asynchronously, and works seamlessly in regular Python environments, async applications, and Jupyter notebooks: 142 | 143 | ```python 144 | import asyncio 145 | from fastllm import RequestBatch, RequestManager, OpenAIProvider 146 | 147 | # Works in Jupyter notebooks 148 | provider = OpenAIProvider(api_key="your-api-key") 149 | manager = RequestManager(provider=provider) 150 | responses = manager.process_batch(batch) # Just works! 151 | 152 | # Works in async applications 153 | async def process_requests(): 154 | provider = OpenAIProvider(api_key="your-api-key") 155 | manager = RequestManager(provider=provider) 156 | 157 | with RequestBatch() as batch: 158 | batch.chat.completions.create( 159 | model="gpt-3.5-turbo", 160 | messages=[{"role": "user", "content": "Hello!"}] 161 | ) 162 | 163 | responses = manager.process_batch(batch) 164 | return responses 165 | 166 | # Run in existing event loop 167 | async def main(): 168 | responses = await process_requests() 169 | print(responses) 170 | 171 | asyncio.run(main()) 172 | ``` 173 | 174 | ### Caching Configuration 175 | 176 | FastLLM supports both in-memory and disk-based caching, with request IDs serving as cache keys: 177 | 178 | ```python 179 | from fastllm import InMemoryCache, DiskCache, RequestBatch 180 | 181 | # Create a batch and get request IDs 182 | with RequestBatch() as batch: 183 | request_id = batch.chat.completions.create( 184 | model="gpt-3.5-turbo", 185 | messages=[{"role": "user", "content": "Hello!"}] 186 | ) 187 | print(f"Request ID (cache key): {request_id}") 188 | 189 | # In-memory cache (faster, but cleared when process ends) 190 | cache = InMemoryCache() 191 | 192 | # Disk cache (persistent, with optional TTL and size limits) 193 | cache = DiskCache( 194 | directory="./cache", 195 | ttl=3600, # Cache TTL in seconds 196 | size_limit=int(2e9) # 2GB size limit 197 | ) 198 | 199 | # Check if a response is cached 200 | is_cached = await cache.exists(request_id) 201 | 202 | # Get cached response if available 203 | if is_cached: 204 | response = await cache.get(request_id) 205 | ``` 206 | 207 | ### Custom Providers 208 | 209 | Create your own provider by inheriting from the base `Provider` class: 210 | 211 | ```python 212 | from fastllm import Provider 213 | from typing import Any 214 | import httpx 215 | 216 | class CustomProvider(Provider[YourResponseType]): 217 | def get_request_headers(self) -> dict[str, str]: 218 | return { 219 | "Authorization": f"Bearer {self.api_key}", 220 | "Content-Type": "application/json", 221 | } 222 | 223 | async def make_request( 224 | self, 225 | client: httpx.AsyncClient, 226 | request: dict[str, Any], 227 | timeout: float, 228 | ) -> YourResponseType: 229 | # Implement your request logic here 230 | pass 231 | ``` 232 | 233 | ### Progress Tracking 234 | 235 | The progress bar shows: 236 | - Request completion progress 237 | - Tokens per second (prompt and completion) 238 | - Cache hit/miss statistics 239 | - Estimated time remaining 240 | - Total elapsed time 241 | 242 | ## Contributing 243 | 244 | Contributions are welcome! Please feel free to submit a Pull Request. 245 | 246 | ## License 247 | 248 | This project is licensed under the MIT License - see the LICENSE file for details. 249 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # FastLLM Documentation 2 | 3 | FastLLM is a high-performance Python library for making parallel LLM API requests with built-in caching and multiple provider support. It's designed to efficiently handle batch processing of LLM requests while providing a familiar API interface similar to OpenAI's client library. 4 | 5 | ## Table of Contents 6 | 7 | - [Overview](overview.md) 8 | - [Getting Started](getting-started.md) 9 | - [Architecture](architecture.md) 10 | - [Core Components](core-components.md) 11 | - [API Reference](api-reference.md) 12 | - [Configuration](configuration.md) 13 | - [Development Guide](development-guide.md) -------------------------------------------------------------------------------- /docs/RELEASE.md: -------------------------------------------------------------------------------- 1 | # Release Procedure 2 | 3 | This document outlines the steps to release fastllm-kit to PyPI. 4 | 5 | ## Publishing to PyPI 6 | 7 | The library is published as `fastllm-kit` on PyPI, but maintains the import name `fastllm` for user convenience. 8 | 9 | ### Prerequisites 10 | 11 | - Ensure you have `uv` installed 12 | - Have PyPI credentials set up 13 | - Set the PyPI token with `UV_PUBLISH_TOKEN` environment variable or use `--token` flag 14 | - For TestPyPI, you'll also need TestPyPI credentials 15 | 16 | ### Release Steps 17 | 18 | 1. Update the version in `pyproject.toml` 19 | 2. Update `CHANGELOG.md` with the latest changes 20 | 3. Commit all changes 21 | 22 | 4. Clean previous builds: 23 | ``` 24 | just clean 25 | ``` 26 | 27 | 5. Build the package: 28 | ``` 29 | just build 30 | ``` 31 | 32 | 6. Publish to TestPyPI first to verify everything works: 33 | ``` 34 | just publish-test 35 | ``` 36 | Or with explicit token: 37 | ``` 38 | UV_PUBLISH_TOKEN=your_testpypi_token just publish-test 39 | ``` 40 | 41 | 7. Verify installation from TestPyPI: 42 | ``` 43 | pip uninstall -y fastllm-kit 44 | pip install --index-url https://test.pypi.org/simple/ fastllm-kit 45 | ``` 46 | 47 | 8. Verify imports work correctly: 48 | ```python 49 | from fastllm import RequestBatch, RequestManager 50 | ``` 51 | 52 | 9. If everything works, publish to PyPI: 53 | ``` 54 | just publish 55 | ``` 56 | Or with explicit token: 57 | ``` 58 | UV_PUBLISH_TOKEN=your_pypi_token just publish 59 | ``` 60 | 61 | ## Usage Instructions 62 | 63 | Users will install the package with: 64 | ``` 65 | pip install fastllm-kit 66 | ``` 67 | 68 | But will import it in their code as: 69 | ```python 70 | from fastllm import RequestBatch, RequestManager 71 | ``` 72 | 73 | This allows us to maintain a clean import interface despite the PyPI name being different from the import name. -------------------------------------------------------------------------------- /docs/api-reference.md: -------------------------------------------------------------------------------- 1 | # API Reference 2 | 3 | This document provides detailed information about FastLLM's public APIs. 4 | 5 | ## RequestManager 6 | 7 | ### Class Definition 8 | 9 | ```python 10 | class RequestManager: 11 | def __init__( 12 | self, 13 | provider: Provider[ResponseT], 14 | concurrency: int = 100, 15 | timeout: float = 30.0, 16 | retry_attempts: int = 3, 17 | retry_delay: float = 1.0, 18 | show_progress: bool = True, 19 | caching_provider: Optional[CacheProvider] = None 20 | ) 21 | ``` 22 | 23 | #### Parameters 24 | 25 | - `provider`: LLM provider instance 26 | - `concurrency`: Maximum number of concurrent requests (default: 100) 27 | - `timeout`: Request timeout in seconds (default: 30.0) 28 | - `retry_attempts`: Number of retry attempts (default: 3) 29 | - `retry_delay`: Delay between retries in seconds (default: 1.0) 30 | - `show_progress`: Whether to show progress bar (default: True) 31 | - `caching_provider`: Optional cache provider instance 32 | 33 | ### Methods 34 | 35 | #### process_batch 36 | 37 | ```python 38 | def process_batch( 39 | self, 40 | batch: Union[list[dict[str, Any]], RequestBatch] 41 | ) -> list[ResponseT] 42 | ``` 43 | 44 | Process a batch of LLM requests in parallel. 45 | 46 | **Parameters:** 47 | - `batch`: Either a RequestBatch object or a list of request dictionaries 48 | 49 | **Returns:** 50 | - List of responses in the same order as the requests 51 | 52 | ## Provider 53 | 54 | ### Base Class 55 | 56 | ```python 57 | class Provider(Generic[ResponseT], ABC): 58 | def __init__( 59 | self, 60 | api_key: str, 61 | api_base: str, 62 | headers: Optional[dict[str, str]] = None, 63 | **kwargs: Any 64 | ) 65 | ``` 66 | 67 | #### Abstract Methods 68 | 69 | ```python 70 | @abstractmethod 71 | def get_request_headers(self) -> dict[str, str]: 72 | """Get headers for API requests.""" 73 | pass 74 | 75 | @abstractmethod 76 | async def make_request( 77 | self, 78 | client: httpx.AsyncClient, 79 | request: dict[str, Any], 80 | timeout: float 81 | ) -> ResponseT: 82 | """Make a request to the provider API.""" 83 | pass 84 | ``` 85 | 86 | ### OpenAI Provider 87 | 88 | ```python 89 | class OpenAIProvider(Provider[ChatCompletion]): 90 | def __init__( 91 | self, 92 | api_key: str, 93 | api_base: str = DEFAULT_API_BASE, 94 | organization: Optional[str] = None, 95 | headers: Optional[dict[str, str]] = None, 96 | **kwargs: Any 97 | ) 98 | ``` 99 | 100 | ## Cache System 101 | 102 | ### CacheProvider Interface 103 | 104 | ```python 105 | class CacheProvider: 106 | async def exists(self, key: str) -> bool: 107 | """Check if a key exists in the cache.""" 108 | pass 109 | 110 | async def get(self, key: str): 111 | """Get a value from the cache.""" 112 | pass 113 | 114 | async def put(self, key: str, value) -> None: 115 | """Put a value in the cache.""" 116 | pass 117 | 118 | async def clear(self) -> None: 119 | """Clear all items from the cache.""" 120 | pass 121 | 122 | async def close(self) -> None: 123 | """Close the cache when done.""" 124 | pass 125 | ``` 126 | 127 | ### InMemoryCache 128 | 129 | ```python 130 | class InMemoryCache(CacheProvider): 131 | def __init__(self) 132 | ``` 133 | 134 | Simple in-memory cache implementation using a dictionary. 135 | 136 | ### DiskCache 137 | 138 | ```python 139 | class DiskCache(CacheProvider): 140 | def __init__( 141 | self, 142 | directory: str, 143 | ttl: Optional[int] = None, 144 | **cache_options 145 | ) 146 | ``` 147 | 148 | **Parameters:** 149 | - `directory`: Directory where cache files will be stored 150 | - `ttl`: Time to live in seconds for cached items 151 | - `cache_options`: Additional options for diskcache.Cache 152 | 153 | ## Request Models 154 | 155 | ### Message 156 | 157 | ```python 158 | class Message(BaseModel): 159 | role: Literal["system", "user", "assistant", "function", "tool"] = "user" 160 | content: Optional[str] = None 161 | name: Optional[str] = None 162 | function_call: Optional[dict[str, Any]] = None 163 | tool_calls: Optional[list[dict[str, Any]]] = None 164 | ``` 165 | 166 | #### Class Methods 167 | 168 | ```python 169 | @classmethod 170 | def from_dict(cls, data: Union[str, dict[str, Any]]) -> Message: 171 | """Create a message from a string or dictionary.""" 172 | ``` 173 | 174 | ### LLMRequest 175 | 176 | ```python 177 | class LLMRequest(BaseModel): 178 | provider: str 179 | messages: list[Message] 180 | model: Optional[str] = None 181 | temperature: float = Field(default=0.7, ge=0.0, le=2.0) 182 | max_completion_tokens: Optional[int] = None 183 | top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0) 184 | presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0) 185 | frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0) 186 | stop: Optional[list[str]] = None 187 | stream: bool = False 188 | ``` 189 | 190 | #### Class Methods 191 | 192 | ```python 193 | @classmethod 194 | def from_prompt( 195 | cls, 196 | provider: str, 197 | prompt: Union[str, dict[str, Any]], 198 | **kwargs 199 | ) -> LLMRequest: 200 | """Create a request from a single prompt.""" 201 | 202 | @classmethod 203 | def from_dict(cls, data: dict[str, Any]) -> LLMRequest: 204 | """Create a request from a dictionary.""" 205 | ``` 206 | 207 | ## RequestBatch 208 | 209 | ```python 210 | class RequestBatch(AbstractContextManager): 211 | def __init__(self) 212 | ``` 213 | 214 | ### Usage Example 215 | 216 | ```python 217 | with RequestBatch() as batch: 218 | batch.chat.completions.create( 219 | model="gpt-3.5-turbo", 220 | messages=[{"role": "user", "content": "Hello!"}] 221 | ) 222 | ``` 223 | 224 | ### Chat Completions API 225 | 226 | ```python 227 | def create( 228 | self, 229 | *, 230 | model: str, 231 | messages: list[dict[str, str]], 232 | temperature: Optional[float] = 0.7, 233 | top_p: Optional[float] = 1.0, 234 | n: Optional[int] = 1, 235 | stop: Optional[Union[str, list[str]]] = None, 236 | max_completion_tokens: Optional[int] = None, 237 | presence_penalty: Optional[float] = 0.0, 238 | frequency_penalty: Optional[float] = 0.0, 239 | logit_bias: Optional[dict[str, float]] = None, 240 | user: Optional[str] = None, 241 | response_format: Optional[dict[str, str]] = None, 242 | seed: Optional[int] = None, 243 | tools: Optional[list[dict[str, Any]]] = None, 244 | tool_choice: Optional[Union[str, dict[str, str]]] = None, 245 | **kwargs: Any 246 | ) -> str: 247 | """Add a chat completion request to the batch.""" 248 | ``` 249 | 250 | ## Progress Tracking 251 | 252 | ### TokenStats 253 | 254 | ```python 255 | @dataclass 256 | class TokenStats: 257 | prompt_tokens: int = 0 258 | completion_tokens: int = 0 259 | total_tokens: int = 0 260 | requests_completed: int = 0 261 | cache_hits: int = 0 262 | start_time: float = 0.0 263 | ``` 264 | 265 | ### ProgressTracker 266 | 267 | ```python 268 | class ProgressTracker: 269 | def __init__( 270 | self, 271 | total_requests: int, 272 | show_progress: bool = True 273 | ) 274 | ``` 275 | 276 | ## Response Types 277 | 278 | ### ResponseWrapper 279 | 280 | ```python 281 | class ResponseWrapper(Generic[ResponseT]): 282 | def __init__( 283 | self, 284 | response: ResponseT, 285 | request_id: str, 286 | order_id: int 287 | ) -------------------------------------------------------------------------------- /docs/architecture.md: -------------------------------------------------------------------------------- 1 | # FastLLM Architecture 2 | 3 | This document provides an overview of the FastLLM architecture and design decisions. 4 | 5 | ## Project Structure 6 | 7 | ``` 8 | fastllm/ # Main package directory 9 | ├── __init__.py # Package exports 10 | ├── core.py # Core functionality (RequestBatch, RequestManager) 11 | ├── cache.py # Caching implementations 12 | └── providers/ # LLM provider implementations 13 | ├── __init__.py 14 | ├── base.py # Base Provider class 15 | └── openai.py # OpenAI API implementation 16 | ``` 17 | 18 | ## Package and Distribution 19 | 20 | The package follows these important design decisions: 21 | 22 | - **PyPI Package Name**: `fastllm-kit` (since `fastllm` was already taken on PyPI) 23 | - **Import Name**: `fastllm` (users import with `from fastllm import ...`) 24 | - **GitHub Repository**: Maintained at `github.com/rexhaif/fastllm` 25 | 26 | The import redirection is implemented using Hatchling's build configuration in `pyproject.toml`, which ensures that the package is published as `fastllm-kit` but exposes a top-level `fastllm` module for imports. 27 | 28 | ## System Architecture 29 | 30 | ``` 31 | ┌─────────────────────────────────────────────────────────┐ 32 | │ RequestManager │ 33 | ├─────────────────────────────────────────────────────────┤ 34 | │ - Manages parallel request processing │ 35 | │ - Handles concurrency and batching │ 36 | │ - Coordinates between components │ 37 | └───────────────────┬───────────────────┬─────────────────┘ 38 | │ │ 39 | ┌───────────────▼──────┐ ┌────────▼──────────┐ 40 | │ Provider │ │ CacheProvider │ 41 | ├──────────────────────┤ ├───────────────────┤ 42 | │ - API communication │ │ - Request caching │ 43 | │ - Response parsing │ │ - Cache management│ 44 | └──────────────────────┘ └───────────────────┘ 45 | ``` 46 | 47 | ## Core Components 48 | 49 | ### RequestBatch 50 | 51 | The `RequestBatch` class provides an OpenAI-compatible request interface that allows for batching multiple requests together. 52 | 53 | ### RequestManager 54 | 55 | The `RequestManager` handles parallel processing of requests with configurable concurrency, retries, and caching. 56 | 57 | ### Providers 58 | 59 | Providers implement the interface for different LLM services: 60 | 61 | - `OpenAIProvider`: Handles requests to OpenAI-compatible APIs 62 | 63 | ### Caching 64 | 65 | The caching system supports: 66 | 67 | - `InMemoryCache`: Fast, in-process caching 68 | - `DiskCache`: Persistent disk-based caching with TTL and size limits 69 | 70 | ## Design Decisions 71 | 72 | 1. **Parallel Processing**: Designed to maximize throughput by processing many requests in parallel 73 | 2. **Request Deduplication**: Automatically detects and deduplicates identical requests 74 | 3. **Response Ordering**: Maintains request ordering regardless of completion time 75 | 4. **Caching**: Optional caching with customizable providers 76 | 77 | ## Development and Testing 78 | 79 | - Tests are implemented using `pytest` and `pytest-asyncio` 80 | - Code formatting is handled by `ruff` and `black` 81 | - Task automation is handled by `just` 82 | - Dependency management uses `uv` 83 | 84 | ## Key Utilities 85 | 86 | - `compute_request_hash`: Consistent request hashing for caching and deduplication 87 | - `TokenStats`: Tracking token usage and rate limits 88 | - `ProgressTracker`: Visual progress bar with statistics 89 | 90 | ## Request Flow 91 | 92 | 1. Create a `RequestBatch` and add requests via APIs like `batch.chat.completions.create()` 93 | 2. Each request is directly stored in OpenAI Batch format with a `custom_id` that combines the request hash and order ID 94 | 3. Pass the batch to a `RequestManager` for parallel processing 95 | 4. The manager extracts necessary metadata from the `custom_id` and `url` fields 96 | 5. Responses are returned in the original request order 97 | 98 | ## Caching 99 | 100 | The library implements efficient caching: 101 | - Request hashing for consistent cache keys 102 | - Support for both in-memory and persistent caching 103 | - Cache hits bypass network requests 104 | 105 | ## Provider Interface 106 | 107 | The library's provider system is designed to work with the simplified OpenAI Batch format: 108 | - Providers receive only the request body and necessary metadata 109 | - No conversion between formats is required during processing 110 | - Both chat completions and embeddings use the same batch structure with different URLs 111 | - API endpoints are determined automatically based on request type within each provider implementation 112 | 113 | ## Extensions 114 | 115 | The library is designed to be easily extensible: 116 | - Support for multiple LLM providers 117 | - Custom cache implementations 118 | - Flexible request formatting 119 | 120 | ## Key Features 121 | 122 | ### 1. Parallel Processing 123 | - Async/await throughout 124 | - Efficient request batching 125 | - Concurrent API calls 126 | - Progress monitoring 127 | - Support for both chat completions and embeddings 128 | 129 | ### 2. Caching 130 | - Multiple cache backends 131 | - TTL support 132 | - Async operations 133 | - Thread-safe implementation 134 | 135 | ### 3. Rate Limiting 136 | - Token-based limits 137 | - Request frequency limits 138 | - Window-based tracking 139 | - Saturation monitoring 140 | 141 | ### 4. Error Handling 142 | - Consistent error types 143 | - Graceful degradation 144 | - Detailed error messages 145 | - Retry mechanisms 146 | 147 | ## Configuration Points 148 | 149 | The system can be configured through: 150 | 1. Provider settings 151 | - API keys and endpoints 152 | - Organization IDs 153 | - Custom headers 154 | 155 | 2. Cache settings 156 | - Backend selection 157 | - TTL configuration 158 | - Storage directory 159 | - Serialization options 160 | 161 | 3. Request parameters 162 | - Model selection 163 | - Temperature and sampling 164 | - Token limits 165 | - Response streaming 166 | 167 | 4. Rate limiting 168 | - Token rate limits 169 | - Request frequency limits 170 | - Window sizes 171 | - Saturation thresholds 172 | 173 | ## Best Practices 174 | 175 | 1. **Error Handling** 176 | - Use try-except blocks for cache operations 177 | - Handle API errors gracefully 178 | - Provide meaningful error messages 179 | - Implement proper cleanup 180 | 181 | 2. **Performance** 182 | - Use appropriate cache backend 183 | - Configure batch sizes 184 | - Monitor rate limits 185 | - Track token usage 186 | 187 | 3. **Security** 188 | - Secure API key handling 189 | - Safe cache storage 190 | - Input validation 191 | - Response sanitization 192 | 193 | 4. **Maintenance** 194 | - Regular cache cleanup 195 | - Monitor disk usage 196 | - Update API versions 197 | - Track deprecations 198 | 199 | ## Data Flow 200 | 201 | 1. **Request Initialization** 202 | ```python 203 | RequestBatch 204 | → Chat Completion or Embedding Request 205 | → Provider-specific Request 206 | ``` 207 | 208 | 2. **Request Processing** 209 | ```python 210 | RequestManager 211 | → Check Cache 212 | → Make API Request if needed 213 | → Update Progress 214 | → Store in Cache 215 | ``` 216 | 217 | 3. **Response Handling** 218 | ```python 219 | API Response 220 | → ResponseWrapper 221 | → Update Statistics 222 | → Return to User 223 | ``` 224 | 225 | ## Design Principles 226 | 227 | 1. **Modularity** 228 | - Clear separation of concerns 229 | - Extensible provider system 230 | - Pluggable cache providers 231 | 232 | 2. **Performance** 233 | - Efficient parallel processing 234 | - Smart resource management 235 | - Optimized caching 236 | 237 | 3. **Reliability** 238 | - Comprehensive error handling 239 | - Automatic retries 240 | - Progress tracking 241 | 242 | 4. **Developer Experience** 243 | - Familiar API patterns 244 | - Clear type hints 245 | - Comprehensive logging 246 | - Structured logging system 247 | 248 | ## Logging System 249 | 250 | The library uses Python's built-in `logging` module for structured logging: 251 | 252 | 1. **Core Components** 253 | - Each module has its own logger (`logging.getLogger(__name__)`) 254 | - Log levels used appropriately (DEBUG, INFO, WARNING, ERROR) 255 | - Critical operations and errors are logged 256 | 257 | 2. **Key Logging Areas** 258 | - Cache operations (read/write errors) 259 | - Request processing status 260 | - Rate limiting events 261 | - Error conditions and exceptions 262 | 263 | 3. **Best Practices** 264 | - Consistent log format 265 | - Meaningful context in log messages 266 | - Error traceability 267 | - Performance impact consideration 268 | 269 | ## Error Handling 270 | 271 | The system implements comprehensive error handling: 272 | - API errors 273 | - Rate limiting 274 | - Timeouts 275 | - Cache failures 276 | - Invalid requests 277 | 278 | Each component includes appropriate error handling and propagation to ensure system stability and reliability. 279 | 280 | ## Testing Strategy 281 | 282 | The test suite is organized by component: 283 | 284 | 1. **Core Tests** (`test_core.py`): 285 | - Request/Response model validation 286 | - RequestManager functionality 287 | - Token statistics tracking 288 | 289 | 2. **Cache Tests** (`test_cache.py`): 290 | - Cache implementation verification 291 | - Request hashing consistency 292 | - Concurrent access handling 293 | 294 | 3. **Provider Tests** (`test_providers.py`): 295 | - Provider interface compliance 296 | - API communication 297 | - Response parsing 298 | 299 | 4. **Integration Tests**: 300 | - End-to-end request flow 301 | - Rate limiting behavior 302 | - Error handling scenarios 303 | 304 | ## Supported APIs 305 | 306 | The library supports the following APIs: 307 | 308 | 1. **Chat Completions** 309 | - Multiple message support 310 | - Tool and function calls 311 | - Streaming responses 312 | - Temperature and top_p sampling 313 | 314 | 2. **Embeddings** 315 | - Single or batch text input 316 | - Dimension control 317 | - Format control (float/base64) 318 | - Efficient batch processing 319 | - Compatible with semantic search use cases 320 | -------------------------------------------------------------------------------- /docs/configuration.md: -------------------------------------------------------------------------------- 1 | # Configuration Guide 2 | 3 | This guide covers the various configuration options available in FastLLM and how to use them effectively. 4 | 5 | ## Provider Configuration 6 | 7 | ### OpenAI Provider 8 | 9 | ```python 10 | from fastllm.providers import OpenAIProvider 11 | 12 | provider = OpenAIProvider( 13 | api_key="your-api-key", 14 | api_base="https://api.openai.com/v1", # Optional: custom API endpoint 15 | organization="your-org-id", # Optional: OpenAI organization ID 16 | headers={ # Optional: custom headers 17 | "Custom-Header": "value" 18 | } 19 | ) 20 | ``` 21 | 22 | #### Configuration Options 23 | 24 | | Parameter | Type | Default | Description | 25 | |-----------|------|---------|-------------| 26 | | api_key | str | Required | Your API key | 27 | | api_base | str | OpenAI default | API endpoint URL | 28 | | organization | str | None | Organization ID | 29 | | headers | dict | None | Custom headers | 30 | 31 | ## Request Manager Configuration 32 | 33 | ```python 34 | from fastllm import RequestManager 35 | 36 | manager = RequestManager( 37 | provider=provider, 38 | concurrency=100, 39 | timeout=30.0, 40 | retry_attempts=3, 41 | retry_delay=1.0, 42 | show_progress=True, 43 | caching_provider=None 44 | ) 45 | ``` 46 | 47 | ### Performance Settings 48 | 49 | | Parameter | Type | Default | Description | 50 | |-----------|------|---------|-------------| 51 | | concurrency | int | 100 | Maximum concurrent requests | 52 | | timeout | float | 30.0 | Request timeout in seconds | 53 | | retry_attempts | int | 3 | Number of retry attempts | 54 | | retry_delay | float | 1.0 | Delay between retries | 55 | 56 | ### Progress Display 57 | 58 | ```python 59 | manager = RequestManager( 60 | provider=provider, 61 | show_progress=True # Enable rich progress display 62 | ) 63 | ``` 64 | 65 | The progress display shows: 66 | - Completion percentage 67 | - Request count 68 | - Token usage rates 69 | - Cache hit ratio 70 | - Estimated time remaining 71 | 72 | ## Cache Configuration 73 | 74 | ### In-Memory Cache 75 | 76 | ```python 77 | from fastllm.cache import InMemoryCache 78 | 79 | cache = InMemoryCache() 80 | ``` 81 | 82 | Best for: 83 | - Development and testing 84 | - Small-scale applications 85 | - Temporary caching needs 86 | 87 | ### Disk Cache 88 | 89 | ```python 90 | from fastllm.cache import DiskCache 91 | 92 | cache = DiskCache( 93 | directory="cache", # Cache directory path 94 | ttl=3600, # Cache TTL in seconds 95 | size_limit=1e9 # Cache size limit in bytes 96 | ) 97 | ``` 98 | 99 | #### Configuration Options 100 | 101 | | Parameter | Type | Default | Description | 102 | |-----------|------|---------|-------------| 103 | | directory | str | Required | Cache directory path | 104 | | ttl | int | None | Time-to-live in seconds | 105 | | size_limit | int | None | Maximum cache size | 106 | 107 | ### Cache Integration 108 | 109 | ```python 110 | manager = RequestManager( 111 | provider=provider, 112 | caching_provider=cache 113 | ) 114 | ``` 115 | 116 | ## Request Configuration 117 | 118 | ### Basic Request Settings 119 | 120 | ```python 121 | with RequestBatch() as batch: 122 | batch.chat.completions.create( 123 | model="gpt-3.5-turbo", 124 | messages=[{"role": "user", "content": "Hello"}], 125 | temperature=0.7, 126 | max_completion_tokens=100 127 | ) 128 | ``` 129 | 130 | ### Available Parameters 131 | 132 | | Parameter | Type | Default | Description | 133 | |-----------|------|---------|-------------| 134 | | model | str | Required | Model identifier | 135 | | messages | list | Required | Conversation messages | 136 | | temperature | float | 0.7 | Sampling temperature | 137 | | max_completion_tokens | int | None | Max tokens to generate | 138 | | top_p | float | 1.0 | Nucleus sampling parameter | 139 | | presence_penalty | float | 0.0 | Presence penalty | 140 | | frequency_penalty | float | 0.0 | Frequency penalty | 141 | | stop | list[str] | None | Stop sequences | 142 | 143 | ## Advanced Configuration 144 | 145 | ### Custom Chunk Size 146 | 147 | The RequestManager automatically calculates optimal chunk sizes, but you can influence this through the concurrency setting: 148 | 149 | ```python 150 | manager = RequestManager( 151 | provider=provider, 152 | concurrency=50 # Will affect chunk size calculation 153 | ) 154 | ``` 155 | 156 | Chunk size is calculated as: 157 | ```python 158 | chunk_size = min(concurrency * 2, 1000) 159 | ``` 160 | 161 | ### Error Handling Configuration 162 | 163 | ```python 164 | manager = RequestManager( 165 | provider=provider, 166 | retry_attempts=5, # Increase retry attempts 167 | retry_delay=2.0, # Increase delay between retries 168 | timeout=60.0 # Increase timeout 169 | ) 170 | ``` 171 | 172 | ### Custom Headers 173 | 174 | ```python 175 | provider = OpenAIProvider( 176 | api_key="your-api-key", 177 | headers={ 178 | "User-Agent": "CustomApp/1.0", 179 | "X-Custom-Header": "value" 180 | } 181 | ) 182 | ``` 183 | 184 | ## Environment Variables 185 | 186 | FastLLM respects the following environment variables: 187 | 188 | ```bash 189 | OPENAI_API_KEY=your-api-key 190 | OPENAI_ORG_ID=your-org-id 191 | ``` 192 | 193 | ## Best Practices 194 | 195 | ### Concurrency Settings 196 | 197 | - Start with lower concurrency (10-20) and adjust based on performance 198 | - Monitor token usage and API rate limits 199 | - Consider provider-specific rate limits 200 | 201 | ### Cache Configuration 202 | 203 | - Use disk cache for production environments 204 | - Set appropriate TTL based on data freshness requirements 205 | - Monitor cache size and hit ratios 206 | 207 | ### Error Handling 208 | 209 | - Configure retry attempts based on API stability 210 | - Use appropriate timeout values 211 | - Implement proper error handling in your code 212 | 213 | ### Resource Management 214 | 215 | - Close cache providers when done 216 | - Monitor memory usage 217 | - Use appropriate chunk sizes for your use case 218 | 219 | ## Performance Optimization 220 | 221 | ### Caching Strategy 222 | 223 | - Enable caching for repeated requests 224 | - Use appropriate TTL values 225 | - Monitor cache hit ratios 226 | 227 | ### Concurrency Tuning 228 | 229 | - Adjust concurrency based on: 230 | * API rate limits 231 | * System resources 232 | * Response times 233 | 234 | ### Memory Management 235 | 236 | - Use appropriate chunk sizes 237 | - Monitor memory usage 238 | - Clean up resources properly 239 | 240 | ## Monitoring and Debugging 241 | 242 | ### Progress Tracking 243 | 244 | ```python 245 | manager = RequestManager( 246 | provider=provider, 247 | show_progress=True 248 | ) 249 | ``` 250 | 251 | ### Token Usage Monitoring 252 | 253 | Track token usage through the TokenStats class: 254 | - Prompt tokens 255 | - Completion tokens 256 | - Token rates 257 | - Cache hit ratios -------------------------------------------------------------------------------- /docs/core-components.md: -------------------------------------------------------------------------------- 1 | # Core Components 2 | 3 | This document provides detailed information about the core components of FastLLM. 4 | 5 | ## RequestManager 6 | 7 | The RequestManager is the central component responsible for handling parallel LLM API requests. 8 | 9 | ```python 10 | RequestManager( 11 | provider: Provider[ResponseT], 12 | concurrency: int = 100, 13 | timeout: float = 30.0, 14 | retry_attempts: int = 3, 15 | retry_delay: float = 1.0, 16 | show_progress: bool = True, 17 | caching_provider: Optional[CacheProvider] = None 18 | ) 19 | ``` 20 | 21 | ### Features 22 | 23 | - **Parallel Processing**: Handles multiple requests concurrently using asyncio 24 | - **Chunked Processing**: Processes requests in optimal chunks based on concurrency 25 | - **Progress Tracking**: Real-time progress monitoring with rich statistics 26 | - **Cache Integration**: Seamless integration with caching providers 27 | - **Error Handling**: Comprehensive error handling with retries 28 | 29 | ### Key Methods 30 | 31 | - `process_batch()`: Main synchronous API for batch processing 32 | - `_process_batch_async()`: Internal async implementation 33 | - `_process_request_async()`: Individual request processing 34 | - `_calculate_chunk_size()`: Dynamic chunk size optimization 35 | 36 | ## Provider System 37 | 38 | The provider system enables integration with different LLM APIs. 39 | 40 | ### Base Provider 41 | 42 | ```python 43 | class Provider(Generic[ResponseT], ABC): 44 | def __init__( 45 | self, 46 | api_key: str, 47 | api_base: str, 48 | headers: Optional[dict[str, str]] = None, 49 | **kwargs: Any 50 | ) 51 | ``` 52 | 53 | #### Key Methods 54 | 55 | - `get_request_url()`: Constructs API endpoint URLs 56 | - `get_request_headers()`: Manages API request headers 57 | - `make_request()`: Handles API communication 58 | 59 | ### OpenAI Provider 60 | 61 | Implements OpenAI-specific functionality: 62 | 63 | ```python 64 | class OpenAIProvider(Provider[ChatCompletion]): 65 | def __init__( 66 | self, 67 | api_key: str, 68 | api_base: str = DEFAULT_API_BASE, 69 | organization: Optional[str] = None, 70 | headers: Optional[dict[str, str]] = None, 71 | **kwargs: Any 72 | ) 73 | ``` 74 | 75 | ## Caching System 76 | 77 | The caching system provides flexible caching solutions. 78 | 79 | ### Cache Provider Interface 80 | 81 | ```python 82 | class CacheProvider: 83 | async def exists(self, key: str) -> bool 84 | async def get(self, key: str) 85 | async def put(self, key: str, value) -> None 86 | async def clear(self) -> None 87 | async def close(self) -> None 88 | ``` 89 | 90 | ### Implementations 91 | 92 | #### InMemoryCache 93 | 94 | Simple dictionary-based cache for non-persistent storage: 95 | 96 | ```python 97 | class InMemoryCache(CacheProvider): 98 | def __init__(self) 99 | ``` 100 | 101 | #### DiskCache 102 | 103 | Persistent disk-based cache with TTL support: 104 | 105 | ```python 106 | class DiskCache(CacheProvider): 107 | def __init__( 108 | self, 109 | directory: str, 110 | ttl: Optional[int] = None, 111 | **cache_options 112 | ) 113 | ``` 114 | 115 | ## Request/Response Models 116 | 117 | ### LLMRequest 118 | 119 | Base model for LLM requests: 120 | 121 | ```python 122 | class LLMRequest(BaseModel): 123 | provider: str 124 | messages: list[Message] 125 | model: Optional[str] = None 126 | temperature: float = Field(default=0.7, ge=0.0, le=2.0) 127 | max_completion_tokens: Optional[int] = None 128 | top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0) 129 | presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0) 130 | frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0) 131 | stop: Optional[list[str]] = None 132 | stream: bool = False 133 | ``` 134 | 135 | ### Message 136 | 137 | Represents a single message in a conversation: 138 | 139 | ```python 140 | class Message(BaseModel): 141 | role: Literal["system", "user", "assistant", "function", "tool"] = "user" 142 | content: Optional[str] = None 143 | name: Optional[str] = None 144 | function_call: Optional[dict[str, Any]] = None 145 | tool_calls: Optional[list[dict[str, Any]]] = None 146 | ``` 147 | 148 | ### TokenStats 149 | 150 | Tracks token usage and performance metrics: 151 | 152 | ```python 153 | @dataclass 154 | class TokenStats: 155 | prompt_tokens: int = 0 156 | completion_tokens: int = 0 157 | total_tokens: int = 0 158 | requests_completed: int = 0 159 | cache_hits: int = 0 160 | start_time: float = 0.0 161 | ``` 162 | 163 | ### ProgressTracker 164 | 165 | Manages progress display and statistics: 166 | 167 | ```python 168 | class ProgressTracker: 169 | def __init__(self, total_requests: int, show_progress: bool = True) 170 | ``` 171 | 172 | Features: 173 | - Real-time progress bar 174 | - Token usage statistics 175 | - Cache hit ratio tracking 176 | - Performance metrics 177 | 178 | ## Request Batching 179 | 180 | ### RequestBatch 181 | 182 | Provides an OpenAI-like interface for batching requests: 183 | 184 | ```python 185 | class RequestBatch(AbstractContextManager): 186 | def __init__(self) 187 | ``` 188 | 189 | Usage: 190 | ```python 191 | with RequestBatch() as batch: 192 | batch.chat.completions.create( 193 | model="gpt-3.5-turbo", 194 | messages=[{"role": "user", "content": "Hello!"}] 195 | ) 196 | ``` 197 | 198 | ## Component Interactions 199 | 200 | 1. **Request Flow**: 201 | ``` 202 | RequestBatch → RequestManager → Cache Check → Provider → Response 203 | ``` 204 | 205 | 2. **Caching Flow**: 206 | ``` 207 | Request → Hash Computation → Cache Lookup → (Cache Hit/Miss) → Response 208 | ``` 209 | 210 | 3. **Progress Tracking**: 211 | ``` 212 | Request Processing → Stats Update → Progress Display → Completion 213 | ``` 214 | 215 | ## Best Practices 216 | 217 | 1. **Request Management**: 218 | - Use appropriate concurrency limits 219 | - Implement proper error handling 220 | - Monitor token usage 221 | 222 | 2. **Caching**: 223 | - Choose appropriate cache provider 224 | - Configure TTL based on needs 225 | - Monitor cache hit ratios 226 | 227 | 3. **Provider Implementation**: 228 | - Handle rate limits properly 229 | - Implement comprehensive error handling 230 | - Follow provider-specific best practices -------------------------------------------------------------------------------- /docs/development-guide.md: -------------------------------------------------------------------------------- 1 | # Development Guide 2 | 3 | This guide provides information for developers who want to contribute to FastLLM or extend its functionality. 4 | 5 | ## Development Setup 6 | 7 | ### Prerequisites 8 | 9 | - Python 3.9 or higher 10 | - uv for dependency management 11 | - Git for version control 12 | 13 | ### Setting Up Development Environment 14 | 15 | 1. Clone the repository: 16 | ```bash 17 | git clone https://github.com/Rexhaif/fastllm.git 18 | cd fastllm 19 | ``` 20 | 21 | 2. Install dependencies: 22 | ```bash 23 | uv pip install --with dev 24 | ``` 25 | 26 | 3. Activate virtual environment: 27 | ```bash 28 | uv venv 29 | source .venv/bin/activate # On Unix/macOS 30 | # or 31 | .venv\Scripts\activate # On Windows 32 | ``` 33 | 34 | ## Project Structure 35 | 36 | ``` 37 | fastllm/ 38 | ├── fastllm/ 39 | │ ├── __init__.py 40 | │ ├── cache.py # Caching system 41 | │ ├── cli.py # CLI interface 42 | │ ├── core.py # Core functionality 43 | │ └── providers/ # Provider implementations 44 | │ ├── __init__.py 45 | │ ├── base.py # Base provider class 46 | │ └── openai.py # OpenAI provider 47 | ├── tests/ # Test suite 48 | ├── examples/ # Example code 49 | ├── pyproject.toml # Project configuration 50 | └── README.md # Project documentation 51 | ``` 52 | 53 | ## Adding a New Provider 54 | 55 | To add support for a new LLM provider: 56 | 57 | 1. Create a new file in `fastllm/providers/`: 58 | 59 | ```python 60 | # fastllm/providers/custom_provider.py 61 | 62 | from typing import Any, Optional 63 | import httpx 64 | from fastllm.providers.base import Provider 65 | from fastllm.core import ResponseT 66 | 67 | class CustomProvider(Provider[ResponseT]): 68 | def __init__( 69 | self, 70 | api_key: str, 71 | api_base: str, 72 | headers: Optional[dict[str, str]] = None, 73 | **kwargs: Any 74 | ): 75 | super().__init__(api_key, api_base, headers, **kwargs) 76 | # Add provider-specific initialization 77 | 78 | def get_request_headers(self) -> dict[str, str]: 79 | return { 80 | "Authorization": f"Bearer {self.api_key}", 81 | "Content-Type": "application/json", 82 | **self.headers 83 | } 84 | 85 | async def make_request( 86 | self, 87 | client: httpx.AsyncClient, 88 | request: dict[str, Any], 89 | timeout: float 90 | ) -> ResponseT: 91 | # Implement provider-specific request handling 92 | url = self.get_request_url("endpoint") 93 | response = await client.post( 94 | url, 95 | headers=self.get_request_headers(), 96 | json=request, 97 | timeout=timeout 98 | ) 99 | response.raise_for_status() 100 | return response.json() 101 | ``` 102 | 103 | 2. Add tests for the new provider: 104 | 105 | ```python 106 | # tests/test_custom_provider.py 107 | 108 | import pytest 109 | from fastllm.providers.custom_provider import CustomProvider 110 | 111 | @pytest.fixture 112 | def provider(): 113 | return CustomProvider( 114 | api_key="test-key", 115 | api_base="https://api.example.com" 116 | ) 117 | 118 | async def test_make_request(provider): 119 | # Implement provider tests 120 | pass 121 | ``` 122 | 123 | ## Adding a New Cache Provider 124 | 125 | To implement a new cache provider: 126 | 127 | 1. Create a new cache implementation: 128 | 129 | ```python 130 | from fastllm.cache import CacheProvider 131 | 132 | class CustomCache(CacheProvider): 133 | def __init__(self, **options): 134 | # Initialize your cache 135 | 136 | async def exists(self, key: str) -> bool: 137 | # Implement key existence check 138 | 139 | async def get(self, key: str): 140 | # Implement value retrieval 141 | 142 | async def put(self, key: str, value) -> None: 143 | # Implement value storage 144 | 145 | async def clear(self) -> None: 146 | # Implement cache clearing 147 | 148 | async def close(self) -> None: 149 | # Implement cleanup 150 | ``` 151 | 152 | 2. Add tests for the new cache: 153 | 154 | ```python 155 | # tests/test_custom_cache.py 156 | 157 | import pytest 158 | from your_cache_module import CustomCache 159 | 160 | @pytest.fixture 161 | async def cache(): 162 | cache = CustomCache() 163 | yield cache 164 | await cache.close() 165 | 166 | async def test_cache_operations(cache): 167 | # Implement cache tests 168 | pass 169 | ``` 170 | 171 | ## Testing 172 | 173 | ### Running Tests 174 | 175 | ```bash 176 | # Run all tests 177 | pytest 178 | 179 | # Run with coverage 180 | pytest --cov=fastllm 181 | 182 | # Run specific test file 183 | pytest tests/test_specific.py 184 | ``` 185 | 186 | ### Writing Tests 187 | 188 | 1. Use pytest fixtures for common setup: 189 | 190 | ```python 191 | @pytest.fixture 192 | def request_manager(provider, cache): 193 | return RequestManager( 194 | provider=provider, 195 | caching_provider=cache 196 | ) 197 | ``` 198 | 199 | 2. Test async functionality: 200 | 201 | ```python 202 | @pytest.mark.asyncio 203 | async def test_async_function(): 204 | # Test implementation 205 | pass 206 | ``` 207 | 208 | 3. Use mocking when appropriate: 209 | 210 | ```python 211 | from unittest.mock import Mock, patch 212 | 213 | def test_with_mock(): 214 | with patch('module.function') as mock_func: 215 | # Test implementation 216 | pass 217 | ``` 218 | 219 | ## Code Style 220 | 221 | ### Style Guide 222 | 223 | - Follow PEP 8 guidelines 224 | - Use type hints 225 | - Document public APIs 226 | - Keep functions focused and small 227 | 228 | ### Code Formatting 229 | 230 | The project uses: 231 | - Black for code formatting 232 | - Ruff for linting 233 | - isort for import sorting 234 | 235 | ```bash 236 | # Format code 237 | black . 238 | 239 | # Sort imports 240 | isort . 241 | 242 | # Run linter 243 | ruff check . 244 | ``` 245 | 246 | ## Documentation 247 | 248 | ### Docstring Format 249 | 250 | Use Google-style docstrings: 251 | 252 | ```python 253 | def function(param1: str, param2: int) -> bool: 254 | """Short description. 255 | 256 | Longer description if needed. 257 | 258 | Args: 259 | param1: Parameter description 260 | param2: Parameter description 261 | 262 | Returns: 263 | Description of return value 264 | 265 | Raises: 266 | ValueError: Description of when this error occurs 267 | """ 268 | ``` 269 | 270 | ### Building Documentation 271 | 272 | The project uses Markdown for documentation: 273 | 274 | 1. Place documentation in `docs/` 275 | 2. Use clear, concise language 276 | 3. Include code examples 277 | 4. Keep documentation up to date 278 | 279 | ## Performance Considerations 280 | 281 | When developing new features: 282 | 283 | 1. **Concurrency** 284 | - Use async/await properly 285 | - Avoid blocking operations 286 | - Handle resources correctly 287 | 288 | 2. **Memory Usage** 289 | - Monitor memory consumption 290 | - Clean up resources 291 | - Use appropriate data structures 292 | 293 | 3. **Caching** 294 | - Implement efficient caching 295 | - Handle cache invalidation 296 | - Consider memory vs. speed tradeoffs 297 | 298 | ## Error Handling 299 | 300 | 1. Use appropriate exception types: 301 | 302 | ```python 303 | class CustomProviderError(Exception): 304 | """Base exception for custom provider.""" 305 | pass 306 | 307 | class CustomProviderAuthError(CustomProviderError): 308 | """Authentication error for custom provider.""" 309 | pass 310 | ``` 311 | 312 | 2. Implement proper error handling: 313 | 314 | ```python 315 | async def make_request(self, ...): 316 | try: 317 | response = await self._make_api_call() 318 | except httpx.TimeoutException as e: 319 | raise CustomProviderError(f"API timeout: {e}") 320 | except httpx.HTTPError as e: 321 | raise CustomProviderError(f"HTTP error: {e}") 322 | ``` 323 | 324 | ## Contributing 325 | 326 | 1. Fork the repository 327 | 2. Create a feature branch 328 | 3. Make your changes 329 | 4. Add tests 330 | 5. Update documentation 331 | 6. Submit a pull request 332 | 333 | ### Commit Guidelines 334 | 335 | - Use clear commit messages 336 | - Reference issues when applicable 337 | - Keep commits focused and atomic 338 | 339 | ### Pull Request Process 340 | 341 | 1. Update documentation 342 | 2. Add tests 343 | 3. Ensure CI passes 344 | 4. Request review 345 | 5. Address feedback 346 | 347 | ## Release Process 348 | 349 | 1. Update version in pyproject.toml 350 | 2. Update CHANGELOG.md 351 | 3. Create release commit 352 | 4. Tag release 353 | 5. Push to repository 354 | 6. Build and publish to PyPI -------------------------------------------------------------------------------- /docs/getting-started.md: -------------------------------------------------------------------------------- 1 | # Getting Started with FastLLM 2 | 3 | This guide will help you get started with FastLLM for efficient parallel LLM API requests. 4 | 5 | ## Installation 6 | 7 | Install FastLLM using pip: 8 | 9 | ```bash 10 | pip install fastllm 11 | ``` 12 | 13 | ## Quick Start 14 | 15 | Here's a simple example to get you started: 16 | 17 | ```python 18 | from fastllm import RequestManager 19 | from fastllm.providers import OpenAIProvider 20 | 21 | # Initialize the provider 22 | provider = OpenAIProvider( 23 | api_key="your-api-key", 24 | organization="your-org-id" # Optional 25 | ) 26 | 27 | # Create request manager 28 | manager = RequestManager( 29 | provider=provider, 30 | concurrency=10, # Number of concurrent requests 31 | show_progress=True # Show progress bar 32 | ) 33 | 34 | # Create a batch of requests 35 | from fastllm import RequestBatch 36 | 37 | with RequestBatch() as batch: 38 | # Add multiple requests to the batch 39 | batch.chat.completions.create( 40 | model="gpt-3.5-turbo", 41 | messages=[{"role": "user", "content": "Hello!"}] 42 | ) 43 | batch.chat.completions.create( 44 | model="gpt-3.5-turbo", 45 | messages=[{"role": "user", "content": "How are you?"}] 46 | ) 47 | 48 | # Process the batch 49 | responses = manager.process_batch(batch) 50 | 51 | # Work with responses 52 | for response in responses: 53 | print(response.content) 54 | ``` 55 | 56 | ## Adding Caching 57 | 58 | Enable caching to improve performance and reduce API calls: 59 | 60 | ```python 61 | from fastllm.cache import DiskCache 62 | 63 | # Initialize disk cache 64 | cache = DiskCache( 65 | directory="cache", # Cache directory 66 | ttl=3600 # Cache TTL in seconds 67 | ) 68 | 69 | # Create request manager with cache 70 | manager = RequestManager( 71 | provider=provider, 72 | concurrency=10, 73 | caching_provider=cache, 74 | show_progress=True 75 | ) 76 | ``` 77 | 78 | ## Processing Large Batches 79 | 80 | For large batches of requests: 81 | 82 | ```python 83 | # Create many requests 84 | with RequestBatch() as batch: 85 | for i in range(1000): 86 | batch.chat.completions.create( 87 | model="gpt-3.5-turbo", 88 | messages=[ 89 | {"role": "user", "content": f"Process item {i}"} 90 | ] 91 | ) 92 | 93 | # Process with automatic chunking and progress tracking 94 | responses = manager.process_batch(batch) 95 | ``` 96 | 97 | ## Custom Configuration 98 | 99 | ### Timeout and Retry Settings 100 | 101 | ```python 102 | manager = RequestManager( 103 | provider=provider, 104 | timeout=30.0, # Request timeout in seconds 105 | retry_attempts=3, # Number of retry attempts 106 | retry_delay=1.0, # Delay between retries 107 | ) 108 | ``` 109 | 110 | ### Provider Configuration 111 | 112 | ```python 113 | provider = OpenAIProvider( 114 | api_key="your-api-key", 115 | api_base="https://api.openai.com/v1", # Custom API endpoint 116 | headers={ 117 | "Custom-Header": "value" 118 | } 119 | ) 120 | ``` 121 | 122 | ## Progress Tracking 123 | 124 | FastLLM provides detailed progress information: 125 | 126 | ```python 127 | manager = RequestManager( 128 | provider=provider, 129 | show_progress=True # Enable progress bar 130 | ) 131 | 132 | # Progress will show: 133 | # - Completion percentage 134 | # - Token usage rates 135 | # - Cache hit ratio 136 | # - Estimated time remaining 137 | ``` 138 | 139 | ## Error Handling 140 | 141 | ```python 142 | try: 143 | responses = manager.process_batch(batch) 144 | except Exception as e: 145 | print(f"Error processing batch: {e}") 146 | ``` 147 | 148 | ## Advanced Usage 149 | 150 | ### Custom Message Formatting 151 | 152 | ```python 153 | from fastllm import Message 154 | 155 | # Create structured messages 156 | messages = [ 157 | Message(role="system", content="You are a helpful assistant"), 158 | Message(role="user", content="Hello!") 159 | ] 160 | 161 | with RequestBatch() as batch: 162 | batch.chat.completions.create( 163 | model="gpt-3.5-turbo", 164 | messages=messages 165 | ) 166 | ``` 167 | 168 | ### Working with Response Data 169 | 170 | ```python 171 | for response in responses: 172 | # Access response content 173 | print(response.content) 174 | 175 | # Access usage statistics 176 | if response.usage: 177 | print(f"Prompt tokens: {response.usage.prompt_tokens}") 178 | print(f"Completion tokens: {response.usage.completion_tokens}") 179 | print(f"Total tokens: {response.usage.total_tokens}") 180 | ``` 181 | 182 | ## Best Practices 183 | 184 | 1. **Batch Processing** 185 | - Group related requests together 186 | - Use appropriate concurrency limits 187 | - Enable progress tracking for large batches 188 | 189 | 2. **Caching** 190 | - Use disk cache for persistence 191 | - Set appropriate TTL values 192 | - Monitor cache hit ratios 193 | 194 | 3. **Error Handling** 195 | - Implement proper error handling 196 | - Use retry mechanisms 197 | - Monitor token usage 198 | 199 | 4. **Resource Management** 200 | - Close cache providers when done 201 | - Monitor memory usage 202 | - Use appropriate chunk sizes 203 | 204 | ## Next Steps 205 | 206 | - Read the [Architecture](architecture.md) documentation 207 | - Explore [Core Components](core-components.md) 208 | - Check the [API Reference](api-reference.md) -------------------------------------------------------------------------------- /docs/overview.md: -------------------------------------------------------------------------------- 1 | # FastLLM Overview 2 | 3 | FastLLM is a powerful Python library that enables efficient parallel processing of Large Language Model (LLM) API requests. It provides a robust solution for applications that need to make multiple LLM API calls simultaneously while managing resources effectively. 4 | 5 | ## Key Features 6 | 7 | ### 1. Parallel Request Processing 8 | - Efficient handling of concurrent API requests 9 | - Configurable concurrency limits 10 | - Built-in request batching and chunking 11 | - Progress tracking with detailed statistics 12 | 13 | ### 2. Intelligent Caching 14 | - Multiple cache provider options (In-memory and Disk-based) 15 | - Consistent request hashing for cache keys 16 | - Configurable TTL for cached responses 17 | - Automatic cache management 18 | 19 | ### 3. Multiple Provider Support 20 | - Modular provider architecture 21 | - Built-in OpenAI provider implementation 22 | - Extensible base classes for adding new providers 23 | - Provider-agnostic request/response models 24 | 25 | ### 4. Resource Management 26 | - Automatic rate limiting 27 | - Configurable timeout handling 28 | - Retry mechanism with exponential backoff 29 | - Connection pooling 30 | 31 | ### 5. Progress Monitoring 32 | - Real-time progress tracking 33 | - Token usage statistics 34 | - Cache hit ratio monitoring 35 | - Performance metrics (tokens/second) 36 | 37 | ### 6. Developer-Friendly Interface 38 | - OpenAI-like API design 39 | - Type hints throughout the codebase 40 | - Comprehensive error handling 41 | - Async/await support 42 | 43 | ## Use Cases 44 | 45 | FastLLM is particularly useful for: 46 | 47 | 1. **Batch Processing**: Processing large numbers of LLM requests efficiently 48 | - Data analysis pipelines 49 | - Content generation at scale 50 | - Bulk text processing 51 | 52 | 2. **High-Performance Applications**: Applications requiring optimal LLM API usage 53 | - Real-time applications 54 | - Interactive systems 55 | - API proxies and middleware 56 | 57 | 3. **Resource-Conscious Systems**: Systems that need to optimize API usage 58 | - Cost optimization through caching 59 | - Rate limit management 60 | - Token usage optimization 61 | 62 | 4. **Development and Testing**: Tools for LLM application development 63 | - Rapid prototyping 64 | - Testing and benchmarking 65 | - Performance optimization 66 | 67 | ## Performance Considerations 68 | 69 | FastLLM is designed with performance in mind: 70 | 71 | - **Parallel Processing**: Efficiently processes multiple requests concurrently 72 | - **Smart Batching**: Automatically chunks requests for optimal throughput 73 | - **Cache Optimization**: Reduces API calls through intelligent caching 74 | - **Resource Management**: Prevents overload through concurrency control 75 | - **Memory Efficiency**: Manages memory usage through chunked processing -------------------------------------------------------------------------------- /docs/rate-limiting.md: -------------------------------------------------------------------------------- 1 | # Rate Limiting Design 2 | 3 | ## Overview 4 | 5 | This document outlines the design for implementing a flexible rate limiting system for FastLLM providers. The system will support both request-based and token-based rate limits with human-readable time units. 6 | 7 | ## Requirements 8 | 9 | 1. Support multiple rate limits per provider 10 | 2. Handle both request-per-time-unit and tokens-per-time-unit limits 11 | 3. Support human-readable time unit specifications (e.g., "10s", "1m", "1h", "1d") 12 | 4. Track token usage from provider responses 13 | 5. Integrate with existing Provider and RequestManager architecture 14 | 15 | ## System Design 16 | 17 | ### Rate Limit Configuration 18 | 19 | Rate limits will be specified in the provider configuration using a list structure: 20 | 21 | ```python 22 | rate_limits = [ 23 | { 24 | "type": "requests", 25 | "limit": 100, 26 | "period": "1m" 27 | }, 28 | { 29 | "type": "tokens", 30 | "limit": 100000, 31 | "period": "1h" 32 | } 33 | ] 34 | ``` 35 | 36 | ### Components 37 | 38 | #### 1. RateLimitManager 39 | 40 | A new component responsible for tracking and enforcing rate limits: 41 | 42 | ```python 43 | class RateLimitManager: 44 | def __init__(self, limits: List[RateLimit]): 45 | self.limits = limits 46 | self.windows = {} # Tracks usage windows 47 | 48 | async def check_limits(self) -> bool: 49 | # Check all limits 50 | 51 | async def record_usage(self, usage: Usage): 52 | # Record request and token usage 53 | ``` 54 | 55 | #### 2. RateLimit Models 56 | 57 | ```python 58 | @dataclass 59 | class RateLimit: 60 | type: Literal["requests", "tokens"] 61 | limit: int 62 | period: str # Human-readable period 63 | 64 | @dataclass 65 | class Usage: 66 | request_count: int 67 | token_count: Optional[int] 68 | timestamp: datetime 69 | ``` 70 | 71 | #### 3. TimeWindow 72 | 73 | Handles the sliding window implementation for rate limiting: 74 | 75 | ```python 76 | class TimeWindow: 77 | def __init__(self, period: str): 78 | self.period = self._parse_period(period) 79 | self.usage_records = [] 80 | 81 | def add_usage(self, usage: Usage): 82 | # Add usage and cleanup old records 83 | 84 | def current_usage(self) -> int: 85 | # Calculate current usage in window 86 | ``` 87 | 88 | ### Integration with Existing Architecture 89 | 90 | #### 1. Provider Base Class Enhancement 91 | 92 | ```python 93 | class Provider: 94 | def __init__(self, rate_limits: List[Dict]): 95 | self.rate_limit_manager = RateLimitManager( 96 | [RateLimit(**limit) for limit in rate_limits] 97 | ) 98 | 99 | async def _check_rate_limits(self): 100 | # Check before making requests 101 | 102 | async def _record_usage(self, response): 103 | # Record after successful requests 104 | ``` 105 | 106 | #### 2. RequestManager Integration 107 | 108 | The RequestManager will be enhanced to: 109 | - Check rate limits before processing requests 110 | - Handle rate limit errors gracefully 111 | - Implement backoff strategies when limits are reached 112 | 113 | ### Time Unit Parsing 114 | 115 | Time unit parsing will support: 116 | - Seconds: "s", "sec", "second", "seconds" 117 | - Minutes: "m", "min", "minute", "minutes" 118 | - Hours: "h", "hr", "hour", "hours" 119 | - Days: "d", "day", "days" 120 | 121 | Example implementation: 122 | 123 | ```python 124 | def parse_time_unit(period: str) -> timedelta: 125 | match = re.match(r"(\d+)([smhd])", period) 126 | if not match: 127 | raise ValueError("Invalid time unit format") 128 | 129 | value, unit = match.groups() 130 | value = int(value) 131 | 132 | return { 133 | 's': timedelta(seconds=value), 134 | 'm': timedelta(minutes=value), 135 | 'h': timedelta(hours=value), 136 | 'd': timedelta(days=value) 137 | }[unit] 138 | ``` 139 | 140 | ## Usage Example 141 | 142 | ```python 143 | # Provider configuration 144 | openai_provider = OpenAIProvider( 145 | api_key="...", 146 | rate_limits=[ 147 | { 148 | "type": "requests", 149 | "limit": 100, 150 | "period": "1m" 151 | }, 152 | { 153 | "type": "tokens", 154 | "limit": 100000, 155 | "period": "1h" 156 | } 157 | ] 158 | ) 159 | 160 | # Usage in code 161 | async with openai_provider as provider: 162 | response = await provider.complete(prompt) 163 | # Rate limits are automatically checked and updated 164 | ``` 165 | 166 | ## Error Handling 167 | 168 | 1. Rate Limit Exceeded 169 | ```python 170 | class RateLimitExceeded(Exception): 171 | def __init__(self, limit_type: str, retry_after: float): 172 | self.limit_type = limit_type 173 | self.retry_after = retry_after 174 | ``` 175 | 176 | 2. Recovery Strategy 177 | - Implement exponential backoff 178 | - Queue requests when limits are reached 179 | - Provide retry-after information to clients 180 | 181 | ## Implementation Phases 182 | 183 | 1. Phase 1: Basic Implementation 184 | - Implement RateLimitManager 185 | - Add time unit parsing 186 | - Integrate with Provider base class 187 | 188 | 2. Phase 2: Enhanced Features 189 | - Add token tracking 190 | - Implement sliding windows 191 | - Add multiple limit support 192 | 193 | 3. Phase 3: Optimization 194 | - Add request queuing 195 | - Implement smart backoff strategies 196 | - Add monitoring and metrics 197 | 198 | ## Monitoring and Metrics 199 | 200 | The rate limiting system will expose metrics for: 201 | - Current usage per window 202 | - Remaining quota 203 | - Rate limit hits 204 | - Average usage patterns 205 | 206 | These metrics can be integrated with the existing monitoring system. 207 | 208 | ## Future Considerations 209 | 210 | 1. Distributed Rate Limiting 211 | - Support for Redis-based rate limiting 212 | - Cluster-aware rate limiting 213 | 214 | 2. Dynamic Rate Limits 215 | - Allow providers to update limits based on API responses 216 | - Support for dynamic quota adjustments 217 | 218 | 3. Rate Limit Optimization 219 | - Predictive rate limiting 220 | - Smart request scheduling -------------------------------------------------------------------------------- /examples/embedding_example.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Test script for embeddings API with parallel processing and caching.""" 3 | 4 | import json 5 | import os 6 | from pathlib import Path 7 | from typing import Any, List, Optional, Union, Literal 8 | import asyncio 9 | import numpy as np 10 | 11 | import typer 12 | from rich.console import Console 13 | from rich.panel import Panel 14 | from rich.table import Table 15 | from dotenv import load_dotenv 16 | 17 | from fastllm import ( 18 | RequestBatch, RequestManager, 19 | ResponseWrapper, InMemoryCache, DiskCache, 20 | OpenAIProvider 21 | ) 22 | 23 | # Default values for command options 24 | DEFAULT_MODEL = "text-embedding-3-small" 25 | DEFAULT_DIMENSIONS = 384 26 | DEFAULT_CONCURRENCY = 10 27 | DEFAULT_OUTPUT = "embedding_results.json" 28 | 29 | app = typer.Typer() 30 | 31 | load_dotenv() 32 | 33 | # Sample texts for embedding 34 | DEFAULT_TEXTS = [ 35 | f"The quick brown fox jumps over the lazy dog {i}" for i in range(100) 36 | ] 37 | 38 | 39 | def cosine_similarity(a, b): 40 | """Calculate cosine similarity between two vectors.""" 41 | return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) 42 | 43 | 44 | def normalize_l2(x): 45 | """L2 normalize a vector or batch of vectors.""" 46 | x = np.array(x) 47 | if x.ndim == 1: 48 | norm = np.linalg.norm(x) 49 | if norm == 0: 50 | return x 51 | return x / norm 52 | else: 53 | norm = np.linalg.norm(x, 2, axis=1, keepdims=True) 54 | return np.where(norm == 0, x, x / norm) 55 | 56 | 57 | def process_response( 58 | response: ResponseWrapper, index: int 59 | ) -> dict[str, Any]: 60 | """Process an embedding response into a serializable format.""" 61 | if not response.is_embedding_response: 62 | return { 63 | "index": index, 64 | "type": "error", 65 | "error": "Not an embedding response" 66 | } 67 | 68 | return { 69 | "index": index, 70 | "type": "success", 71 | "request_id": response.request_id, 72 | "model": response.response.get("model"), 73 | "usage": response.response.get("usage"), 74 | "embedding_count": len(response.embeddings), 75 | "embedding_dimensions": [len(emb) for emb in response.embeddings], 76 | } 77 | 78 | 79 | def create_embedding_batch( 80 | texts: List[str], 81 | model: str, 82 | dimensions: Optional[int] = None, 83 | encoding_format: Optional[str] = None 84 | ) -> RequestBatch: 85 | """Create a batch of embedding requests.""" 86 | batch = RequestBatch() 87 | 88 | # Add individual text embedding requests 89 | for text in texts: 90 | batch.embeddings.create( 91 | model=model, 92 | input=text, 93 | dimensions=dimensions, 94 | encoding_format=encoding_format, 95 | ) 96 | 97 | # Add a batch request with multiple texts (last 3 texts) 98 | if len(texts) >= 3: 99 | batch.embeddings.create( 100 | model=model, 101 | input=texts[-3:], 102 | dimensions=dimensions, 103 | encoding_format=encoding_format, 104 | ) 105 | 106 | return batch 107 | 108 | 109 | def run_test( 110 | *, # Force keyword arguments 111 | api_key: str, 112 | model: str, 113 | texts: List[str], 114 | dimensions: Optional[int], 115 | encoding_format: Optional[str], 116 | concurrency: int, 117 | output: Path | str, 118 | no_progress: bool = False, 119 | cache_type: Literal["memory", "disk"] = "memory", 120 | cache_ttl: Optional[int] = None, 121 | ) -> None: 122 | """Run the embedding test with given parameters.""" 123 | console = Console() 124 | 125 | # Create batch of requests 126 | batch = create_embedding_batch(texts, model, dimensions, encoding_format) 127 | 128 | # Show configuration 129 | console.print( 130 | Panel.fit( 131 | "\n".join( 132 | [ 133 | f"Model: {model}", 134 | f"Dimensions: {dimensions or 'default'}", 135 | f"Format: {encoding_format or 'default'}", 136 | f"Texts: {len(texts)}", 137 | f"Requests: {len(batch)}", 138 | f"Concurrency: {concurrency}", 139 | ] 140 | ), 141 | title="[bold blue]Embedding Test Configuration", 142 | ) 143 | ) 144 | 145 | # Create cache provider based on type 146 | if cache_type == "memory": 147 | cache_provider = InMemoryCache() 148 | else: 149 | cache_provider = DiskCache( 150 | directory="./cache", 151 | ttl=cache_ttl, 152 | size_limit=int(2e9), # 2GB size limit 153 | ) 154 | 155 | provider = OpenAIProvider( 156 | api_key=api_key, 157 | ) 158 | manager = RequestManager( 159 | provider=provider, 160 | concurrency=concurrency, 161 | show_progress=not no_progress, 162 | caching_provider=cache_provider, 163 | ) 164 | 165 | try: 166 | # First run: Process batch 167 | console.print("[bold blue]Starting first run (no cache)...[/bold blue]") 168 | responses_first = manager.process_batch(batch) 169 | 170 | # Extract and process results 171 | all_embeddings = [] 172 | results_data_first = [] 173 | 174 | for i, response in enumerate(responses_first): 175 | result = process_response(response, i) 176 | results_data_first.append(result) 177 | 178 | if response.is_embedding_response: 179 | embeddings = response.embeddings 180 | for embedding in embeddings: 181 | all_embeddings.append(embedding) 182 | 183 | # Summary table for first run 184 | console.print("\n[bold green]First Run Results:[/bold green]") 185 | table = Table(show_header=True) 186 | table.add_column("Request #") 187 | table.add_column("Model") 188 | table.add_column("Embeddings") 189 | table.add_column("Dimensions") 190 | table.add_column("Tokens") 191 | 192 | for i, result in enumerate(results_data_first): 193 | if result["type"] == "success": 194 | dims = ", ".join([str(d) for d in result["embedding_dimensions"]]) 195 | table.add_row( 196 | f"{i+1}", 197 | result["model"], 198 | str(result["embedding_count"]), 199 | dims, 200 | str(result["usage"]["prompt_tokens"]) 201 | ) 202 | 203 | console.print(table) 204 | 205 | # Compare some embeddings 206 | if len(all_embeddings) >= 2: 207 | console.print("\n[bold blue]Embedding Similarities:[/bold blue]") 208 | similarity_table = Table(show_header=True) 209 | similarity_table.add_column("Embedding Pair") 210 | similarity_table.add_column("Cosine Similarity") 211 | 212 | # Compare a few pairs (not all combinations to keep output manageable) 213 | num_comparisons = min(5, len(all_embeddings) * (len(all_embeddings) - 1) // 2) 214 | compared = set() 215 | comparison_count = 0 216 | 217 | for i in range(len(all_embeddings)): 218 | for j in range(i+1, len(all_embeddings)): 219 | if comparison_count >= num_comparisons: 220 | break 221 | if (i, j) not in compared: 222 | sim = cosine_similarity(all_embeddings[i], all_embeddings[j]) 223 | similarity_table.add_row(f"{i+1} and {j+1}", f"{sim:.4f}") 224 | compared.add((i, j)) 225 | comparison_count += 1 226 | 227 | console.print(similarity_table) 228 | 229 | # Second run: Process the same batch, expecting cached results 230 | console.print("\n[bold blue]Starting second run (cached)...[/bold blue]") 231 | responses_second = manager.process_batch(batch) 232 | results_data_second = [] 233 | 234 | for i, response in enumerate(responses_second): 235 | result = process_response(response, i) 236 | results_data_second.append(result) 237 | 238 | # Compare cache performance 239 | console.print( 240 | Panel.fit( 241 | "\n".join( 242 | [ 243 | f"First Run - Successful: [green]{sum(1 for r in results_data_first if r['type'] == 'success')}[/green]", 244 | f"Second Run - Successful: [green]{sum(1 for r in results_data_second if r['type'] == 'success')}[/green]", 245 | f"Total Requests: {len(batch)}", 246 | ] 247 | ), 248 | title="[bold green]Cache Performance", 249 | ) 250 | ) 251 | 252 | # Save results from both runs 253 | if output != "NO_OUTPUT": 254 | output_path = Path(output) 255 | output_path.write_text( 256 | json.dumps( 257 | { 258 | "config": { 259 | "model": model, 260 | "dimensions": dimensions, 261 | "encoding_format": encoding_format, 262 | "concurrency": concurrency, 263 | "cache_type": cache_type, 264 | "cache_ttl": cache_ttl, 265 | }, 266 | "first_run_results": results_data_first, 267 | "second_run_results": results_data_second, 268 | }, 269 | indent=2, 270 | ) 271 | ) 272 | console.print(f"\nResults saved to [bold]{output_path}[/bold]") 273 | 274 | finally: 275 | # Clean up disk cache if used 276 | if cache_type == "disk": 277 | # Run close in asyncio event loop 278 | asyncio.run(cache_provider.close()) 279 | 280 | 281 | @app.command() 282 | def main( 283 | model: str = typer.Option( 284 | DEFAULT_MODEL, 285 | "--model", 286 | "-m", 287 | help="Embedding model to use", 288 | ), 289 | dimensions: Optional[int] = typer.Option( 290 | DEFAULT_DIMENSIONS, 291 | "--dimensions", 292 | "-d", 293 | help="Dimensions for the embeddings (None uses model default)", 294 | ), 295 | encoding_format: Optional[str] = typer.Option( 296 | None, 297 | "--format", 298 | "-f", 299 | help="Encoding format (float or base64)", 300 | ), 301 | concurrency: int = typer.Option( 302 | DEFAULT_CONCURRENCY, 303 | "--concurrency", 304 | "-c", 305 | help="Concurrent requests", 306 | ), 307 | output: str = typer.Option( 308 | DEFAULT_OUTPUT, 309 | "--output", 310 | "-o", 311 | help="Output file", 312 | ), 313 | no_progress: bool = typer.Option( 314 | False, 315 | "--no-progress", 316 | help="Disable progress tracking", 317 | ), 318 | cache_type: str = typer.Option( 319 | "memory", 320 | "--cache-type", 321 | help="Cache type to use (memory or disk)", 322 | ), 323 | cache_ttl: Optional[int] = typer.Option( 324 | None, 325 | "--cache-ttl", 326 | help="Time to live in seconds for cached items (disk cache only)", 327 | ), 328 | input_file: Optional[str] = typer.Option( 329 | None, 330 | "--input-file", 331 | "-i", 332 | help="File with texts to embed (one per line)", 333 | ), 334 | ) -> None: 335 | """Run embeddings test with caching.""" 336 | # Get API key from environment variable 337 | api_key = os.environ.get("OPENAI_API_KEY") 338 | if not api_key: 339 | raise ValueError("Please set the OPENAI_API_KEY environment variable") 340 | 341 | # Load texts from file if provided, otherwise use default texts 342 | texts = DEFAULT_TEXTS 343 | if input_file: 344 | try: 345 | with open(input_file, 'r') as f: 346 | texts = [line.strip() for line in f if line.strip()] 347 | except Exception as e: 348 | typer.echo(f"Error reading input file: {e}") 349 | raise typer.Exit(code=1) 350 | 351 | run_test( 352 | api_key=api_key, 353 | model=model, 354 | texts=texts, 355 | dimensions=dimensions, 356 | encoding_format=encoding_format, 357 | concurrency=concurrency, 358 | output=output, 359 | no_progress=no_progress, 360 | cache_type=cache_type, 361 | cache_ttl=cache_ttl, 362 | ) 363 | 364 | 365 | if __name__ == "__main__": 366 | app() -------------------------------------------------------------------------------- /examples/notebook_test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from fastllm.core import RequestBatch, RequestManager\n", 10 | "from fastllm.providers.openai import OpenAIProvider\n", 11 | "from fastllm.cache import DiskCache\n", 12 | "\n", 13 | "%load_ext rich" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "manager = RequestManager(\n", 23 | " provider=OpenAIProvider(\n", 24 | " api_key=\"..\",\n", 25 | " api_base=\"https://openrouter.ai/api/v1\",\n", 26 | " ),\n", 27 | " caching_provider=DiskCache(directory=\"cache\"),\n", 28 | " concurrency=10,\n", 29 | " show_progress=True,\n", 30 | ")" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 3, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "ids = []\n", 40 | "with RequestBatch() as batch:\n", 41 | " for i in range(100):\n", 42 | " ids.append(\n", 43 | " batch.chat.completions.create(\n", 44 | " model=\"meta-llama/llama-3.2-3b-instruct\",\n", 45 | " messages=[\n", 46 | " {\n", 47 | " \"role\": \"user\",\n", 48 | " \"content\": f\"Print only number, number is {i}. Do not include any other text.\",\n", 49 | " }\n", 50 | " ],\n", 51 | " max_completion_tokens=100,\n", 52 | " temperature=0.5,\n", 53 | " )\n", 54 | " )" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 4, 60 | "metadata": {}, 61 | "outputs": [ 62 | { 63 | "data": { 64 | "application/vnd.jupyter.widget-view+json": { 65 | "model_id": "7b2a3a604b714056be4a38ad673ce7e0", 66 | "version_major": 2, 67 | "version_minor": 0 68 | }, 69 | "text/plain": [ 70 | "Output()" 71 | ] 72 | }, 73 | "metadata": {}, 74 | "output_type": "display_data" 75 | }, 76 | { 77 | "data": { 78 | "text/html": [ 79 | "
\n" 80 | ], 81 | "text/plain": [] 82 | }, 83 | "metadata": {}, 84 | "output_type": "display_data" 85 | } 86 | ], 87 | "source": [ 88 | "responses = manager.process_batch(batch)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 8, 94 | "metadata": {}, 95 | "outputs": [ 96 | { 97 | "data": { 98 | "text/html": [ 99 | "\n" 100 | ], 101 | "text/plain": [] 102 | }, 103 | "metadata": {}, 104 | "output_type": "display_data" 105 | }, 106 | { 107 | "data": { 108 | "text/plain": [ 109 | "\u001b[32m'b3cf61e94e5ae858'\u001b[0m" 110 | ] 111 | }, 112 | "execution_count": 8, 113 | "metadata": {}, 114 | "output_type": "execute_result" 115 | } 116 | ], 117 | "source": [ 118 | "ids[-1]" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 7, 124 | "metadata": {}, 125 | "outputs": [ 126 | { 127 | "data": { 128 | "text/html": [ 129 | "\n" 130 | ], 131 | "text/plain": [] 132 | }, 133 | "metadata": {}, 134 | "output_type": "display_data" 135 | }, 136 | { 137 | "data": { 138 | "text/plain": [ 139 | "\u001b[32m'b3cf61e94e5ae858'\u001b[0m" 140 | ] 141 | }, 142 | "execution_count": 7, 143 | "metadata": {}, 144 | "output_type": "execute_result" 145 | } 146 | ], 147 | "source": [ 148 | "responses[-1].request_id" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [] 157 | } 158 | ], 159 | "metadata": { 160 | "kernelspec": { 161 | "display_name": ".venv", 162 | "language": "python", 163 | "name": "python3" 164 | }, 165 | "language_info": { 166 | "codemirror_mode": { 167 | "name": "ipython", 168 | "version": 3 169 | }, 170 | "file_extension": ".py", 171 | "mimetype": "text/x-python", 172 | "name": "python", 173 | "nbconvert_exporter": "python", 174 | "pygments_lexer": "ipython3", 175 | "version": "3.12.7" 176 | } 177 | }, 178 | "nbformat": 4, 179 | "nbformat_minor": 2 180 | } 181 | -------------------------------------------------------------------------------- /examples/parallel_test.py: -------------------------------------------------------------------------------- 1 | """Test script for parallel request handling.""" 2 | 3 | import json 4 | import os 5 | from pathlib import Path 6 | from typing import Any, Optional, Union, Literal 7 | import asyncio 8 | 9 | import typer 10 | from openai.types.chat import ChatCompletion 11 | from rich.console import Console 12 | from rich.panel import Panel 13 | from dotenv import load_dotenv 14 | 15 | from fastllm.core import RequestBatch, RequestManager, ResponseWrapper 16 | from fastllm.providers.openai import OpenAIProvider 17 | from fastllm.cache import InMemoryCache, DiskCache 18 | 19 | # Default values for command options 20 | DEFAULT_REPEATS = 10 21 | DEFAULT_CONCURRENCY = 50 22 | DEFAULT_TEMPERATURE = 0.7 23 | DEFAULT_OUTPUT = "results.json" 24 | 25 | app = typer.Typer() 26 | 27 | load_dotenv() 28 | 29 | 30 | def process_response( 31 | response: ResponseWrapper[ChatCompletion], index: int 32 | ) -> dict[str, Any]: 33 | """Process a response into a serializable format.""" 34 | return { 35 | "index": index, 36 | "type": "success", 37 | "request_id": response.request_id, 38 | "raw_response": response.response, 39 | } 40 | 41 | 42 | def run_test( 43 | *, # Force keyword arguments 44 | api_key: str, 45 | model: str, 46 | repeats: int, 47 | concurrency: int, 48 | output: Path | str, 49 | temperature: float, 50 | max_tokens: Optional[int], 51 | no_progress: bool = False, 52 | cache_type: Literal["memory", "disk"] = "memory", 53 | cache_ttl: Optional[int] = None, 54 | ) -> None: 55 | """Run the test with given parameters.""" 56 | console = Console() 57 | 58 | # Create batch of requests using OpenAI-style API 59 | with RequestBatch() as batch: 60 | # Add single prompt requests 61 | for i in range(repeats): 62 | batch.chat.completions.create( 63 | model=model, 64 | messages=[ 65 | { 66 | "role": "user", 67 | "content": f"Print only number, number is {i}. Do not include any other text.", 68 | } 69 | ], 70 | temperature=temperature, 71 | max_completion_tokens=max_tokens, 72 | ) 73 | 74 | # Show configuration 75 | console.print( 76 | Panel.fit( 77 | "\n".join( 78 | [ 79 | f"Model: {model}", 80 | f"Temperature: {temperature}", 81 | f"Max Tokens: {max_tokens or 'default'}", 82 | f"Requests: {len(batch)}", 83 | f"Concurrency: {concurrency}", 84 | ] 85 | ), 86 | title="[bold blue]Test Configuration", 87 | ) 88 | ) 89 | 90 | # Create cache provider based on type 91 | if cache_type == "memory": 92 | cache_provider = InMemoryCache() 93 | else: 94 | cache_provider = DiskCache( 95 | directory="./cache", 96 | ttl=cache_ttl, 97 | size_limit=int(2e9), # 2GB size limit 98 | ) 99 | 100 | provider = OpenAIProvider( 101 | api_key=api_key, 102 | api_base="https://llm.buffedby.ai/v1", 103 | ) 104 | manager = RequestManager( 105 | provider=provider, 106 | concurrency=concurrency, 107 | show_progress=not no_progress, 108 | caching_provider=cache_provider, 109 | ) 110 | 111 | try: 112 | # First run: Process batch 113 | responses_first = manager.process_batch(batch) 114 | successful_first = 0 115 | results_data_first = [] 116 | for i, response in enumerate(responses_first): 117 | result = process_response(response, i) 118 | successful_first += 1 119 | results_data_first.append(result) 120 | if i+3 > len(responses_first): 121 | console.print(f"Response #{i+1} of {len(responses_first)}") 122 | console.print(result) 123 | 124 | console.print( 125 | Panel.fit( 126 | "\n".join( 127 | [ 128 | f"First Run - Successful: [green]{successful_first}[/green]", 129 | f"Total: {len(responses_first)} (matches {len(batch)} requests)", 130 | ] 131 | ), 132 | title="[bold green]Results - First Run", 133 | ) 134 | ) 135 | 136 | # Second run: Process the same batch, expecting cached results 137 | responses_second = manager.process_batch(batch) 138 | successful_second = 0 139 | results_data_second = [] 140 | for i, response in enumerate(responses_second): 141 | result = process_response(response, i) 142 | successful_second += 1 143 | results_data_second.append(result) 144 | if i+2 > len(responses_second): 145 | console.print(f"Response #{i+1} of {len(responses_second)}") 146 | console.print(result) 147 | console.print( 148 | Panel.fit( 149 | "\n".join( 150 | [ 151 | f"Second Run - Successful: [green]{successful_second}[/green]", 152 | f"Total: {len(responses_second)} (matches {len(batch)} requests)", 153 | ] 154 | ), 155 | title="[bold green]Results - Second Run (Cached)", 156 | ) 157 | ) 158 | 159 | # Save results from both runs 160 | if output != "NO_OUTPUT": 161 | output = Path(output) 162 | output.write_text( 163 | json.dumps( 164 | { 165 | "config": { 166 | "model": model, 167 | "temperature": temperature, 168 | "max_tokens": max_tokens, 169 | "repeats": repeats, 170 | "concurrency": concurrency, 171 | "cache_type": cache_type, 172 | "cache_ttl": cache_ttl, 173 | }, 174 | "first_run_results": results_data_first, 175 | "second_run_results": results_data_second, 176 | "first_run_summary": { 177 | "successful": successful_first, 178 | "total": len(responses_first), 179 | }, 180 | "second_run_summary": { 181 | "successful": successful_second, 182 | "total": len(responses_second), 183 | }, 184 | }, 185 | indent=2, 186 | ) 187 | ) 188 | finally: 189 | # Clean up disk cache if used 190 | if cache_type == "disk": 191 | # Run close in asyncio event loop 192 | asyncio.run(cache_provider.close()) 193 | 194 | 195 | @app.command() 196 | def main( 197 | model: str = typer.Option( 198 | "meta-llama/llama-3.2-3b-instruct", 199 | "--model", 200 | "-m", 201 | help="Model to use", 202 | ), 203 | repeats: int = typer.Option( 204 | DEFAULT_REPEATS, 205 | "--repeats", 206 | "-n", 207 | help="Number of repeats", 208 | ), 209 | concurrency: int = typer.Option( 210 | DEFAULT_CONCURRENCY, 211 | "--concurrency", 212 | "-c", 213 | help="Concurrent requests", 214 | ), 215 | output: str = typer.Option( 216 | DEFAULT_OUTPUT, 217 | "--output", 218 | "-o", 219 | help="Output file", 220 | ), 221 | temperature: float = typer.Option( 222 | DEFAULT_TEMPERATURE, 223 | "--temperature", 224 | "-t", 225 | help="Temperature for generation", 226 | ), 227 | max_tokens: Optional[int] = typer.Option( 228 | None, 229 | "--max-tokens", 230 | help="Maximum tokens to generate", 231 | ), 232 | no_progress: bool = typer.Option( 233 | False, 234 | "--no-progress", 235 | help="Disable progress tracking", 236 | ), 237 | cache_type: str = typer.Option( 238 | "memory", 239 | "--cache-type", 240 | help="Cache type to use (memory or disk)", 241 | ), 242 | cache_ttl: Optional[int] = typer.Option( 243 | None, 244 | "--cache-ttl", 245 | help="Time to live in seconds for cached items (disk cache only)", 246 | ), 247 | ) -> None: 248 | """Run parallel request test.""" 249 | api_key = os.environ["BB_AI_API_KEY"] 250 | 251 | run_test( 252 | api_key=api_key, 253 | model=model, 254 | repeats=repeats, 255 | concurrency=concurrency, 256 | output=output, 257 | temperature=temperature, 258 | max_tokens=max_tokens, 259 | no_progress=no_progress, 260 | cache_type=cache_type, 261 | cache_ttl=cache_ttl, 262 | ) 263 | 264 | 265 | if __name__ == "__main__": 266 | app() 267 | -------------------------------------------------------------------------------- /fastllm/__init__.py: -------------------------------------------------------------------------------- 1 | """FastLLM - High-performance parallel LLM API request tool.""" 2 | 3 | __version__ = "0.1.0" 4 | 5 | from fastllm.core import ( 6 | RequestBatch, 7 | RequestManager, 8 | ResponseWrapper, 9 | TokenStats, 10 | ) 11 | from fastllm.cache import ( 12 | CacheProvider, 13 | InMemoryCache, 14 | DiskCache, 15 | compute_request_hash, 16 | ) 17 | from fastllm.providers.base import Provider 18 | from fastllm.providers.openai import OpenAIProvider 19 | 20 | __all__ = [ 21 | # Core components 22 | "RequestBatch", 23 | "RequestManager", 24 | "ResponseWrapper", 25 | "TokenStats", 26 | 27 | # Cache components 28 | "CacheProvider", 29 | "InMemoryCache", 30 | "DiskCache", 31 | "compute_request_hash", 32 | 33 | # Provider components 34 | "Provider", 35 | "OpenAIProvider", 36 | ] 37 | -------------------------------------------------------------------------------- /fastllm/cache.py: -------------------------------------------------------------------------------- 1 | import json 2 | import xxhash 3 | import asyncio 4 | import logging 5 | from typing import Any, Dict, Optional 6 | from diskcache import Cache 7 | 8 | # Configure logging 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class CacheProvider: 13 | """Base class for cache providers with async interface.""" 14 | 15 | async def exists(self, key: str) -> bool: 16 | """Check if a key exists in the cache.""" 17 | raise NotImplementedError 18 | 19 | async def get(self, key: str): 20 | """Get a value from the cache.""" 21 | raise NotImplementedError 22 | 23 | async def put(self, key: str, value) -> None: 24 | """Put a value in the cache.""" 25 | raise NotImplementedError 26 | 27 | async def clear(self) -> None: 28 | """Clear all items from the cache.""" 29 | raise NotImplementedError 30 | 31 | async def close(self) -> None: 32 | """Close the cache when done.""" 33 | pass 34 | 35 | 36 | def compute_request_hash(request: dict) -> str: 37 | """Compute a hash for a request that can be used as a cache key. 38 | 39 | Args: 40 | request: The request dictionary to hash 41 | 42 | Returns: 43 | str: A hex string hash of the request 44 | 45 | Note: 46 | - None values and empty values are removed from the request before hashing 47 | - Internal fields (_request_id, _order_id) are removed 48 | - Default values are not added if not present 49 | """ 50 | # Create a copy of the request and remove any fields that are not part of the request content 51 | temp_request = request.copy() 52 | 53 | # Remove internal tracking fields that shouldn't affect caching 54 | temp_request.pop("_request_id", None) 55 | temp_request.pop("_order_id", None) 56 | 57 | # Extract known fields and extra params 58 | known_fields = { 59 | "provider", "model", "messages", "temperature", "max_completion_tokens", 60 | "top_p", "presence_penalty", "frequency_penalty", "stop", "stream", 61 | # Embedding specific fields 62 | "type", "input", "dimensions", "encoding_format", "user" 63 | } 64 | 65 | def clean_value(v): 66 | """Remove empty values and normalize None values.""" 67 | if v is None: 68 | return None 69 | if isinstance(v, (dict, list)): 70 | return clean_dict_or_list(v) 71 | if isinstance(v, str) and not v.strip(): 72 | return None 73 | return v 74 | 75 | def clean_dict_or_list(obj): 76 | """Recursively clean dictionaries and lists.""" 77 | if isinstance(obj, dict): 78 | cleaned = {k: clean_value(v) for k, v in obj.items()} 79 | return {k: v for k, v in cleaned.items() if v is not None} 80 | if isinstance(obj, list): 81 | cleaned = [clean_value(v) for v in obj] 82 | return [v for v in cleaned if v is not None] 83 | return obj 84 | 85 | # Clean and separate core parameters and extra parameters 86 | core_params = {k: clean_value(v) for k, v in temp_request.items() if k in known_fields} 87 | extra_params = {k: clean_value(v) for k, v in temp_request.items() if k not in known_fields} 88 | 89 | # Remove None values and empty values 90 | core_params = {k: v for k, v in core_params.items() if v is not None} 91 | extra_params = {k: v for k, v in extra_params.items() if v is not None} 92 | 93 | # Create a combined dictionary with sorted extra params 94 | hash_dict = { 95 | "core": core_params, 96 | "extra": dict(sorted(extra_params.items())) # Sort extra params for consistent hashing 97 | } 98 | 99 | # Serialize with sorted keys for a consistent representation 100 | request_str = json.dumps(hash_dict, sort_keys=True, ensure_ascii=False) 101 | return xxhash.xxh64(request_str.encode("utf-8")).hexdigest() 102 | 103 | 104 | class InMemoryCache(CacheProvider): 105 | """Simple in-memory cache implementation using a dictionary.""" 106 | 107 | def __init__(self): 108 | self._cache: Dict[str, Any] = {} 109 | 110 | async def exists(self, key: str) -> bool: 111 | """Check if a key exists in the cache.""" 112 | return key in self._cache 113 | 114 | async def get(self, key: str): 115 | """Get a value from the cache.""" 116 | if not await self.exists(key): 117 | raise KeyError(f"Cache for key {key} does not exist") 118 | return self._cache[key] 119 | 120 | async def put(self, key: str, value) -> None: 121 | """Put a value in the cache.""" 122 | self._cache[key] = value 123 | 124 | async def clear(self) -> None: 125 | """Clear all items from the cache.""" 126 | self._cache.clear() 127 | 128 | 129 | class DiskCache(CacheProvider): 130 | """Disk-based cache implementation using diskcache with async support.""" 131 | 132 | def __init__(self, directory: str, ttl: Optional[int] = None, **cache_options): 133 | """Initialize disk cache. 134 | 135 | Args: 136 | directory: Directory where cache files will be stored 137 | ttl: Time to live in seconds for cached items (None means no expiration) 138 | **cache_options: Additional options to pass to diskcache.Cache 139 | 140 | Raises: 141 | OSError: If the directory is invalid or cannot be created 142 | """ 143 | try: 144 | self._cache = Cache(directory, **cache_options) 145 | self._ttl = ttl 146 | except Exception as e: 147 | # Convert any cache initialization error to OSError 148 | raise OSError(f"Failed to initialize disk cache: {str(e)}") 149 | 150 | async def _run_in_executor(self, func, *args): 151 | """Run a blocking cache operation in the default executor.""" 152 | loop = asyncio.get_event_loop() 153 | return await loop.run_in_executor(None, func, *args) 154 | 155 | async def exists(self, key: str) -> bool: 156 | """Check if a key exists in the cache.""" 157 | try: 158 | # Use the internal __contains__ method which is faster than get 159 | return await self._run_in_executor(self._cache.__contains__, key) 160 | except Exception as e: 161 | raise OSError(f"Failed to check cache key: {str(e)}") 162 | 163 | async def get(self, key: str): 164 | """Get a value from the cache.""" 165 | try: 166 | value = await self._run_in_executor(self._cache.get, key) 167 | if value is None: 168 | # Convert None value to KeyError for consistent behavior with InMemoryCache 169 | raise KeyError(f"Cache for key {key} does not exist") 170 | return value 171 | except Exception as e: 172 | if isinstance(e, KeyError): 173 | raise 174 | raise OSError(f"Failed to get cache value: {str(e)}") 175 | 176 | async def put(self, key: str, value) -> None: 177 | """Put a value in the cache with optional TTL.""" 178 | try: 179 | await self._run_in_executor(self._cache.set, key, value, self._ttl) 180 | except Exception as e: 181 | raise OSError(f"Failed to store cache value: {str(e)}") 182 | 183 | async def clear(self) -> None: 184 | """Clear all items from the cache.""" 185 | try: 186 | await self._run_in_executor(self._cache.clear) 187 | except Exception as e: 188 | raise OSError(f"Failed to clear cache: {str(e)}") 189 | 190 | async def close(self) -> None: 191 | """Close the cache when done.""" 192 | try: 193 | await self._run_in_executor(self._cache.close) 194 | except Exception as e: 195 | raise OSError(f"Failed to close cache: {str(e)}") -------------------------------------------------------------------------------- /fastllm/core.py: -------------------------------------------------------------------------------- 1 | """Core functionality for parallel LLM API requests.""" 2 | 3 | import asyncio 4 | import time 5 | import logging 6 | from contextlib import AbstractContextManager, nullcontext 7 | from dataclasses import dataclass 8 | from typing import ( 9 | Any, 10 | Generic, 11 | Optional, 12 | TypeVar, 13 | Union, 14 | ) 15 | 16 | import httpx 17 | from openai.types import CompletionUsage 18 | from openai.types.chat import ChatCompletion 19 | from openai.types.chat.chat_completion import Choice, CompletionUsage 20 | from openai.types.chat.chat_completion_message import ChatCompletionMessage 21 | from pydantic import BaseModel 22 | from rich.progress import ( 23 | BarColumn, 24 | Progress, 25 | SpinnerColumn, 26 | TaskProgressColumn, 27 | TextColumn, 28 | TimeElapsedColumn, 29 | TimeRemainingColumn, 30 | MofNCompleteColumn 31 | ) 32 | 33 | from fastllm.cache import compute_request_hash 34 | 35 | # Configure logging 36 | logger = logging.getLogger(__name__) 37 | 38 | 39 | # Define a type variable for provider-specific response types 40 | ResponseT = TypeVar("ResponseT", bound=Union[ChatCompletion, Any]) 41 | 42 | DUMMY_RESPONSE = ChatCompletion( 43 | id="dummy_id", 44 | choices=[ 45 | Choice( 46 | index=0, 47 | message=ChatCompletionMessage(content="dummy_content", role="assistant"), 48 | finish_reason="stop" 49 | ) 50 | ], 51 | created=0, 52 | model="dummy_model", 53 | object="chat.completion", 54 | service_tier="default", 55 | system_fingerprint="dummy_system_fingerprint", 56 | usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0) 57 | ) 58 | 59 | 60 | class ResponseWrapper(Generic[ResponseT]): 61 | """Wrapper for provider responses that includes request ID for sorting.""" 62 | 63 | def __init__(self, response: ResponseT, request_id: str, order_id: int): 64 | self.response = response 65 | self.request_id = request_id 66 | self._order_id = order_id 67 | 68 | @property 69 | def usage(self) -> Optional[CompletionUsage]: 70 | """Get usage statistics if available.""" 71 | if isinstance(self.response, ChatCompletion): 72 | return self.response.usage 73 | elif isinstance(self.response, dict) and 'usage' in self.response: 74 | # Handle dict responses (like embeddings) 75 | usage = self.response['usage'] 76 | # Convert to CompletionUsage if not already 77 | if not isinstance(usage, CompletionUsage): 78 | return CompletionUsage( 79 | prompt_tokens=usage.get('prompt_tokens', 0), 80 | completion_tokens=usage.get('completion_tokens', 0), 81 | total_tokens=usage.get('total_tokens', 0) 82 | ) 83 | return usage 84 | return None 85 | 86 | @property 87 | def is_embedding_response(self) -> bool: 88 | """Check if this is an embedding response.""" 89 | if isinstance(self.response, dict): 90 | return 'data' in self.response and all('embedding' in item for item in self.response.get('data', [])) 91 | return False 92 | 93 | @property 94 | def embeddings(self) -> list: 95 | """Get embeddings from response if available.""" 96 | if not self.is_embedding_response: 97 | return [] 98 | 99 | if isinstance(self.response, dict) and 'data' in self.response: 100 | return [item.get('embedding', []) for item in self.response.get('data', [])] 101 | return [] 102 | 103 | 104 | @dataclass 105 | class TokenStats: 106 | """Statistics about token usage.""" 107 | 108 | prompt_tokens: int = 0 109 | completion_tokens: int = 0 110 | total_tokens: int = 0 111 | requests_completed: int = 0 112 | cache_hits: int = 0 # Track cache hits 113 | start_time: float = 0.0 114 | token_limit: Optional[int] = None # Rate limit for tokens per minute 115 | request_limit: Optional[int] = None # Rate limit for requests per minute 116 | window_tokens: int = 0 # Tokens in current rate limit window 117 | window_requests: int = 0 # Requests in current rate limit window 118 | token_limit: Optional[int] = None # Rate limit for tokens per minute 119 | request_limit: Optional[int] = None # Rate limit for requests per minute 120 | window_tokens: int = 0 # Tokens in current rate limit window 121 | window_requests: int = 0 # Requests in current rate limit window 122 | 123 | @property 124 | def elapsed_time(self) -> float: 125 | return time.time() - self.start_time 126 | 127 | @property 128 | def prompt_tokens_per_second(self) -> float: 129 | if self.elapsed_time == 0: 130 | return 0.0 131 | return self.prompt_tokens / self.elapsed_time 132 | 133 | @property 134 | def completion_tokens_per_second(self) -> float: 135 | if self.elapsed_time == 0: 136 | return 0.0 137 | return self.completion_tokens / self.elapsed_time 138 | 139 | @property 140 | def cache_hit_ratio(self) -> float: 141 | if self.requests_completed == 0: 142 | return 0.0 143 | return self.cache_hits / self.requests_completed 144 | 145 | @property 146 | def token_saturation(self) -> float: 147 | """Calculate token usage saturation (0.0 to 1.0).""" 148 | if not self.token_limit or self.elapsed_time == 0: 149 | return 0.0 150 | tokens_per_minute = (self.window_tokens / self.elapsed_time) * 60 151 | return tokens_per_minute / self.token_limit 152 | 153 | @property 154 | def request_saturation(self) -> float: 155 | """Calculate request rate saturation (0.0 to 1.0).""" 156 | if not self.request_limit or self.elapsed_time == 0: 157 | return 0.0 158 | requests_per_minute = (self.window_requests / self.elapsed_time) * 60 159 | return requests_per_minute / self.request_limit 160 | 161 | @property 162 | def token_saturation(self) -> float: 163 | """Calculate token usage saturation (0.0 to 1.0).""" 164 | if not self.token_limit or self.elapsed_time == 0: 165 | return 0.0 166 | tokens_per_minute = (self.window_tokens / self.elapsed_time) * 60 167 | return tokens_per_minute / self.token_limit 168 | 169 | @property 170 | def request_saturation(self) -> float: 171 | """Calculate request rate saturation (0.0 to 1.0).""" 172 | if not self.request_limit or self.elapsed_time == 0: 173 | return 0.0 174 | requests_per_minute = (self.window_requests / self.elapsed_time) * 60 175 | return requests_per_minute / self.request_limit 176 | 177 | def update(self, prompt_tokens: int, completion_tokens: int, is_cache_hit: bool = False) -> None: 178 | """Update token statistics.""" 179 | self.prompt_tokens += prompt_tokens 180 | self.completion_tokens += completion_tokens 181 | self.total_tokens += prompt_tokens + completion_tokens 182 | self.requests_completed += 1 183 | if is_cache_hit: 184 | self.cache_hits += 1 185 | else: 186 | # Only update window stats for non-cache hits 187 | self.window_tokens += prompt_tokens + completion_tokens 188 | self.window_requests += 1 189 | 190 | 191 | class ProgressTracker: 192 | """Tracks progress and token usage for batch requests.""" 193 | 194 | def __init__(self, total_requests: int, show_progress: bool = True): 195 | self.stats = TokenStats(start_time=time.time()) 196 | self.total_requests = total_requests 197 | self.show_progress = show_progress 198 | 199 | # Create progress display 200 | self.progress = Progress( 201 | SpinnerColumn(), 202 | TextColumn("[progress.description]{task.description}"), 203 | BarColumn(), 204 | TaskProgressColumn(), 205 | MofNCompleteColumn(), 206 | TimeRemainingColumn(), 207 | TimeElapsedColumn(), 208 | TextColumn("[blue]{task.fields[stats]}"), 209 | TextColumn("[yellow]{task.fields[cache]}"), 210 | disable=not show_progress, 211 | ) 212 | 213 | # Add main progress task 214 | self.task_id = self.progress.add_task( 215 | description="Processing requests", 216 | total=total_requests, 217 | stats="Starting...", 218 | cache="", 219 | ) 220 | 221 | def __enter__(self): 222 | """Start progress display.""" 223 | self.progress.start() 224 | return self 225 | 226 | def __exit__(self, exc_type, exc_val, exc_tb): 227 | """Stop progress display.""" 228 | self.progress.stop() 229 | 230 | def update(self, prompt_tokens: int, completion_tokens: int, is_cache_hit: bool = False): 231 | """Update progress and token statistics.""" 232 | self.stats.update(prompt_tokens, completion_tokens, is_cache_hit) 233 | 234 | # Update progress display with token rates and cache stats 235 | stats_text = ( 236 | f"[green]⬆ {self.stats.prompt_tokens_per_second:.1f}[/green] " 237 | f"[red]⬇ {self.stats.completion_tokens_per_second:.1f}[/red] t/s" 238 | ) 239 | 240 | cache_text = ( 241 | f"Cache: [green]{self.stats.cache_hit_ratio*100:.1f}%[/green] hits, " 242 | f"[yellow]{(1-self.stats.cache_hit_ratio)*100:.1f}%[/yellow] new" 243 | ) 244 | 245 | self.progress.update( 246 | self.task_id, 247 | advance=1, 248 | stats=stats_text, 249 | cache=cache_text, 250 | ) 251 | 252 | 253 | class RequestManager: 254 | """Manages parallel LLM API requests.""" 255 | 256 | def __init__( 257 | self, 258 | provider: 'Provider[ResponseT]', 259 | concurrency: int = 100, 260 | timeout: float = 30.0, 261 | retry_attempts: int = 3, 262 | retry_delay: float = 1.0, 263 | show_progress: bool = True, 264 | caching_provider: Optional['CacheProvider'] = None, 265 | return_dummy_on_error: bool = False, 266 | dummy_response: Optional[ResponseT] = DUMMY_RESPONSE 267 | ): 268 | self.provider = provider 269 | self.concurrency = concurrency 270 | self.timeout = timeout 271 | self.retry_attempts = retry_attempts 272 | self.retry_delay = retry_delay 273 | self.show_progress = show_progress 274 | self.cache = caching_provider 275 | self.return_dummy_on_error = return_dummy_on_error 276 | self.dummy_response = dummy_response 277 | 278 | def _calculate_chunk_size(self) -> int: 279 | """Calculate optimal chunk size based on concurrency. 280 | 281 | The chunk size is calculated as 2 * concurrency to allow for some overlap 282 | and better resource utilization while still maintaining reasonable memory usage. 283 | This provides a balance between creating too many tasks at once and 284 | underutilizing the available concurrency. 285 | """ 286 | return min(self.concurrency * 10, 25000) # Cap at 25000 to prevent excessive memory usage 287 | 288 | def process_batch( 289 | self, 290 | batch: Union[list[dict[str, Any]], "RequestBatch"], 291 | ) -> list[ResponseT]: 292 | """Process a batch of LLM requests in parallel. 293 | 294 | This is the main synchronous API endpoint that users should call. 295 | Internally it uses asyncio to handle requests concurrently. 296 | Works in both regular Python environments and Jupyter notebooks. 297 | 298 | Args: 299 | batch: Either a RequestBatch object or a list of request dictionaries 300 | 301 | Returns: 302 | List of responses in the same order as the requests 303 | """ 304 | try: 305 | loop = asyncio.get_running_loop() 306 | if loop.is_running(): 307 | # We're in a Jupyter notebook or similar environment 308 | # where the loop is already running 309 | import nest_asyncio 310 | nest_asyncio.apply() 311 | return loop.run_until_complete(self._process_batch_async(batch)) 312 | except RuntimeError: 313 | # No event loop running, create a new one 314 | return asyncio.run(self._process_batch_async(batch)) 315 | 316 | async def _process_request_async( 317 | self, 318 | client: httpx.AsyncClient, 319 | request: dict[str, Any], 320 | progress: Optional[ProgressTracker] = None, 321 | ) -> ResponseWrapper[ResponseT]: 322 | """Process a single request with caching support.""" 323 | # Get order ID and request ID from request 324 | order_id = request.get('_order_id', 0) 325 | request_id = request.get('_request_id') 326 | 327 | if request_id is None: 328 | # Only compute if not already present 329 | request_id = compute_request_hash(request) 330 | request['_request_id'] = request_id 331 | 332 | # Check cache first if available 333 | if self.cache is not None: 334 | try: 335 | if await self.cache.exists(request_id): 336 | cached_response = await self.cache.get(request_id) 337 | wrapped = ResponseWrapper(cached_response, request_id, order_id) 338 | if progress and wrapped.usage: 339 | # Update progress with cache hit 340 | progress.update( 341 | wrapped.usage.prompt_tokens, 342 | wrapped.usage.completion_tokens or 0, # Handle embeddings having no completion tokens 343 | is_cache_hit=True 344 | ) 345 | return wrapped 346 | except Exception as e: 347 | logger.warning(f"Cache read error: {str(e)}") 348 | 349 | # Process request with retries 350 | for attempt in range(self.retry_attempts): 351 | try: 352 | response = await self.provider.make_request( 353 | client, 354 | request, 355 | self.timeout, 356 | ) 357 | 358 | # Create wrapper and update progress 359 | wrapped = ResponseWrapper(response, request_id, order_id) 360 | 361 | if progress: 362 | # For embeddings, usage only has prompt_tokens 363 | if isinstance(response, dict) and 'usage' in response: 364 | usage = response['usage'] 365 | prompt_tokens = usage.get('prompt_tokens', 0) 366 | # Embeddings don't have completion tokens 367 | completion_tokens = usage.get('completion_tokens', 0) 368 | progress.update( 369 | prompt_tokens, 370 | completion_tokens, 371 | is_cache_hit=False 372 | ) 373 | elif wrapped.usage: 374 | progress.update( 375 | wrapped.usage.prompt_tokens, 376 | wrapped.usage.completion_tokens or 0, # Handle embeddings having no completion tokens 377 | is_cache_hit=False 378 | ) 379 | 380 | # Cache successful response 381 | if self.cache is not None: 382 | try: 383 | await self.cache.put(request_id, response) 384 | except Exception as e: 385 | logger.warning(f"Cache write error: {str(e)}") 386 | 387 | return wrapped 388 | 389 | except Exception as e: 390 | if attempt == self.retry_attempts - 1: 391 | if progress: 392 | # Update progress even for failed requests 393 | progress.update(0, 0, is_cache_hit=False) 394 | if self.return_dummy_on_error: 395 | # no caching for failed requests 396 | return ResponseWrapper(self.dummy_response, request_id, order_id) 397 | else: 398 | raise 399 | await asyncio.sleep(self.retry_delay * (attempt + 1)) 400 | 401 | async def _process_batch_async( 402 | self, 403 | batch: Union[list[dict[str, Any]], "RequestBatch"], 404 | ) -> list[ResponseWrapper[ResponseT]]: 405 | """Internal async implementation of batch processing.""" 406 | # Create semaphore for this batch processing run 407 | semaphore = asyncio.Semaphore(self.concurrency) 408 | 409 | # Convert RequestBatch to list of requests if needed 410 | if isinstance(batch, RequestBatch): 411 | # Extract original requests from batch format 412 | requests = [] 413 | for batch_req in batch.requests: 414 | # Extract request_id and order_id from custom_id 415 | request_id, order_id_str = batch_req["custom_id"].split("#") 416 | order_id = int(order_id_str) 417 | 418 | # Determine type from URL 419 | req_type = "chat_completion" if batch_req["url"] == "/v1/chat/completions" else "embedding" 420 | 421 | # Extract the original request from the batch format 422 | request = { 423 | **batch_req["body"], 424 | "_request_id": request_id, 425 | "_order_id": order_id, 426 | "type": req_type 427 | } 428 | requests.append(request) 429 | else: 430 | # Handle raw request list - compute request IDs and add order IDs 431 | requests = [] 432 | for i, request in enumerate(batch): 433 | request = request.copy() # Don't modify original request 434 | if "_request_id" not in request: 435 | request["_request_id"] = compute_request_hash(request) 436 | request["_order_id"] = i 437 | requests.append(request) 438 | 439 | # Create progress tracker if enabled 440 | tracker = ( 441 | ProgressTracker(len(requests), show_progress=self.show_progress) 442 | if self.show_progress 443 | else None 444 | ) 445 | 446 | async def process_request_with_semaphore( 447 | client: httpx.AsyncClient, 448 | request: dict[str, Any], 449 | progress: Optional[ProgressTracker] = None, 450 | ) -> ResponseWrapper[ResponseT]: 451 | """Process a single request with semaphore control.""" 452 | async with semaphore: 453 | return await self._process_request_async(client, request, progress) 454 | 455 | async def process_batch_chunk( 456 | client: httpx.AsyncClient, chunk: list[dict[str, Any]] 457 | ) -> list[ResponseWrapper[ResponseT]]: 458 | """Process a chunk of requests.""" 459 | batch_tasks = [ 460 | process_request_with_semaphore(client, req, tracker) for req in chunk 461 | ] 462 | results = await asyncio.gather(*batch_tasks) 463 | return [(r._order_id, r) for r in results] 464 | 465 | # Process requests in chunks based on calculated chunk size 466 | chunk_size = self._calculate_chunk_size() 467 | all_results = [] 468 | context = tracker if tracker else nullcontext() 469 | 470 | # Create a single client for the entire batch 471 | async with httpx.AsyncClient(timeout=self.timeout) as client: 472 | with context: 473 | for batch_start in range(0, len(requests), chunk_size): 474 | batch_requests = requests[ 475 | batch_start : batch_start + chunk_size 476 | ] 477 | batch_results = await process_batch_chunk(client, batch_requests) 478 | all_results.extend(batch_results) 479 | 480 | # Sort responses by order ID and return just the responses 481 | return [r for _, r in sorted(all_results, key=lambda x: x[0])] 482 | 483 | 484 | class RequestBatch(AbstractContextManager): 485 | """A batch of requests to be processed together in OpenAI Batch format.""" 486 | 487 | def __init__(self): 488 | self.requests = [] 489 | self._next_order_id = 0 490 | 491 | def __enter__(self): 492 | return self 493 | 494 | def __exit__(self, exc_type, exc_val, exc_tb): 495 | pass 496 | 497 | def __len__(self): 498 | return len(self.requests) 499 | 500 | def _add_request(self, request: dict[str, Any]) -> str: 501 | """Add a request to the batch and return its request ID (cache key). 502 | 503 | Args: 504 | request: The request to add to the batch 505 | 506 | Returns: 507 | str: The request ID (cache key) for this request 508 | """ 509 | 510 | # Compute request ID for caching if not already present 511 | request_id = compute_request_hash(request) 512 | order_id = self._next_order_id 513 | self._next_order_id += 1 514 | 515 | # Determine the endpoint URL based on request type 516 | url = "/v1/chat/completions" 517 | if request.get("type") == "embedding": 518 | url = "/v1/embeddings" 519 | 520 | # Create a custom_id from request_id and order_id 521 | custom_id = f"{request_id}#{order_id}" 522 | 523 | # Create batch format request directly 524 | batch_request = { 525 | "custom_id": custom_id, 526 | "url": url, 527 | "body": {k: v for k, v in request.items() if k not in ["type"]} 528 | } 529 | 530 | # Add to batch 531 | self.requests.append(batch_request) 532 | return request_id 533 | 534 | @classmethod 535 | def merge(cls, batches: list["RequestBatch"]) -> "RequestBatch": 536 | """Merge multiple request batches into a single batch.""" 537 | merged = cls() 538 | for batch in batches: 539 | merged.requests.extend(batch.requests) 540 | return merged 541 | 542 | @property 543 | def chat(self): 544 | """Access chat completion methods.""" 545 | return self.Chat(self) 546 | 547 | @property 548 | def embeddings(self): 549 | """Access embeddings methods.""" 550 | return self.Embeddings(self) 551 | 552 | class Chat: 553 | """Chat API that mimics OpenAI's interface.""" 554 | 555 | def __init__(self, batch): 556 | self.batch = batch 557 | self.completions = self.Completions(batch) 558 | 559 | class Completions: 560 | """Chat completions API that mimics OpenAI's interface.""" 561 | 562 | def __init__(self, batch): 563 | self.batch = batch 564 | 565 | def create( 566 | self, 567 | *, 568 | model: str, 569 | messages: list[dict[str, str]], 570 | temperature: Optional[float] = 0.7, 571 | top_p: Optional[float] = 1.0, 572 | n: Optional[int] = 1, 573 | stop: Optional[Union[str, list[str]]] = None, 574 | max_completion_tokens: Optional[int] = None, 575 | presence_penalty: Optional[float] = 0.0, 576 | frequency_penalty: Optional[float] = 0.0, 577 | logit_bias: Optional[dict[str, float]] = None, 578 | user: Optional[str] = None, 579 | response_format: Optional[dict[str, str]] = None, 580 | seed: Optional[int] = None, 581 | tools: Optional[list[dict[str, Any]]] = None, 582 | tool_choice: Optional[Union[str, dict[str, str]]] = None, 583 | **kwargs: Any 584 | ) -> str: 585 | """Add a chat completion request to the batch. 586 | 587 | Args: 588 | model: The model to use for completion 589 | messages: The messages to generate a completion for 590 | temperature: Sampling temperature (0-2) 591 | top_p: Nucleus sampling parameter (0-1) 592 | n: Number of completions to generate 593 | stop: Stop sequences to use 594 | max_completion_tokens: Maximum tokens to generate 595 | presence_penalty: Presence penalty (-2 to 2) 596 | frequency_penalty: Frequency penalty (-2 to 2) 597 | logit_bias: Token biases to use 598 | user: User identifier 599 | response_format: Format for the response 600 | seed: Random seed for reproducibility 601 | tools: List of tools available to the model 602 | tool_choice: Tool choice configuration 603 | **kwargs: Additional provider-specific parameters 604 | 605 | Returns: 606 | str: The request ID (cache key) for this request 607 | """ 608 | # Create the request body 609 | body = { 610 | "model": model, 611 | "messages": messages, 612 | "temperature": temperature, 613 | "top_p": top_p, 614 | "n": n, 615 | "stop": stop, 616 | "max_completion_tokens": max_completion_tokens, 617 | "presence_penalty": presence_penalty, 618 | "frequency_penalty": frequency_penalty, 619 | "logit_bias": logit_bias, 620 | "user": user, 621 | "response_format": response_format, 622 | "seed": seed, 623 | "tools": tools, 624 | "tool_choice": tool_choice, 625 | **kwargs, 626 | } 627 | 628 | # Remove None values to match OpenAI's behavior 629 | body = {k: v for k, v in body.items() if v is not None} 630 | 631 | # Compute request_id at creation time 632 | 633 | request_id = compute_request_hash({"type": "chat_completion", **body}) 634 | order_id = self.batch._next_order_id 635 | self.batch._next_order_id += 1 636 | 637 | # Create custom_id from request_id and order_id 638 | custom_id = f"{request_id}#{order_id}" 639 | 640 | # Create the batch request directly in OpenAI Batch format 641 | batch_request = { 642 | "custom_id": custom_id, 643 | "url": "/v1/chat/completions", 644 | "body": body 645 | } 646 | 647 | # Add to batch 648 | self.batch.requests.append(batch_request) 649 | 650 | return request_id 651 | 652 | class Embeddings: 653 | """Embeddings API that mimics OpenAI's interface.""" 654 | 655 | def __init__(self, batch): 656 | self.batch = batch 657 | 658 | def create( 659 | self, 660 | *, 661 | model: str, 662 | input: Union[str, list[str]], 663 | dimensions: Optional[int] = None, 664 | encoding_format: Optional[str] = None, 665 | user: Optional[str] = None, 666 | **kwargs: Any 667 | ) -> str: 668 | """Add an embedding request to the batch. 669 | 670 | Args: 671 | model: The model to use for embeddings (e.g., text-embedding-3-small) 672 | input: The text to embed (either a string or a list of strings) 673 | dimensions: The number of dimensions to return. Only supported with 674 | text-embedding-3 models. Defaults to the model's max dimensions. 675 | encoding_format: The format to return the embeddings in (float or base64) 676 | user: A unique identifier for the end-user 677 | **kwargs: Additional provider-specific parameters 678 | 679 | Returns: 680 | str: The request ID (cache key) for this request 681 | """ 682 | # Create the request body 683 | body = { 684 | "model": model, 685 | "input": input, 686 | "dimensions": dimensions, 687 | "encoding_format": encoding_format, 688 | "user": user, 689 | **kwargs, 690 | } 691 | 692 | # Remove None values to match OpenAI's behavior 693 | body = {k: v for k, v in body.items() if v is not None} 694 | 695 | request_id = compute_request_hash({"type": "embedding", **body}) 696 | order_id = self.batch._next_order_id 697 | self.batch._next_order_id += 1 698 | 699 | # Create custom_id from request_id and order_id 700 | custom_id = f"{request_id}#{order_id}" 701 | 702 | # Create the batch request directly in OpenAI Batch format 703 | batch_request = { 704 | "custom_id": custom_id, 705 | "url": "/v1/embeddings", 706 | "body": body 707 | } 708 | 709 | # Add to batch 710 | self.batch.requests.append(batch_request) 711 | 712 | return request_id 713 | -------------------------------------------------------------------------------- /fastllm/providers/__init__.py: -------------------------------------------------------------------------------- 1 | """Provider implementations.""" 2 | 3 | from .base import Provider 4 | from .openai import OpenAIProvider 5 | 6 | __all__ = ["Provider", "OpenAIProvider"] 7 | -------------------------------------------------------------------------------- /fastllm/providers/base.py: -------------------------------------------------------------------------------- 1 | """Base classes for LLM providers.""" 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import Any, Generic, Optional 5 | 6 | import httpx 7 | 8 | from fastllm.core import ResponseT 9 | 10 | class Provider(Generic[ResponseT], ABC): 11 | """Base class for LLM providers.""" 12 | 13 | # Internal tracking headers 14 | _REFERER = "https://github.com/Rexhaif/fastllm" 15 | _APP_NAME = "FastLLM" 16 | 17 | def __init__( 18 | self, 19 | api_key: str, 20 | api_base: str, 21 | headers: Optional[dict[str, str]] = None, 22 | **kwargs: Any, 23 | ): 24 | self.api_key = api_key 25 | self.api_base = api_base.rstrip("/") # Remove trailing slash if present 26 | self._default_headers = { 27 | "HTTP-Referer": self._REFERER, 28 | "X-Title": self._APP_NAME, 29 | } 30 | self.headers = { 31 | **self._default_headers, 32 | **(headers or {}), 33 | } 34 | 35 | def get_request_url(self, endpoint: str) -> str: 36 | """Get full URL for API endpoint.""" 37 | return f"{self.api_base}/{endpoint.lstrip('/')}" 38 | 39 | @abstractmethod 40 | def get_request_headers(self) -> dict[str, str]: 41 | """Get headers for API requests.""" 42 | pass 43 | 44 | @abstractmethod 45 | async def make_request( 46 | self, 47 | client: httpx.AsyncClient, 48 | request: dict[str, Any], 49 | timeout: float, 50 | ) -> ResponseT: 51 | """Make a request to the provider API.""" 52 | pass -------------------------------------------------------------------------------- /fastllm/providers/openai.py: -------------------------------------------------------------------------------- 1 | """OpenAI API provider implementation.""" 2 | 3 | from typing import Any, Optional, cast, Type 4 | 5 | import httpx 6 | from openai.types.chat import ChatCompletion 7 | from openai.types import CreateEmbeddingResponse 8 | 9 | from fastllm.providers.base import Provider 10 | 11 | DEFAULT_API_BASE = "https://api.openai.com/v1" 12 | 13 | 14 | class OpenAIProvider(Provider[ChatCompletion]): 15 | """OpenAI provider.""" 16 | 17 | def __init__( 18 | self, 19 | api_key: str, 20 | api_base: str = DEFAULT_API_BASE, 21 | organization: Optional[str] = None, 22 | headers: Optional[dict[str, str]] = None, 23 | **kwargs: Any, 24 | ): 25 | super().__init__(api_key, api_base, headers, **kwargs) 26 | self.organization = organization 27 | 28 | def get_request_headers(self) -> dict[str, str]: 29 | """Get headers for OpenAI API requests.""" 30 | headers = { 31 | "Authorization": f"Bearer {self.api_key}", 32 | "Content-Type": "application/json", 33 | **self.headers, 34 | } 35 | if self.organization: 36 | headers["OpenAI-Organization"] = self.organization 37 | return headers 38 | 39 | async def make_request( 40 | self, 41 | client: httpx.AsyncClient, 42 | request: dict[str, Any], 43 | timeout: float, 44 | ) -> ChatCompletion | CreateEmbeddingResponse: 45 | """Make a request to the OpenAI API.""" 46 | # Determine request type from the request or infer from content 47 | if isinstance(request, dict): 48 | # Extract request type from the request data 49 | request_type = request.get("type") 50 | if request_type is None: 51 | # Infer type based on content 52 | if "messages" in request: 53 | request_type = "chat_completion" 54 | elif "input" in request: 55 | request_type = "embedding" 56 | else: 57 | request_type = "chat_completion" # Default 58 | else: 59 | # Handle unexpected input 60 | raise ValueError(f"Unexpected request type: {type(request)}") 61 | 62 | # Determine API path based on request type 63 | if request_type == "embedding": 64 | api_path = "embeddings" 65 | else: 66 | api_path = "chat/completions" 67 | 68 | url = self.get_request_url(api_path) 69 | payload = self._prepare_payload(request, request_type) 70 | 71 | response = await client.post( 72 | url, 73 | headers=self.get_request_headers(), 74 | json=payload, 75 | timeout=timeout, 76 | ) 77 | response.raise_for_status() 78 | data = response.json() 79 | 80 | if request_type == "embedding": 81 | return CreateEmbeddingResponse(**data) 82 | else: 83 | return ChatCompletion(**data) 84 | 85 | 86 | 87 | def _prepare_payload(self, request: dict[str, Any], request_type: str) -> dict[str, Any]: 88 | """Prepare the API payload from the request data.""" 89 | # Extract known fields and extra params 90 | known_fields = { 91 | "provider", "model", "messages", "temperature", "max_completion_tokens", 92 | "top_p", "presence_penalty", "frequency_penalty", "stop", "stream", 93 | "type", "input", "dimensions", "encoding_format", "user" 94 | } 95 | 96 | # Start with a copy of the request 97 | payload = {k: v for k, v in request.items() if k not in ["provider", "type", "_order_id", "_request_id"]} 98 | 99 | # Handle embedding requests 100 | if request_type == "embedding": 101 | # Ensure required fields are present 102 | if "model" not in payload: 103 | raise ValueError("Model is required for embedding requests") 104 | if "input" not in payload: 105 | raise ValueError("Input is required for embedding requests") 106 | 107 | # Keep only relevant fields for embeddings 108 | embedding_fields = {"model", "input", "dimensions", "encoding_format", "user"} 109 | return {k: v for k, v in payload.items() if k in embedding_fields} 110 | 111 | # Handle chat completion requests 112 | if "model" not in payload: 113 | raise ValueError("Model is required for chat completion requests") 114 | if "messages" not in payload: 115 | raise ValueError("Messages are required for chat completion requests") 116 | 117 | # Map max_completion_tokens to max_tokens if present 118 | if "max_completion_tokens" in payload: 119 | payload["max_tokens"] = payload.pop("max_completion_tokens") 120 | 121 | # Remove any None values 122 | return {k: v for k, v in payload.items() if v is not None} 123 | -------------------------------------------------------------------------------- /justfile: -------------------------------------------------------------------------------- 1 | set shell := ["bash", "-c"] 2 | set dotenv-load := true 3 | 4 | # List all available commands 5 | default: 6 | @just --list 7 | 8 | # Install all dependencies using uv 9 | install: 10 | uv venv 11 | uv pip install -e . 12 | uv pip install -e ".[dev]" 13 | 14 | # Run tests 15 | test: 16 | uv run pytest tests/ -v 17 | 18 | # Format code using ruff 19 | format: 20 | uv run ruff format . 21 | uv run ruff check . --fix 22 | 23 | # Run linting checks 24 | lint: 25 | uv run ruff check . 26 | 27 | # Clean up cache files 28 | clean: 29 | rm -rf .pytest_cache 30 | rm -rf .coverage 31 | rm -rf .ruff_cache 32 | rm -rf dist 33 | rm -rf build 34 | rm -rf *.egg-info 35 | find . -type d -name __pycache__ -exec rm -rf {} + 36 | 37 | live_test_completions: 38 | uv run python examples/parallel_test.py --model openai/gpt-4o-mini --repeats 100 --concurrency 75 --cache-type memory --output NO_OUTPUT 39 | 40 | # Build the package for distribution 41 | build: 42 | uv build --no-sources 43 | 44 | # Build and publish to PyPI 45 | publish: clean build 46 | UV_PUBLISH_TOKEN=$UV_PUBLISH_TOKEN_PROD uv publish 47 | 48 | # Build and publish to TestPyPI 49 | publish-test: clean build 50 | UV_PUBLISH_TOKEN=$UV_PUBLISH_TOKEN_TEST uv publish --index testpypi 51 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "fastllm-kit" 3 | version = "0.1.9" 4 | description = "High-performance parallel LLM API request tool with caching and multiple provider support" 5 | authors = [ 6 | {name = "Daniil Larionov", email = "rexhaif.io@gmail.com"} 7 | ] 8 | readme = "README.md" 9 | license = {text = "MIT"} 10 | requires-python = ">=3.9" 11 | keywords = ["llm", "ai", "openai", "parallel", "caching", "api"] 12 | classifiers = [ 13 | "Development Status :: 4 - Beta", 14 | "Intended Audience :: Developers", 15 | "License :: OSI Approved :: MIT License", 16 | "Programming Language :: Python :: 3.9", 17 | "Programming Language :: Python :: 3.10", 18 | "Programming Language :: Python :: 3.11", 19 | "Topic :: Software Development :: Libraries :: Python Modules", 20 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 21 | ] 22 | dependencies = [ 23 | "httpx>=0.27.2", 24 | "pydantic>=2.10.6", 25 | "rich>=13.9.4", 26 | "diskcache>=5.6.3", 27 | "anyio>=4.8.0", 28 | "typing_extensions>=4.12.2", 29 | "tqdm>=4.67.1", 30 | "nest-asyncio>=1.6.0", 31 | "openai>=1.61.0", 32 | "xxhash>=3.0.0", 33 | ] 34 | 35 | [project.optional-dependencies] 36 | dev = [ 37 | "ruff>=0.3.7", 38 | "pytest>=8.3.4", 39 | "pytest-asyncio>=0.23.8", 40 | "pytest-cov>=4.1.0", 41 | "black>=24.10.0", 42 | "coverage>=7.6.10", 43 | "ipywidgets>=8.1.5", 44 | "dotenv>=0.9.9", 45 | "typer>=0.15.2", 46 | "numpy>=2.0.2", 47 | ] 48 | 49 | [tool.ruff] 50 | line-length = 88 51 | target-version = "py39" 52 | select = ["E", "F", "B", "I", "N", "UP", "PL", "RUF"] 53 | ignore = [] 54 | 55 | [tool.ruff.format] 56 | quote-style = "double" 57 | indent-style = "space" 58 | skip-magic-trailing-comma = false 59 | line-ending = "auto" 60 | 61 | [tool.pytest.ini_options] 62 | asyncio_mode = "strict" 63 | asyncio_default_fixture_loop_scope = "function" 64 | 65 | [build-system] 66 | requires = ["hatchling"] 67 | build-backend = "hatchling.build" 68 | 69 | [dependency-groups] 70 | dev = [ 71 | "pytest>=8.3.4", 72 | "pytest-asyncio>=0.25.3", 73 | ] 74 | 75 | [tool.hatch.build.targets.wheel] 76 | packages = ["fastllm"] 77 | 78 | [tool.hatch.metadata] 79 | allow-direct-references = true 80 | 81 | [[tool.uv.index]] 82 | name = "testpypi" 83 | url = "https://test.pypi.org/simple/" 84 | publish-url = "https://test.pypi.org/legacy/" 85 | explicit = true 86 | -------------------------------------------------------------------------------- /tests/test_asyncio.py: -------------------------------------------------------------------------------- 1 | """Tests for async functionality.""" 2 | 3 | import asyncio 4 | import time 5 | 6 | import pytest 7 | from openai.types.completion_usage import CompletionUsage 8 | 9 | from fastllm.core import RequestManager 10 | 11 | # Constants for testing 12 | MAX_DURATION = 0.5 # Maximum expected duration for concurrent execution 13 | EXPECTED_RESPONSES = 5 # Expected number of responses 14 | EXPECTED_SUCCESSES = 2 # Expected number of successful responses 15 | EXPECTED_FAILURES = 1 # Expected number of failed responses 16 | 17 | 18 | class DummyRequestManager(RequestManager): 19 | async def _make_provider_request(self, client, request): 20 | await asyncio.sleep(0.1) # Simulate network delay 21 | # Extract message content from request 22 | message_content = request.get("messages", [{}])[0].get("content", "No content") 23 | response_dict = { 24 | "request_id": id(request), # Add unique request ID 25 | "content": f"Response to: {message_content}", 26 | "finish_reason": "dummy_end", 27 | "provider": "dummy", 28 | "raw_response": {"dummy_key": "dummy_value"}, 29 | "usage": { 30 | "prompt_tokens": 10, 31 | "completion_tokens": 5, 32 | "total_tokens": 15 33 | } 34 | } 35 | return response_dict 36 | 37 | 38 | @pytest.mark.asyncio 39 | async def test_dummy_manager_single_request(): 40 | manager = DummyRequestManager(provider="dummy") 41 | request = { 42 | "provider": "dummy", 43 | "messages": [{"role": "user", "content": "Hello async!"}], 44 | "model": "dummy-model" 45 | } 46 | response = await manager._make_provider_request(None, request) 47 | assert response["content"] == "Response to: Hello async!" 48 | assert response["finish_reason"] == "dummy_end" 49 | assert response["usage"]["total_tokens"] == 15 50 | 51 | 52 | @pytest.mark.asyncio 53 | async def test_dummy_manager_concurrent_requests(): 54 | manager = DummyRequestManager(provider="dummy") 55 | requests = [ 56 | { 57 | "provider": "dummy", 58 | "messages": [{"role": "user", "content": f"Message {i}"}], 59 | "model": "dummy-model" 60 | } 61 | for i in range(5) 62 | ] 63 | 64 | start = time.perf_counter() 65 | responses = await asyncio.gather( 66 | *[manager._make_provider_request(None, req) for req in requests] 67 | ) 68 | end = time.perf_counter() 69 | 70 | # All requests should complete successfully 71 | assert len(responses) == 5 72 | for i, response in enumerate(responses): 73 | assert response["content"] == f"Response to: Message {i}" 74 | assert response["finish_reason"] == "dummy_end" 75 | assert response["usage"]["total_tokens"] == 15 76 | 77 | # Requests should be processed concurrently 78 | # Total time should be less than sequential time (5 * 0.1s) 79 | assert end - start < 0.5 # Allow some overhead 80 | 81 | 82 | class FailingDummyManager(RequestManager): 83 | async def _make_provider_request(self, client, request): 84 | message_content = request.get("messages", [{}])[0].get("content", "") 85 | if "fail" in message_content.lower(): 86 | raise Exception("Provider failure") 87 | await asyncio.sleep(0.1) 88 | response_dict = { 89 | "request_id": id(request), # Add unique request ID 90 | "content": f"Response to: {message_content}", 91 | "finish_reason": "dummy_end", 92 | "provider": "dummy", 93 | "raw_response": {"dummy_key": "dummy_value"}, 94 | "usage": { 95 | "prompt_tokens": 10, 96 | "completion_tokens": 5, 97 | "total_tokens": 15 98 | } 99 | } 100 | return response_dict 101 | 102 | 103 | @pytest.mark.asyncio 104 | async def test_dummy_manager_request_failure(): 105 | manager = FailingDummyManager(provider="dummy") 106 | request = { 107 | "provider": "dummy", 108 | "messages": [{"role": "user", "content": "fail this request"}], 109 | "model": "dummy-model" 110 | } 111 | with pytest.raises(Exception) as exc_info: 112 | await manager._make_provider_request(None, request) 113 | assert "Provider failure" in str(exc_info.value) 114 | 115 | 116 | @pytest.mark.asyncio 117 | async def test_gather_with_mixed_success_and_failure(): 118 | manager = FailingDummyManager(provider="dummy") 119 | requests = [ 120 | { 121 | "provider": "dummy", 122 | "messages": [{"role": "user", "content": "Message 1"}], 123 | "model": "dummy-model" 124 | }, 125 | { 126 | "provider": "dummy", 127 | "messages": [{"role": "user", "content": "fail this one"}], 128 | "model": "dummy-model" 129 | }, 130 | { 131 | "provider": "dummy", 132 | "messages": [{"role": "user", "content": "Message 3"}], 133 | "model": "dummy-model" 134 | }, 135 | ] 136 | 137 | responses = await asyncio.gather( 138 | *[manager._make_provider_request(None, req) for req in requests], 139 | return_exceptions=True, 140 | ) 141 | successes = [resp for resp in responses if not isinstance(resp, Exception)] 142 | failures = [resp for resp in responses if isinstance(resp, Exception)] 143 | 144 | assert len(successes) == 2 145 | assert len(failures) == 1 146 | assert "Provider failure" in str(failures[0]) 147 | 148 | 149 | @pytest.mark.asyncio 150 | async def test_task_scheduling_order(): 151 | manager = DummyRequestManager(provider="dummy") 152 | requests = [ 153 | { 154 | "provider": "dummy", 155 | "messages": [{"role": "user", "content": f"Message {i}"}], 156 | "model": "dummy-model" 157 | } 158 | for i in range(3) 159 | ] 160 | 161 | # Create tasks but don't await them yet 162 | tasks = [manager._make_provider_request(None, req) for req in requests] 163 | 164 | # Schedule tasks in reverse order 165 | responses = [] 166 | for task in reversed(tasks): 167 | responses.append(await task) 168 | 169 | # Despite scheduling in reverse order, responses should match request order 170 | for i, response in enumerate(responses): 171 | assert f"Message {2-i}" in response["content"] 172 | 173 | 174 | @pytest.mark.asyncio 175 | async def test_task_cancellation(): 176 | manager = DummyRequestManager(provider="dummy") 177 | request = { 178 | "provider": "dummy", 179 | "messages": [{"role": "user", "content": "Cancel me"}], 180 | "model": "dummy-model" 181 | } 182 | 183 | # Start the task 184 | task = asyncio.create_task(manager._make_provider_request(None, request)) 185 | 186 | # Cancel it immediately 187 | task.cancel() 188 | 189 | with pytest.raises(asyncio.CancelledError): 190 | await task 191 | -------------------------------------------------------------------------------- /tests/test_cache.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | 4 | import pytest 5 | 6 | from fastllm.cache import DiskCache, InMemoryCache, compute_request_hash 7 | 8 | 9 | @pytest.mark.asyncio 10 | async def test_inmemory_cache_put_get_exists_and_clear(): 11 | cache = InMemoryCache() 12 | key = "test_key" 13 | value = {"foo": "bar"} 14 | 15 | # Initially, key should not exist 16 | assert not await cache.exists(key) 17 | 18 | # Put value 19 | await cache.put(key, value) 20 | 21 | # Check existence 22 | assert await cache.exists(key) 23 | 24 | # Get the value back 25 | retrieved_val = await cache.get(key) 26 | assert retrieved_val == value 27 | 28 | # Clear cache 29 | await cache.clear() 30 | assert not await cache.exists(key) 31 | with pytest.raises(KeyError): 32 | await cache.get(key) 33 | 34 | 35 | @pytest.mark.asyncio 36 | async def test_disk_cache_put_get_exists_and_clear(tmp_path): 37 | # Create a temporary directory for the disk cache 38 | cache_dir = os.path.join(tmp_path, "disk_cache") 39 | os.makedirs(cache_dir, exist_ok=True) 40 | cache = DiskCache(directory=cache_dir) 41 | key = "disk_test" 42 | value = {"alpha": 123} 43 | 44 | # Initially, key should not exist 45 | assert not await cache.exists(key) 46 | 47 | # Put the value 48 | await cache.put(key, value) 49 | 50 | # Check existence 51 | assert await cache.exists(key) 52 | 53 | # Retrieve the value 54 | result = await cache.get(key) 55 | assert result == value 56 | 57 | # Clear the cache and verify 58 | await cache.clear() 59 | assert not await cache.exists(key) 60 | with pytest.raises(KeyError): 61 | await cache.get(key) 62 | 63 | # Close the disk cache 64 | await cache.close() 65 | 66 | 67 | @pytest.mark.asyncio 68 | async def test_disk_cache_ttl(tmp_path): 69 | # Create temporary directory for disk cache with TTL 70 | cache_dir = os.path.join(tmp_path, "disk_cache_ttl") 71 | os.makedirs(cache_dir, exist_ok=True) 72 | # Set TTL to 1 second 73 | cache = DiskCache(directory=cache_dir, ttl=1) 74 | key = "disk_ttl" 75 | value = "temporary" 76 | await cache.put(key, value) 77 | 78 | # Immediately, key should exist 79 | assert await cache.exists(key) 80 | 81 | # Wait for TTL expiry 82 | await asyncio.sleep(1.1) 83 | 84 | # After TTL expiration, key should be expired (get raises KeyError) 85 | with pytest.raises(KeyError): 86 | await cache.get(key) 87 | 88 | await cache.clear() 89 | await cache.close() 90 | 91 | 92 | def test_compute_request_hash_consistency(): 93 | req1 = { 94 | "provider": "test", 95 | "model": "dummy", 96 | "messages": [{"role": "user", "content": "Hello"}], 97 | "temperature": 0.5, 98 | "_request_id": "ignore_me", 99 | "extra_param": "value", 100 | } 101 | req2 = { 102 | "provider": "test", 103 | "model": "dummy", 104 | "messages": [{"role": "user", "content": "Hello"}], 105 | "temperature": 0.5, 106 | "extra_param": "value", 107 | "_order_id": "should_be_removed", 108 | } 109 | 110 | hash1 = compute_request_hash(req1) 111 | hash2 = compute_request_hash(req2) 112 | 113 | # The hashes should be identical because _request_id and _order_id are removed 114 | assert hash1 == hash2 115 | 116 | 117 | @pytest.mark.asyncio 118 | async def test_disk_cache_invalid_directory(tmp_path): 119 | """Test DiskCache behavior with invalid directory.""" 120 | # Try to create cache in a non-existent directory that can't be created 121 | # (using a file as a directory should raise OSError) 122 | file_path = os.path.join(tmp_path, "file") 123 | with open(file_path, "w") as f: 124 | f.write("test") 125 | 126 | with pytest.raises(OSError): 127 | DiskCache(directory=file_path) 128 | 129 | 130 | @pytest.mark.asyncio 131 | async def test_disk_cache_concurrent_access(tmp_path): 132 | """Test concurrent access to DiskCache.""" 133 | cache_dir = os.path.join(tmp_path, "concurrent_cache") 134 | os.makedirs(cache_dir, exist_ok=True) 135 | cache = DiskCache(directory=cache_dir) 136 | 137 | # Create multiple concurrent operations 138 | async def concurrent_operation(key, value): 139 | await cache.put(key, value) 140 | assert await cache.exists(key) 141 | result = await cache.get(key) 142 | assert result == value 143 | 144 | # Run multiple operations concurrently 145 | tasks = [ 146 | concurrent_operation(f"key_{i}", f"value_{i}") 147 | for i in range(5) 148 | ] 149 | await asyncio.gather(*tasks) 150 | 151 | # Verify all values are still accessible 152 | for i in range(5): 153 | assert await cache.get(f"key_{i}") == f"value_{i}" 154 | 155 | await cache.clear() 156 | await cache.close() 157 | 158 | 159 | @pytest.mark.asyncio 160 | async def test_inmemory_cache_large_values(): 161 | """Test InMemoryCache with large values.""" 162 | cache = InMemoryCache() 163 | key = "large_value" 164 | # Create a large value (1MB string) 165 | large_value = "x" * (1024 * 1024) 166 | 167 | await cache.put(key, large_value) 168 | result = await cache.get(key) 169 | assert result == large_value 170 | 171 | await cache.clear() 172 | 173 | 174 | def test_compute_request_hash_edge_cases(): 175 | """Test compute_request_hash with edge cases.""" 176 | # Empty request 177 | empty_req = {} 178 | empty_hash = compute_request_hash(empty_req) 179 | assert empty_hash # Should return a hash even for empty dict 180 | 181 | # Request with only internal fields 182 | internal_req = { 183 | "_request_id": "123", 184 | "_order_id": "456", 185 | } 186 | internal_hash = compute_request_hash(internal_req) 187 | assert internal_hash == compute_request_hash({}) # Should be same as empty 188 | 189 | # Request with nested structures 190 | nested_req = { 191 | "provider": "test", 192 | "messages": [ 193 | {"role": "system", "content": "You are a bot"}, 194 | {"role": "user", "content": "Hi"}, 195 | {"role": "assistant", "content": "Hello!"}, 196 | ], 197 | "options": { 198 | "temperature": 0.7, 199 | "max_tokens": 100, 200 | "stop": [".", "!"], 201 | }, 202 | } 203 | nested_hash = compute_request_hash(nested_req) 204 | assert nested_hash # Should handle nested structures 205 | 206 | # Verify order independence 207 | reordered_req = { 208 | "messages": [ 209 | {"role": "system", "content": "You are a bot"}, 210 | {"role": "user", "content": "Hi"}, 211 | {"role": "assistant", "content": "Hello!"}, 212 | ], 213 | "provider": "test", 214 | "options": { 215 | "max_tokens": 100, 216 | "temperature": 0.7, 217 | "stop": [".", "!"], 218 | }, 219 | } 220 | reordered_hash = compute_request_hash(reordered_req) 221 | assert nested_hash == reordered_hash # Hash should be independent of field order 222 | -------------------------------------------------------------------------------- /tests/test_core.py: -------------------------------------------------------------------------------- 1 | """Tests for core functionality.""" 2 | 3 | import time 4 | from unittest import mock 5 | 6 | import time 7 | from unittest import mock 8 | 9 | import pytest 10 | 11 | from fastllm.core import ( 12 | RequestBatch, 13 | RequestManager, 14 | ResponseWrapper, 15 | TokenStats, 16 | ProgressTracker, 17 | ) 18 | from fastllm.cache import InMemoryCache, compute_request_hash 19 | 20 | # Constants for testing 21 | DEFAULT_CHUNK_SIZE = 20 22 | DEFAULT_MAX_CHUNK_SIZE = 1000 23 | DEFAULT_PROMPT_TOKENS = 10 24 | DEFAULT_COMPLETION_TOKENS = 5 25 | DEFAULT_TOTAL_TOKENS = DEFAULT_PROMPT_TOKENS + DEFAULT_COMPLETION_TOKENS 26 | DEFAULT_RETRY_ATTEMPTS = 2 27 | DEFAULT_CONCURRENCY = 5 28 | DEFAULT_TIMEOUT = 1.0 29 | DEFAULT_RETRY_DELAY = 0.0 30 | 31 | # TokenStats Tests 32 | def test_token_stats(): 33 | """Test basic TokenStats functionality.""" 34 | ts = TokenStats(start_time=time.time() - 2) # started 2 seconds ago 35 | assert ts.cache_hit_ratio == 0.0 36 | ts.update(DEFAULT_PROMPT_TOKENS, DEFAULT_COMPLETION_TOKENS, is_cache_hit=True) 37 | assert ts.prompt_tokens == DEFAULT_PROMPT_TOKENS 38 | assert ts.completion_tokens == DEFAULT_COMPLETION_TOKENS 39 | assert ts.total_tokens == DEFAULT_TOTAL_TOKENS 40 | assert ts.requests_completed == 1 41 | assert ts.cache_hits == 1 42 | assert ts.prompt_tokens_per_second > 0 43 | assert ts.completion_tokens_per_second > 0 44 | 45 | 46 | def test_token_stats_rate_limits(): 47 | """Test rate limit tracking in TokenStats.""" 48 | current_time = time.time() 49 | stats = TokenStats( 50 | start_time=current_time, 51 | token_limit=1000, # 1000 tokens per minute 52 | request_limit=100, # 100 requests per minute 53 | ) 54 | 55 | # Test initial state 56 | assert stats.token_saturation == 0.0 57 | assert stats.request_saturation == 0.0 58 | 59 | # Update with some usage (non-cache hits) 60 | stats.update(50, 50, is_cache_hit=False) # 100 tokens total 61 | stats.update(25, 25, is_cache_hit=False) # 50 tokens total 62 | 63 | # Cache hits should not affect rate limit tracking 64 | stats.update(100, 100, is_cache_hit=True) 65 | assert stats.window_tokens == 150 # Still 150 from non-cache hits 66 | assert stats.window_requests == 2 # Still 2 from non-cache hits 67 | 68 | # RequestManager Tests 69 | class DummyProvider: 70 | async def make_request(self, client, request, timeout): 71 | return { 72 | "content": "Test response", 73 | "finish_reason": "stop", 74 | "usage": { 75 | "prompt_tokens": 10, 76 | "completion_tokens": 5, 77 | "total_tokens": 15 78 | } 79 | } 80 | 81 | 82 | @pytest.mark.asyncio 83 | async def test_request_manager(): 84 | """Test basic RequestManager functionality.""" 85 | provider = DummyProvider() 86 | manager = RequestManager(provider=provider) 87 | 88 | # Create a request dictionary 89 | request = { 90 | "provider": "dummy", 91 | "messages": [{"role": "user", "content": "Test message"}], 92 | "model": "dummy-model" 93 | } 94 | 95 | # Test make_provider_request 96 | response = await manager._make_provider_request(None, request) 97 | assert response["content"] == "Test response" 98 | assert response["finish_reason"] == "stop" 99 | assert response["usage"]["total_tokens"] == 15 100 | 101 | 102 | class FailingProvider: 103 | async def make_request(self, client, request, timeout): 104 | raise Exception("Provider error") 105 | 106 | 107 | @pytest.mark.asyncio 108 | async def test_request_manager_failure(): 109 | """Test RequestManager with failing provider.""" 110 | provider = FailingProvider() 111 | manager = RequestManager(provider=provider) 112 | request = { 113 | "provider": "dummy", 114 | "messages": [{"role": "user", "content": "Test message"}], 115 | "model": "dummy-model" 116 | } 117 | 118 | with pytest.raises(Exception) as exc_info: 119 | await manager._make_provider_request(None, request) 120 | assert "Provider error" in str(exc_info.value) 121 | 122 | 123 | def test_progress_tracker_update(): 124 | """Test ProgressTracker updates.""" 125 | tracker = ProgressTracker(total_requests=10, show_progress=False) 126 | tracker.update(DEFAULT_PROMPT_TOKENS, DEFAULT_COMPLETION_TOKENS) 127 | assert tracker.stats.prompt_tokens == DEFAULT_PROMPT_TOKENS 128 | assert tracker.stats.completion_tokens == DEFAULT_COMPLETION_TOKENS 129 | assert tracker.stats.total_tokens == DEFAULT_TOTAL_TOKENS 130 | 131 | 132 | def test_progress_tracker_context_manager(): 133 | """Test ProgressTracker as context manager.""" 134 | with ProgressTracker(total_requests=10, show_progress=False) as tracker: 135 | tracker.update(DEFAULT_PROMPT_TOKENS, DEFAULT_COMPLETION_TOKENS) 136 | assert tracker.stats.prompt_tokens == DEFAULT_PROMPT_TOKENS 137 | 138 | 139 | def test_progress_tracker_with_limits(): 140 | """Test ProgressTracker with rate limits.""" 141 | tracker = ProgressTracker(total_requests=10, show_progress=False) 142 | tracker.stats.token_limit = 1000 143 | tracker.stats.request_limit = 100 144 | tracker.update(DEFAULT_PROMPT_TOKENS, DEFAULT_COMPLETION_TOKENS) 145 | assert tracker.stats.token_saturation > 0 146 | assert tracker.stats.request_saturation > 0 147 | 148 | 149 | class CachingProvider: 150 | """Test provider that always returns the same response.""" 151 | 152 | def __init__(self): 153 | self.call_count = 0 154 | 155 | async def make_request(self, client, request, timeout): 156 | self.call_count += 1 157 | return { 158 | "content": "Cached response", 159 | "finish_reason": "stop", 160 | "usage": { 161 | "prompt_tokens": 10, 162 | "completion_tokens": 5, 163 | "total_tokens": 15 164 | } 165 | } 166 | 167 | 168 | @pytest.mark.asyncio 169 | async def test_request_manager_caching(): 170 | """Test RequestManager with caching.""" 171 | provider = CachingProvider() 172 | cache = InMemoryCache() 173 | manager = RequestManager( 174 | provider=provider, 175 | caching_provider=cache, 176 | show_progress=False, 177 | ) 178 | 179 | # Create identical requests 180 | request1 = { 181 | "provider": "dummy", 182 | "messages": [{"role": "user", "content": "Test message"}], 183 | "model": "dummy-model" 184 | } 185 | 186 | # Add request_id 187 | request_id = compute_request_hash(request1) 188 | request1["_request_id"] = request_id 189 | 190 | # First request should hit the provider 191 | response1 = await manager._process_request_async(None, request1, None) 192 | assert provider.call_count == 1 193 | 194 | # Second request with the same content should hit the cache 195 | request2 = request1.copy() 196 | response2 = await manager._process_request_async(None, request2, None) 197 | assert provider.call_count == 1 # Should not increase 198 | 199 | # Different request should hit the provider 200 | request3 = { 201 | "provider": "dummy", 202 | "messages": [{"role": "user", "content": "Different message"}], 203 | "model": "dummy-model" 204 | } 205 | request3["_request_id"] = compute_request_hash(request3) 206 | 207 | response3 = await manager._process_request_async(None, request3, None) 208 | assert provider.call_count == 2 209 | 210 | 211 | @pytest.mark.asyncio 212 | async def test_request_manager_cache_errors(): 213 | """Test RequestManager handles cache errors gracefully.""" 214 | 215 | class ErrorCache(InMemoryCache): 216 | async def exists(self, key: str) -> bool: 217 | raise Exception("Cache error") 218 | 219 | async def get(self, key: str): 220 | raise Exception("Cache error") 221 | 222 | async def put(self, key: str, value) -> None: 223 | raise Exception("Cache error") 224 | 225 | provider = DummyProvider() 226 | manager = RequestManager( 227 | provider=provider, 228 | caching_provider=ErrorCache(), 229 | show_progress=False, 230 | ) 231 | 232 | # Create a request 233 | request = { 234 | "provider": "dummy", 235 | "messages": [{"role": "user", "content": "Test message"}], 236 | "model": "dummy-model", 237 | "_request_id": "test-id", 238 | } 239 | 240 | # Request should succeed despite cache errors 241 | response = await manager._process_request_async(None, request, None) 242 | # Access response content from the wrapped response 243 | assert response.response["content"] == "Test response" 244 | 245 | 246 | @pytest.mark.asyncio 247 | async def test_request_manager_failed_response_not_cached(): 248 | """Test that failed responses are not cached.""" 249 | 250 | class FailingProvider: 251 | def __init__(self, fail_first=True): 252 | self.call_count = 0 253 | self.fail_first = fail_first 254 | self.has_failed = False 255 | 256 | async def make_request(self, client, request, timeout): 257 | self.call_count += 1 258 | 259 | if self.fail_first and not self.has_failed: 260 | self.has_failed = True 261 | raise Exception("Provider failure") 262 | 263 | return { 264 | "content": "Success response", 265 | "finish_reason": "stop", 266 | "usage": { 267 | "prompt_tokens": 10, 268 | "completion_tokens": 5, 269 | "total_tokens": 15 270 | } 271 | } 272 | 273 | provider = FailingProvider() 274 | cache = InMemoryCache() 275 | manager = RequestManager( 276 | provider=provider, 277 | caching_provider=cache, 278 | retry_attempts=1, # Only retry once 279 | show_progress=False, 280 | ) 281 | 282 | # Create a request 283 | request = { 284 | "provider": "dummy", 285 | "messages": [{"role": "user", "content": "Test message"}], 286 | "model": "dummy-model", 287 | "_request_id": "test-id", 288 | } 289 | 290 | # First attempt should fail 291 | with pytest.raises(Exception): 292 | await manager._process_request_async(None, request, None) 293 | 294 | # Verify the call was made 295 | assert provider.call_count == 1 296 | 297 | # Second attempt with same request should hit the provider again 298 | # since the failed response should not be cached 299 | try: 300 | await manager._process_request_async(None, request, None) 301 | except: 302 | pass 303 | 304 | assert provider.call_count == 2 305 | -------------------------------------------------------------------------------- /tests/test_openai.py: -------------------------------------------------------------------------------- 1 | import httpx 2 | import pytest 3 | 4 | from fastllm.providers.openai import OpenAIProvider 5 | 6 | # Constants for testing 7 | DEFAULT_TEMPERATURE = 0.9 8 | DEFAULT_TEMPERATURE_ALT = 0.6 9 | HTTP_OK_MIN = 200 10 | HTTP_OK_MAX = 300 11 | 12 | 13 | def test_prepare_payload_from_simple_request(): 14 | # Test preparing a simple request payload 15 | provider = OpenAIProvider(api_key="testkey") 16 | request = { 17 | "model": "gpt-3.5-turbo", 18 | "messages": [{"role": "user", "content": "Test message"}], 19 | "temperature": DEFAULT_TEMPERATURE 20 | } 21 | payload = provider._prepare_payload(request, "chat_completion") 22 | assert payload["model"] == "gpt-3.5-turbo" 23 | assert "messages" in payload 24 | messages = payload["messages"] 25 | assert isinstance(messages, list) 26 | assert messages[0]["role"] == "user" 27 | assert messages[0]["content"] == "Test message" 28 | assert payload["temperature"] == DEFAULT_TEMPERATURE 29 | 30 | 31 | def test_prepare_payload_from_system_message(): 32 | # Test preparing a payload with a system message 33 | provider = OpenAIProvider(api_key="testkey") 34 | request = { 35 | "model": "gpt-3.5-turbo", 36 | "messages": [{"role": "system", "content": "System message"}], 37 | } 38 | payload = provider._prepare_payload(request, "chat_completion") 39 | assert payload["model"] == "gpt-3.5-turbo" 40 | messages = payload["messages"] 41 | assert len(messages) == 1 42 | assert messages[0]["role"] == "system" 43 | assert messages[0]["content"] == "System message" 44 | 45 | 46 | def test_prepare_payload_omits_none_values(): 47 | # Test that None values are omitted from the payload 48 | provider = OpenAIProvider(api_key="testkey") 49 | request = { 50 | "model": "gpt-3.5-turbo", 51 | "messages": [{"role": "user", "content": "Ignore None"}], 52 | "top_p": None, 53 | "stop": None 54 | } 55 | payload = provider._prepare_payload(request, "chat_completion") 56 | # top_p and stop should not be in payload if they are None 57 | assert "top_p" not in payload 58 | assert "stop" not in payload 59 | 60 | 61 | def test_prepare_payload_omits_internal_tracking_ids(): 62 | # Test that internal tracking ids are never sent to providers 63 | provider = OpenAIProvider(api_key="testkey") 64 | request = { 65 | "model": "gpt-3.5-turbo", 66 | "messages": [{"role": "user", "content": "Hello world"}], 67 | "_order_id": 123, 68 | "_request_id": "abcd1234" 69 | } 70 | payload = provider._prepare_payload(request, "chat_completion") 71 | # _order_id and _request_id should not be in payload 72 | assert "_order_id" not in payload 73 | assert "_request_id" not in payload 74 | 75 | 76 | def test_prepare_payload_with_extra_params(): 77 | # Test that extra parameters are included in the payload 78 | provider = OpenAIProvider(api_key="testkey") 79 | request = { 80 | "model": "gpt-3.5-turbo", 81 | "messages": [{"role": "user", "content": "Extra params"}], 82 | "custom_param": "custom_value" 83 | } 84 | payload = provider._prepare_payload(request, "chat_completion") 85 | assert "custom_param" in payload 86 | assert payload["custom_param"] == "custom_value" 87 | 88 | 89 | def test_openai_provider_get_request_url(): 90 | # Test that the OpenAIProvider constructs the correct request URL 91 | provider = OpenAIProvider(api_key="testkey", api_base="https://api.openai.com") 92 | url = provider.get_request_url("completions") 93 | assert url == "https://api.openai.com/completions" 94 | 95 | 96 | def test_openai_provider_get_request_headers(): 97 | provider = OpenAIProvider( 98 | api_key="testkey", 99 | api_base="https://api.openai.com", 100 | organization="org-123", 101 | headers={"X-Custom": "custom-value"}, 102 | ) 103 | headers = provider.get_request_headers() 104 | assert headers["Authorization"] == "Bearer testkey" 105 | assert headers["Content-Type"] == "application/json" 106 | assert headers["OpenAI-Organization"] == "org-123" 107 | assert headers["X-Custom"] == "custom-value" 108 | 109 | 110 | def test_prepare_payload_for_embeddings(): 111 | # Test preparing a payload for embeddings 112 | provider = OpenAIProvider(api_key="testkey") 113 | request = { 114 | "model": "text-embedding-ada-002", 115 | "input": "Test input", 116 | "dimensions": 1536, 117 | "user": "test-user" 118 | } 119 | payload = provider._prepare_payload(request, "embedding") 120 | assert payload["model"] == "text-embedding-ada-002" 121 | assert payload["input"] == "Test input" 122 | assert payload["dimensions"] == 1536 123 | assert payload["user"] == "test-user" 124 | 125 | 126 | def test_prepare_payload_for_embeddings_with_array_input(): 127 | # Test preparing a payload for embeddings with array input 128 | provider = OpenAIProvider(api_key="testkey") 129 | request = { 130 | "model": "text-embedding-ada-002", 131 | "input": ["Test input 1", "Test input 2"], 132 | "encoding_format": "float" 133 | } 134 | payload = provider._prepare_payload(request, "embedding") 135 | assert payload["model"] == "text-embedding-ada-002" 136 | assert isinstance(payload["input"], list) 137 | assert len(payload["input"]) == 2 138 | assert payload["input"][0] == "Test input 1" 139 | assert payload["input"][1] == "Test input 2" 140 | assert payload["encoding_format"] == "float" 141 | 142 | 143 | def test_map_max_completion_tokens(): 144 | # Test that max_completion_tokens is properly mapped to max_tokens 145 | provider = OpenAIProvider(api_key="testkey") 146 | request = { 147 | "model": "gpt-3.5-turbo", 148 | "messages": [{"role": "user", "content": "Test message"}], 149 | "max_completion_tokens": 100 150 | } 151 | payload = provider._prepare_payload(request, "chat_completion") 152 | assert "max_tokens" in payload 153 | assert payload["max_tokens"] == 100 154 | assert "max_completion_tokens" not in payload 155 | 156 | 157 | class FakeResponse: 158 | def __init__(self, json_data, status_code=HTTP_OK_MIN): 159 | self._json_data = json_data 160 | self.status_code = status_code 161 | 162 | def raise_for_status(self): 163 | if not (HTTP_OK_MIN <= self.status_code < HTTP_OK_MAX): 164 | raise httpx.HTTPStatusError("Error", request=None, response=self) 165 | 166 | def json(self): 167 | return self._json_data 168 | 169 | 170 | class FakeAsyncClient: 171 | async def post(self, url, headers, json, timeout): 172 | # Return appropriate fake response based on request type 173 | if "embeddings" in url: 174 | # Return a fake embeddings response 175 | fake_json = { 176 | "object": "list", 177 | "data": [ 178 | { 179 | "object": "embedding", 180 | "embedding": [0.1, 0.2, 0.3], 181 | "index": 0 182 | } 183 | ], 184 | "model": "text-embedding-ada-002", 185 | "usage": {"prompt_tokens": 8, "total_tokens": 8} 186 | } 187 | else: 188 | # Return a fake chat completion response 189 | fake_json = { 190 | "id": "chatcmpl-xyz", 191 | "object": "chat.completion", 192 | "model": "gpt-3.5-turbo", 193 | "created": 1690000000, 194 | "choices": [ 195 | { 196 | "index": 0, 197 | "message": {"role": "assistant", "content": "Test reply"}, 198 | "finish_reason": "stop", 199 | } 200 | ], 201 | "usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3}, 202 | } 203 | return FakeResponse(fake_json) 204 | 205 | 206 | @pytest.mark.asyncio 207 | async def test_openai_provider_make_request(): 208 | provider = OpenAIProvider(api_key="testkey", api_base="https://api.openai.com") 209 | fake_client = FakeAsyncClient() 210 | request_data = { 211 | "model": "gpt-3.5-turbo", 212 | "messages": [{"role": "user", "content": "Tell me a joke."}], 213 | "temperature": 0.5, 214 | } 215 | # Pass a dict directly to make_request 216 | result = await provider.make_request(fake_client, request_data, timeout=1.0) 217 | # Check that the result has the expected fake response data 218 | assert result.id == "chatcmpl-xyz" 219 | assert result.object == "chat.completion" 220 | assert isinstance(result.choices, list) 221 | # Access the content via attributes 222 | assert result.choices[0].message.content == "Test reply" 223 | 224 | 225 | @pytest.mark.asyncio 226 | async def test_openai_provider_make_embedding_request(): 227 | provider = OpenAIProvider(api_key="testkey", api_base="https://api.openai.com") 228 | fake_client = FakeAsyncClient() 229 | request_data = { 230 | "model": "text-embedding-ada-002", 231 | "input": "Sample text for embedding", 232 | "type": "embedding" # Add type indicator for path determination 233 | } 234 | # Pass an embedding request to make_request 235 | result = await provider.make_request( 236 | fake_client, 237 | request_data, 238 | timeout=1.0 239 | ) 240 | # For embeddings, we expect a dict since it doesn't parse as ChatCompletion 241 | assert isinstance(result, dict) 242 | assert result["object"] == "list" 243 | assert "data" in result 244 | assert isinstance(result["data"], list) 245 | assert len(result["data"]) > 0 246 | assert "embedding" in result["data"][0] 247 | assert isinstance(result["data"][0]["embedding"], list) 248 | 249 | 250 | class FakeAsyncClientError: 251 | async def post(self, url, headers, json, timeout): 252 | # Return a fake response with an error status code 253 | return FakeResponse({"error": "Bad Request"}, status_code=400) 254 | 255 | 256 | @pytest.mark.asyncio 257 | async def test_openai_provider_make_request_error(): 258 | provider = OpenAIProvider(api_key="testkey", api_base="https://api.openai.com") 259 | fake_client_error = FakeAsyncClientError() 260 | request_data = { 261 | "model": "gpt-3.5-turbo", 262 | "messages": [{"role": "user", "content": "This will error."}], 263 | "temperature": 0.5, 264 | } 265 | with pytest.raises(httpx.HTTPStatusError): 266 | await provider.make_request(fake_client_error, request_data, timeout=1.0) 267 | 268 | 269 | def test_embedding_request_validation(): 270 | # Test validation for embedding requests 271 | provider = OpenAIProvider(api_key="testkey") 272 | 273 | # Test missing model 274 | with pytest.raises(ValueError): 275 | provider._prepare_payload({"input": "test"}, "embedding") 276 | 277 | # Test missing input 278 | with pytest.raises(ValueError): 279 | provider._prepare_payload({"model": "text-embedding-ada-002"}, "embedding") 280 | 281 | # Test valid minimal embedding request 282 | payload = provider._prepare_payload({ 283 | "model": "text-embedding-ada-002", 284 | "input": "test" 285 | }, "embedding") 286 | assert payload == {"model": "text-embedding-ada-002", "input": "test"} 287 | 288 | 289 | def test_embedding_request_type_detection(): 290 | # Test automatic detection of embedding requests 291 | provider = OpenAIProvider(api_key="testkey") 292 | fake_client = FakeAsyncClient() 293 | 294 | # Request with input field should be detected as embedding 295 | request = { 296 | "model": "text-embedding-ada-002", 297 | "input": "test", 298 | } 299 | 300 | # Check that the request is correctly identified as embedding type 301 | payload = provider._prepare_payload(request, "embedding") 302 | assert payload == {"model": "text-embedding-ada-002", "input": "test"} 303 | -------------------------------------------------------------------------------- /tests/test_progress_tracker.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from fastllm.core import ProgressTracker 4 | 5 | # Constants for testing 6 | INITIAL_PROMPT_TOKENS = 10 7 | INITIAL_COMPLETION_TOKENS = 20 8 | INITIAL_TOTAL_TOKENS = INITIAL_PROMPT_TOKENS + INITIAL_COMPLETION_TOKENS 9 | 10 | UPDATE_PROMPT_TOKENS = 5 11 | UPDATE_COMPLETION_TOKENS = 5 12 | FINAL_PROMPT_TOKENS = INITIAL_PROMPT_TOKENS + UPDATE_PROMPT_TOKENS 13 | FINAL_COMPLETION_TOKENS = INITIAL_COMPLETION_TOKENS + UPDATE_COMPLETION_TOKENS 14 | FINAL_TOTAL_TOKENS = FINAL_PROMPT_TOKENS + FINAL_COMPLETION_TOKENS 15 | 16 | REQUESTS_COMPLETED = 2 17 | 18 | 19 | def test_progress_tracker(): 20 | tracker = ProgressTracker(total_requests=1, show_progress=False) 21 | 22 | # Initial update 23 | tracker.update(INITIAL_PROMPT_TOKENS, INITIAL_COMPLETION_TOKENS, False) 24 | 25 | # Assert that the token stats are updated 26 | assert tracker.stats.prompt_tokens == INITIAL_PROMPT_TOKENS 27 | assert tracker.stats.completion_tokens == INITIAL_COMPLETION_TOKENS 28 | assert tracker.stats.total_tokens == INITIAL_TOTAL_TOKENS 29 | 30 | # Test cache hit scenario 31 | tracker.update(UPDATE_PROMPT_TOKENS, UPDATE_COMPLETION_TOKENS, True) 32 | assert tracker.stats.prompt_tokens == FINAL_PROMPT_TOKENS 33 | assert tracker.stats.completion_tokens == FINAL_COMPLETION_TOKENS 34 | assert tracker.stats.total_tokens == FINAL_TOTAL_TOKENS 35 | 36 | # The cache hit count should be incremented by 1 37 | assert tracker.stats.cache_hits == 1 38 | 39 | # Test that requests_completed is incremented correctly 40 | # Initially, it was 0, then two updates 41 | assert tracker.stats.requests_completed == REQUESTS_COMPLETED 42 | 43 | 44 | def test_progress_tracker_update(): 45 | # Create a ProgressTracker with a fixed total_requests and disable progress display 46 | tracker = ProgressTracker(total_requests=5, show_progress=False) 47 | # Simulate 1 second elapsed 48 | tracker.stats.start_time = time.time() - 1 49 | 50 | # Update tracker with some token counts 51 | tracker.update(INITIAL_PROMPT_TOKENS, INITIAL_COMPLETION_TOKENS, False) 52 | 53 | # Assert that the token stats are updated 54 | assert tracker.stats.prompt_tokens == INITIAL_PROMPT_TOKENS 55 | assert tracker.stats.completion_tokens == INITIAL_COMPLETION_TOKENS 56 | assert tracker.stats.total_tokens == INITIAL_TOTAL_TOKENS 57 | 58 | # Test cache hit scenario 59 | tracker.update(UPDATE_PROMPT_TOKENS, UPDATE_COMPLETION_TOKENS, True) 60 | assert tracker.stats.prompt_tokens == FINAL_PROMPT_TOKENS 61 | assert tracker.stats.completion_tokens == FINAL_COMPLETION_TOKENS 62 | assert tracker.stats.total_tokens == FINAL_TOTAL_TOKENS 63 | # The cache hit count should be incremented by 1 64 | assert tracker.stats.cache_hits == 1 65 | 66 | # Test that requests_completed is incremented correctly 67 | # Initially, it was 0, then two updates 68 | assert tracker.stats.requests_completed == REQUESTS_COMPLETED 69 | -------------------------------------------------------------------------------- /tests/test_providers.py: -------------------------------------------------------------------------------- 1 | from fastllm.providers.base import Provider 2 | from fastllm.providers.openai import OpenAIProvider 3 | from openai.types.chat import ChatCompletion 4 | from openai.types.chat.chat_completion import Choice, ChatCompletionMessage 5 | from openai.types.completion_usage import CompletionUsage 6 | 7 | 8 | class DummyProvider(Provider): 9 | def __init__( 10 | self, 11 | api_key="dummy", 12 | api_base="https://api.example.com", 13 | organization=None, 14 | headers=None, 15 | ): 16 | super().__init__(api_key, api_base, headers) 17 | self.organization = organization 18 | 19 | def get_request_headers(self): 20 | headers = { 21 | "Authorization": f"Bearer {self.api_key}", 22 | "Content-Type": "application/json", 23 | } 24 | if self.organization: 25 | headers["OpenAI-Organization"] = self.organization 26 | return headers 27 | 28 | async def _make_actual_request(self, client, request, timeout): 29 | # Simulate a dummy response 30 | return ChatCompletion( 31 | id="chatcmpl-dummy", 32 | object="chat.completion", 33 | created=1234567890, 34 | model=request.get("model", "dummy-model"), 35 | choices=[ 36 | Choice( 37 | index=0, 38 | message=ChatCompletionMessage( 39 | role="assistant", 40 | content="This is a dummy response", 41 | ), 42 | finish_reason="stop", 43 | ) 44 | ], 45 | usage=CompletionUsage( 46 | prompt_tokens=10, 47 | completion_tokens=10, 48 | total_tokens=20, 49 | ), 50 | ) 51 | 52 | async def make_request(self, client, request, timeout): 53 | return await self._make_actual_request(client, request, timeout) 54 | 55 | 56 | def test_dummy_provider_get_request_url(): 57 | provider = DummyProvider(api_key="testkey", api_base="https://api.test.com") 58 | url = provider.get_request_url("endpoint") 59 | assert url == "https://api.test.com/endpoint" 60 | 61 | 62 | def test_dummy_provider_get_request_headers(): 63 | provider = DummyProvider(api_key="testkey") 64 | headers = provider.get_request_headers() 65 | assert headers["Authorization"] == "Bearer testkey" 66 | assert headers["Content-Type"] == "application/json" 67 | 68 | 69 | def test_openai_provider_get_request_headers_org(): 70 | provider = OpenAIProvider(api_key="testkey", organization="org123") 71 | headers = provider.get_request_headers() 72 | assert headers["Authorization"] == "Bearer testkey" 73 | assert headers["Content-Type"] == "application/json" 74 | assert headers.get("OpenAI-Organization") == "org123" 75 | 76 | 77 | def test_openai_provider_prepare_chat_completion_payload(): 78 | # Test conversion from a simple request to payload 79 | provider = OpenAIProvider(api_key="testkey") 80 | request = { 81 | "model": "gpt-dummy", 82 | "messages": [{"role": "user", "content": "Hello world!"}] 83 | } 84 | payload = provider._prepare_payload(request, "chat_completion") 85 | 86 | # Verify that the model and messages are set correctly 87 | assert payload["model"] == "gpt-dummy" 88 | assert "messages" in payload 89 | assert isinstance(payload["messages"], list) 90 | assert payload["messages"][0]["role"] == "user" 91 | assert payload["messages"][0]["content"] == "Hello world!" 92 | 93 | 94 | def test_openai_provider_prepare_embedding_payload(): 95 | # Test conversion from embedding request to payload 96 | provider = OpenAIProvider(api_key="testkey") 97 | request = { 98 | "model": "text-embedding-3-small", 99 | "input": ["Sample text 1", "Sample text 2"] 100 | } 101 | payload = provider._prepare_payload(request, "embedding") 102 | 103 | # Verify model and input are set correctly 104 | assert payload["model"] == "text-embedding-3-small" 105 | assert payload["input"] == ["Sample text 1", "Sample text 2"] 106 | -------------------------------------------------------------------------------- /tests/test_request_batch.py: -------------------------------------------------------------------------------- 1 | """Tests for request batching functionality.""" 2 | 3 | import pytest 4 | from fastllm.core import RequestBatch 5 | from fastllm.cache import compute_request_hash 6 | 7 | # Constants for testing 8 | EXPECTED_REQUESTS = 2 9 | FIRST_REQUEST_ID = 0 10 | SECOND_REQUEST_ID = 1 11 | 12 | # Constants for batch addition testing 13 | BATCH_SIZE_ONE = 1 14 | BATCH_SIZE_TWO = 2 15 | BATCH_SIZE_THREE = 3 16 | FIRST_BATCH_START_ID = 0 17 | SECOND_BATCH_START_ID = 1 18 | THIRD_BATCH_START_ID = 2 19 | 20 | # Constants for multiple additions testing 21 | INITIAL_BATCH_SIZE = 1 22 | FINAL_BATCH_SIZE = 3 23 | FIRST_MULTIPLE_ID = 0 24 | SECOND_MULTIPLE_ID = 1 25 | THIRD_MULTIPLE_ID = 2 26 | 27 | 28 | def test_request_batch(): 29 | """Test basic request batch functionality and request_id generation.""" 30 | batch = RequestBatch() 31 | request_id = batch.chat.completions.create( 32 | model="dummy-model", 33 | messages=[{"role": "user", "content": "Hi"}], 34 | ) 35 | 36 | # Verify request was added 37 | assert len(batch.requests) == 1 38 | 39 | # Verify OpenAI Batch format 40 | assert "custom_id" in batch.requests[0] 41 | assert "url" in batch.requests[0] 42 | assert "body" in batch.requests[0] 43 | 44 | # Verify custom_id format and extract request_id and order_id 45 | custom_id_parts = batch.requests[0]["custom_id"].split("#") 46 | assert len(custom_id_parts) == 2 47 | extracted_request_id, order_id_str = custom_id_parts 48 | assert extracted_request_id == request_id 49 | assert order_id_str == "0" 50 | 51 | # Verify URL indicates chat completion 52 | assert batch.requests[0]["url"] == "/v1/chat/completions" 53 | 54 | # Verify request_id is computed correctly 55 | # Include all fields that affect the hash 56 | expected_request = {"type": "chat_completion", **batch.requests[0]["body"]} 57 | assert request_id == compute_request_hash(expected_request) 58 | 59 | 60 | def test_request_batch_merge(): 61 | """Test merging request batches and request_id preservation.""" 62 | # Create first batch 63 | batch1 = RequestBatch() 64 | request_id1 = batch1.chat.completions.create( 65 | model="dummy-model", 66 | messages=[{"role": "user", "content": "Hi"}], 67 | ) 68 | assert len(batch1.requests) == 1 69 | assert batch1.requests[0]["custom_id"].split("#")[0] == request_id1 70 | 71 | # Create second batch 72 | batch2 = RequestBatch() 73 | request_id2 = batch2.chat.completions.create( 74 | model="dummy-model", 75 | messages=[{"role": "user", "content": "Hello"}], 76 | ) 77 | request_id3 = batch2.chat.completions.create( 78 | model="dummy-model", 79 | messages=[{"role": "user", "content": "Hey"}], 80 | ) 81 | assert len(batch2.requests) == 2 82 | assert batch2.requests[0]["custom_id"].split("#")[0] == request_id2 83 | assert batch2.requests[1]["custom_id"].split("#")[0] == request_id3 84 | 85 | # Test merging batches 86 | merged_batch = RequestBatch.merge([batch1, batch2]) 87 | assert len(merged_batch.requests) == 3 88 | 89 | # Verify request_ids are preserved after merge 90 | assert merged_batch.requests[0]["custom_id"].split("#")[0] == request_id1 91 | assert merged_batch.requests[1]["custom_id"].split("#")[0] == request_id2 92 | assert merged_batch.requests[2]["custom_id"].split("#")[0] == request_id3 93 | 94 | 95 | def test_request_batch_multiple_merges(): 96 | """Test merging multiple request batches and request_id preservation.""" 97 | # Create first batch 98 | batch1 = RequestBatch() 99 | request_id1 = batch1.chat.completions.create( 100 | model="dummy-model", 101 | messages=[{"role": "user", "content": "Hi"}], 102 | ) 103 | assert len(batch1.requests) == 1 104 | assert batch1.requests[0]["custom_id"].split("#")[0] == request_id1 105 | 106 | # Create second batch 107 | batch2 = RequestBatch() 108 | request_id2 = batch2.chat.completions.create( 109 | model="dummy-model", 110 | messages=[{"role": "user", "content": "Hello"}], 111 | ) 112 | assert batch2.requests[0]["custom_id"].split("#")[0] == request_id2 113 | 114 | # Create third batch 115 | batch3 = RequestBatch() 116 | request_id3 = batch3.chat.completions.create( 117 | model="dummy-model", 118 | messages=[{"role": "user", "content": "Hey"}], 119 | ) 120 | assert batch3.requests[0]["custom_id"].split("#")[0] == request_id3 121 | 122 | # Test merging multiple batches 123 | final_batch = RequestBatch.merge([batch1, batch2, batch3]) 124 | assert len(final_batch.requests) == 3 125 | 126 | # Verify request_ids are preserved after merge 127 | assert final_batch.requests[0]["custom_id"].split("#")[0] == request_id1 128 | assert final_batch.requests[1]["custom_id"].split("#")[0] == request_id2 129 | assert final_batch.requests[2]["custom_id"].split("#")[0] == request_id3 130 | 131 | 132 | def test_request_id_consistency(): 133 | """Test that identical requests get the same request_id.""" 134 | batch = RequestBatch() 135 | 136 | # Create two identical requests 137 | request_id1 = batch.chat.completions.create( 138 | model="dummy-model", 139 | messages=[{"role": "user", "content": "Hi"}], 140 | ) 141 | 142 | # Create a new batch to avoid order_id interference 143 | batch2 = RequestBatch() 144 | request_id2 = batch2.chat.completions.create( 145 | model="dummy-model", 146 | messages=[{"role": "user", "content": "Hi"}], 147 | ) 148 | 149 | # Verify that identical requests get the same request_id 150 | assert request_id1 == request_id2 151 | 152 | # Create a different request 153 | request_id3 = batch2.chat.completions.create( 154 | model="dummy-model", 155 | messages=[{"role": "user", "content": "Different"}], 156 | ) 157 | 158 | # Verify that different requests get different request_ids 159 | assert request_id1 != request_id3 160 | 161 | 162 | def test_request_id_with_none_values(): 163 | """Test that None values are properly handled in request_id computation.""" 164 | batch = RequestBatch() 165 | 166 | # Create request with some None values 167 | request_id = batch.chat.completions.create( 168 | model="dummy-model", 169 | messages=[{"role": "user", "content": "Hi"}], 170 | temperature=None, # Should be replaced with default 171 | top_p=None, # Should be replaced with default 172 | ) 173 | 174 | # Get the actual request body that was created 175 | body1 = batch.requests[0]["body"] 176 | 177 | # Create identical request without None values 178 | batch2 = RequestBatch() 179 | request_id2 = batch2.chat.completions.create( 180 | model="dummy-model", 181 | messages=[{"role": "user", "content": "Hi"}], 182 | ) 183 | 184 | # Get the second actual request body 185 | body2 = batch2.requests[0]["body"] 186 | 187 | # Both requests should have the same content 188 | assert body1 == body2 189 | 190 | # Both request IDs should be identical since None values are replaced with defaults 191 | assert request_id == request_id2 192 | 193 | 194 | def test_embeddings_request_format(): 195 | """Test the format of embeddings requests.""" 196 | batch = RequestBatch() 197 | request_id = batch.embeddings.create( 198 | model="text-embedding-ada-002", 199 | input="Hello world", 200 | ) 201 | 202 | # Verify request was added 203 | assert len(batch.requests) == 1 204 | 205 | # Verify OpenAI Batch format 206 | assert "custom_id" in batch.requests[0] 207 | assert batch.requests[0]["url"] == "/v1/embeddings" 208 | assert "body" in batch.requests[0] 209 | 210 | # Verify custom_id format and extract request_id and order_id 211 | custom_id_parts = batch.requests[0]["custom_id"].split("#") 212 | assert len(custom_id_parts) == 2 213 | extracted_request_id, order_id_str = custom_id_parts 214 | assert extracted_request_id == request_id 215 | assert order_id_str == "0" 216 | 217 | # Verify body content 218 | assert batch.requests[0]["body"]["model"] == "text-embedding-ada-002" 219 | assert batch.requests[0]["body"]["input"] == "Hello world" 220 | --------------------------------------------------------------------------------