├── justllms ├── data │ └── __init__.py ├── __version__.py ├── routing │ ├── __init__.py │ └── router.py ├── sxs │ ├── __init__.py │ ├── core │ │ ├── __init__.py │ │ └── executor.py │ └── models │ │ ├── __init__.py │ │ └── comparison.py ├── config │ ├── __init__.py │ └── config.py ├── utils │ ├── __init__.py │ ├── validators.py │ └── token_counter.py ├── core │ ├── __init__.py │ ├── models.py │ ├── streaming.py │ └── openai_base.py ├── exceptions │ ├── __init__.py │ └── exceptions.py ├── tools │ ├── adapters │ │ ├── __init__.py │ │ ├── azure.py │ │ ├── base.py │ │ ├── anthropic.py │ │ ├── openai.py │ │ └── google.py │ ├── __init__.py │ ├── native │ │ ├── __init__.py │ │ ├── google_tools.py │ │ └── manager.py │ ├── google.py │ ├── models.py │ ├── decorators.py │ ├── registry.py │ ├── utils.py │ └── executor.py ├── __init__.py ├── providers │ ├── __init__.py │ ├── deepseek.py │ ├── openai.py │ ├── anthropic.py │ └── grok.py └── cli.py ├── MANIFEST.in ├── .gitignore ├── LICENSE ├── .github ├── workflows │ ├── ci.yml │ ├── pr-labeler.yml │ └── release.yml └── labeler.yml └── pyproject.toml /justllms/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /justllms/__version__.py: -------------------------------------------------------------------------------- 1 | __version__ = "2.1.8" 2 | -------------------------------------------------------------------------------- /justllms/routing/__init__.py: -------------------------------------------------------------------------------- 1 | from justllms.routing.router import Router 2 | 3 | __all__ = [ 4 | "Router", 5 | ] 6 | -------------------------------------------------------------------------------- /justllms/sxs/__init__.py: -------------------------------------------------------------------------------- 1 | from justllms.sxs.cli import run_interactive_sxs 2 | 3 | __all__ = ["run_interactive_sxs"] 4 | -------------------------------------------------------------------------------- /justllms/sxs/core/__init__.py: -------------------------------------------------------------------------------- 1 | from justllms.sxs.core.executor import ParallelExecutor 2 | 3 | __all__ = ["ParallelExecutor"] 4 | -------------------------------------------------------------------------------- /justllms/config/__init__.py: -------------------------------------------------------------------------------- 1 | from justllms.config.config import Config, load_config 2 | 3 | __all__ = [ 4 | "Config", 5 | "load_config", 6 | ] 7 | -------------------------------------------------------------------------------- /justllms/sxs/models/__init__.py: -------------------------------------------------------------------------------- 1 | from justllms.sxs.models.comparison import ModelResponse, ResponseStatus 2 | 3 | __all__ = ["ModelResponse", "ResponseStatus"] 4 | -------------------------------------------------------------------------------- /justllms/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from justllms.utils.token_counter import TokenCounter, count_tokens 2 | from justllms.utils.validators import validate_messages 3 | 4 | __all__ = ["TokenCounter", "count_tokens", "validate_messages"] 5 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE 3 | include pyproject.toml 4 | recursive-exclude * __pycache__ 5 | recursive-exclude * *.py[co] 6 | recursive-exclude * .DS_Store 7 | recursive-exclude examples * 8 | recursive-exclude docs * 9 | recursive-exclude tests * -------------------------------------------------------------------------------- /justllms/core/__init__.py: -------------------------------------------------------------------------------- 1 | from justllms.core.base import BaseProvider, BaseResponse 2 | from justllms.core.client import Client 3 | from justllms.core.completion import Completion, CompletionResponse 4 | from justllms.core.models import Message, Role, Usage 5 | 6 | __all__ = [ 7 | "BaseProvider", 8 | "BaseResponse", 9 | "Client", 10 | "Completion", 11 | "CompletionResponse", 12 | "Message", 13 | "Role", 14 | "Usage", 15 | ] 16 | -------------------------------------------------------------------------------- /justllms/exceptions/__init__.py: -------------------------------------------------------------------------------- 1 | from justllms.exceptions.exceptions import ( 2 | AuthenticationError, 3 | ConfigurationError, 4 | JustLLMsError, 5 | ProviderError, 6 | RateLimitError, 7 | TimeoutError, 8 | ValidationError, 9 | ) 10 | 11 | __all__ = [ 12 | "JustLLMsError", 13 | "ProviderError", 14 | "ValidationError", 15 | "RateLimitError", 16 | "TimeoutError", 17 | "AuthenticationError", 18 | "ConfigurationError", 19 | ] 20 | -------------------------------------------------------------------------------- /justllms/tools/adapters/__init__.py: -------------------------------------------------------------------------------- 1 | from justllms.tools.adapters.anthropic import AnthropicToolAdapter 2 | from justllms.tools.adapters.azure import AzureToolAdapter 3 | from justllms.tools.adapters.base import BaseToolAdapter 4 | from justllms.tools.adapters.google import GoogleToolAdapter 5 | from justllms.tools.adapters.openai import OpenAIToolAdapter 6 | 7 | __all__ = [ 8 | "BaseToolAdapter", 9 | "OpenAIToolAdapter", 10 | "AnthropicToolAdapter", 11 | "GoogleToolAdapter", 12 | "AzureToolAdapter", 13 | ] 14 | -------------------------------------------------------------------------------- /justllms/tools/adapters/azure.py: -------------------------------------------------------------------------------- 1 | from justllms.tools.adapters.openai import OpenAIToolAdapter 2 | 3 | 4 | class AzureToolAdapter(OpenAIToolAdapter): 5 | """Azure OpenAI uses the same function calling format as OpenAI. 6 | 7 | Since Azure OpenAI Service provides OpenAI's models through Azure, 8 | they use identical API formats for function calling. This adapter 9 | simply inherits all functionality from the OpenAI adapter. 10 | """ 11 | 12 | pass # All functionality inherited from OpenAIToolAdapter 13 | -------------------------------------------------------------------------------- /justllms/sxs/models/comparison.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from enum import Enum 3 | from typing import Optional 4 | 5 | 6 | class ResponseStatus(Enum): 7 | """Status of a model response.""" 8 | 9 | COMPLETED = "completed" 10 | ERROR = "error" 11 | 12 | 13 | @dataclass 14 | class ModelResponse: 15 | """Response from a single model.""" 16 | 17 | provider: str 18 | model: str 19 | content: str 20 | status: ResponseStatus 21 | latency: float 22 | tokens: int 23 | cost: float 24 | error: Optional[str] = None 25 | -------------------------------------------------------------------------------- /justllms/tools/__init__.py: -------------------------------------------------------------------------------- 1 | from justllms.tools.decorators import tool, tool_from_callable 2 | from justllms.tools.google import GoogleCodeExecution, GoogleSearch 3 | from justllms.tools.models import Tool, ToolCall, ToolExecutionEntry, ToolResult 4 | from justllms.tools.registry import GlobalToolRegistry, ToolRegistry 5 | 6 | __all__ = [ 7 | "tool", 8 | "tool_from_callable", 9 | "Tool", 10 | "ToolCall", 11 | "ToolResult", 12 | "ToolExecutionEntry", 13 | "ToolRegistry", 14 | "GlobalToolRegistry", 15 | "GoogleSearch", 16 | "GoogleCodeExecution", 17 | ] 18 | -------------------------------------------------------------------------------- /justllms/tools/native/__init__.py: -------------------------------------------------------------------------------- 1 | from justllms.tools.native.google_tools import ( 2 | GOOGLE_NATIVE_TOOLS, 3 | GoogleCodeExecution, 4 | GoogleNativeTool, 5 | GoogleSearch, 6 | get_google_native_tool, 7 | ) 8 | from justllms.tools.native.manager import ( 9 | GoogleNativeToolManager, 10 | NativeToolManager, 11 | create_native_tool_manager, 12 | ) 13 | 14 | __all__ = [ 15 | "GoogleNativeTool", 16 | "GoogleSearch", 17 | "GoogleCodeExecution", 18 | "GOOGLE_NATIVE_TOOLS", 19 | "get_google_native_tool", 20 | "NativeToolManager", 21 | "GoogleNativeToolManager", 22 | "create_native_tool_manager", 23 | ] 24 | -------------------------------------------------------------------------------- /justllms/__init__.py: -------------------------------------------------------------------------------- 1 | from justllms.__version__ import __version__ 2 | from justllms.core.client import Client 3 | from justllms.core.completion import Completion, CompletionResponse 4 | from justllms.core.models import Message, Role 5 | from justllms.exceptions import JustLLMsError, ProviderError 6 | from justllms.tools import GoogleCodeExecution, GoogleSearch, Tool, ToolRegistry, tool 7 | 8 | JustLLM = Client 9 | 10 | __all__ = [ 11 | "__version__", 12 | "JustLLM", 13 | "Client", 14 | "Completion", 15 | "CompletionResponse", 16 | "Message", 17 | "Role", 18 | "JustLLMsError", 19 | "ProviderError", 20 | "tool", 21 | "Tool", 22 | "ToolRegistry", 23 | "GoogleSearch", 24 | "GoogleCodeExecution", 25 | ] 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *$py.class 4 | 5 | .Python 6 | build/ 7 | develop-eggs/ 8 | dist/ 9 | downloads/ 10 | eggs/ 11 | .eggs/ 12 | lib/ 13 | lib64/ 14 | parts/ 15 | sdist/ 16 | var/ 17 | wheels/ 18 | share/python-wheels/ 19 | *.egg-info/ 20 | .installed.cfg 21 | *.egg 22 | MANIFEST 23 | *.manifest 24 | *.spec 25 | 26 | 27 | 28 | # Environments 29 | .env 30 | .venv 31 | env/ 32 | venv/ 33 | .claude 34 | .vscode/ 35 | 36 | # JustLLMs specific 37 | *.csv 38 | *.pdf 39 | test_*.py 40 | demo_*.py 41 | debug_*.py 42 | simple_*.py 43 | comprehensive_test.py 44 | CLAUDE.md 45 | 46 | # Screenshots and temporary files 47 | *.png 48 | *.jpg 49 | *.jpeg 50 | *.gif 51 | temp/ 52 | tmp/ 53 | .DS_Store 54 | 55 | # API keys and secrets 56 | .env.local 57 | config.json 58 | secrets.json 59 | 60 | # Cache directories 61 | redis_cache/ 62 | disk_cache/ 63 | 64 | # Test results and reports 65 | test_results/ 66 | analytics_reports/ 67 | conversation_exports/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 JustLLMs 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /justllms/providers/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Type 2 | 3 | from justllms.core.base import BaseProvider 4 | 5 | _PROVIDERS: Dict[str, Type[BaseProvider]] = {} 6 | 7 | 8 | def register_provider(name: str, provider_class: Type[BaseProvider]) -> None: 9 | """Register a provider class.""" 10 | _PROVIDERS[name.lower()] = provider_class 11 | 12 | 13 | def get_provider_class(name: str) -> Optional[Type[BaseProvider]]: 14 | """Get a provider class by name.""" 15 | return _PROVIDERS.get(name.lower()) 16 | 17 | 18 | def list_available_providers() -> List[str]: 19 | """List all available provider names.""" 20 | return list(_PROVIDERS.keys()) 21 | 22 | 23 | try: 24 | from justllms.providers.openai import OpenAIProvider 25 | 26 | register_provider("openai", OpenAIProvider) 27 | except ImportError: 28 | pass 29 | 30 | try: 31 | from justllms.providers.azure_openai import AzureOpenAIProvider 32 | 33 | register_provider("azure_openai", AzureOpenAIProvider) 34 | except ImportError: 35 | pass 36 | 37 | try: 38 | from justllms.providers.anthropic import AnthropicProvider 39 | 40 | register_provider("anthropic", AnthropicProvider) 41 | register_provider("claude", AnthropicProvider) 42 | except ImportError: 43 | pass 44 | 45 | try: 46 | from justllms.providers.google import GoogleProvider 47 | 48 | register_provider("google", GoogleProvider) 49 | register_provider("gemini", GoogleProvider) 50 | except ImportError: 51 | pass 52 | 53 | try: 54 | from justllms.providers.grok import GrokProvider 55 | 56 | register_provider("grok", GrokProvider) 57 | register_provider("xai", GrokProvider) 58 | except ImportError: 59 | pass 60 | 61 | try: 62 | from justllms.providers.deepseek import DeepSeekProvider 63 | 64 | register_provider("deepseek", DeepSeekProvider) 65 | except ImportError: 66 | pass 67 | 68 | try: 69 | from justllms.providers.ollama import OllamaProvider 70 | 71 | register_provider("ollama", OllamaProvider) 72 | except ImportError: 73 | pass 74 | 75 | 76 | __all__ = [ 77 | "register_provider", 78 | "get_provider_class", 79 | "list_available_providers", 80 | ] 81 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | types: [opened, synchronize, reopened, ready_for_review] 9 | merge_group: 10 | 11 | permissions: 12 | contents: read 13 | pull-requests: read 14 | 15 | jobs: 16 | lint: 17 | name: Lint Code 18 | runs-on: ubuntu-latest 19 | steps: 20 | - uses: actions/checkout@v4 21 | 22 | - name: Set up Python 23 | uses: actions/setup-python@v5 24 | with: 25 | python-version: '3.11' 26 | 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | pip install ruff black 31 | 32 | - name: Run Ruff 33 | run: ruff check justllms/ 34 | 35 | - name: Check Black formatting 36 | run: black --check justllms/ 37 | 38 | type-check: 39 | name: Type Check 40 | runs-on: ubuntu-latest 41 | steps: 42 | - uses: actions/checkout@v4 43 | 44 | - name: Set up Python 45 | uses: actions/setup-python@v5 46 | with: 47 | python-version: '3.11' 48 | 49 | - name: Install dependencies 50 | run: | 51 | python -m pip install --upgrade pip 52 | pip install mypy types-PyYAML types-requests 53 | pip install -e . 54 | 55 | - name: Run mypy 56 | run: mypy justllms/ --ignore-missing-imports 57 | 58 | test-build: 59 | name: Test Build - Python ${{ matrix.python-version }} 60 | runs-on: ubuntu-latest 61 | strategy: 62 | fail-fast: false 63 | matrix: 64 | python-version: ['3.10', '3.11', '3.12', '3.13'] 65 | 66 | steps: 67 | - uses: actions/checkout@v4 68 | 69 | - name: Set up Python ${{ matrix.python-version }} 70 | uses: actions/setup-python@v5 71 | with: 72 | python-version: ${{ matrix.python-version }} 73 | 74 | - name: Install build dependencies 75 | run: | 76 | python -m pip install --upgrade pip 77 | pip install build 78 | 79 | - name: Build package 80 | run: python -m build 81 | 82 | - name: Verify package 83 | run: | 84 | pip install dist/*.whl 85 | python -c "import justllms; print(f'JustLLMs {justllms.__version__} imported successfully')" 86 | 87 | - name: Check package contents 88 | run: | 89 | pip show -f justllms -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "justllms" 7 | version = "2.1.8" 8 | description = "Production-focused Python library for multi-provider LLM management with unified API" 9 | readme = "README.md" 10 | license = {text = "MIT"} 11 | authors = [ 12 | {name = "darshan harihar", email = "darshanharihar2950@gmail.com"} 13 | ] 14 | keywords = ["llm", "ai", "openai", "anthropic", "gemini", "gateway", "proxy"] 15 | classifiers = [ 16 | "Development Status :: 4 - Beta", 17 | "Intended Audience :: Developers", 18 | "License :: OSI Approved :: MIT License", 19 | "Programming Language :: Python :: 3", 20 | "Programming Language :: Python :: 3.8", 21 | "Programming Language :: Python :: 3.9", 22 | "Programming Language :: Python :: 3.10", 23 | "Programming Language :: Python :: 3.11", 24 | "Programming Language :: Python :: 3.12", 25 | "Topic :: Software Development :: Libraries :: Python Modules", 26 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 27 | ] 28 | requires-python = ">=3.8" 29 | dependencies = [ 30 | "httpx>=0.25.0", 31 | "pydantic>=2.0.0", 32 | "tenacity>=8.0.0", 33 | "tiktoken>=0.5.0", 34 | "python-dotenv>=1.0.0", 35 | "rich>=13.0.0", 36 | "PyYAML>=6.0.0", 37 | "click>=8.0.0", 38 | "questionary>=2.0.0", 39 | ] 40 | 41 | [project.optional-dependencies] 42 | dev = [ 43 | "pytest>=7.0.0", 44 | "pytest-asyncio>=0.21.0", 45 | "pytest-cov>=4.0.0", 46 | "black>=23.0.0", 47 | "ruff>=0.1.0", 48 | "mypy>=1.0.0", 49 | "pre-commit>=3.0.0", 50 | ] 51 | docs = [ 52 | "sphinx>=6.0.0", 53 | "sphinx-rtd-theme>=1.3.0", 54 | "sphinx-autodoc-typehints>=1.24.0", 55 | "myst-parser>=2.0.0", 56 | ] 57 | providers = [ 58 | "openai>=1.0.0", 59 | "anthropic>=0.18.0", 60 | "google-generativeai>=0.3.0", 61 | "cohere>=4.0.0", 62 | "replicate>=0.15.0", 63 | ] 64 | analytics = [ 65 | "reportlab>=4.0.0", 66 | "matplotlib>=3.5.0", 67 | "pandas>=1.3.0", 68 | ] 69 | 70 | [project.scripts] 71 | justllms = "justllms.cli:main" 72 | 73 | [project.urls] 74 | Homepage = "https://github.com/just-llms/justllms" 75 | Documentation = "https://github.com/just-llms/justllms#readme" 76 | Repository = "https://github.com/just-llms/justllms" 77 | "Bug Tracker" = "https://github.com/just-llms/justllms/issues" 78 | 79 | [tool.setuptools.packages.find] 80 | include = ["justllms*"] 81 | exclude = ["tests*", "docs*", "examples*"] 82 | 83 | [tool.black] 84 | line-length = 100 85 | target-version = ["py38", "py39", "py310", "py311", "py312"] 86 | 87 | [tool.ruff] 88 | line-length = 100 89 | target-version = "py38" 90 | 91 | [tool.ruff.lint] 92 | select = ["E", "F", "I", "UP", "B", "SIM", "C90"] 93 | ignore = ["E501", "C901"] 94 | 95 | [tool.mypy] 96 | python_version = "3.10" 97 | warn_return_any = true 98 | warn_unused_configs = true 99 | disallow_untyped_defs = true 100 | exclude = [".venv/", "build/", "dist/"] 101 | -------------------------------------------------------------------------------- /justllms/core/models.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Any, Dict, List, Optional, Union 3 | 4 | from pydantic import BaseModel, ConfigDict, Field 5 | 6 | 7 | class Role(str, Enum): 8 | """Message role enumeration.""" 9 | 10 | SYSTEM = "system" 11 | USER = "user" 12 | ASSISTANT = "assistant" 13 | FUNCTION = "function" 14 | TOOL = "tool" 15 | 16 | 17 | class Message(BaseModel): 18 | """Unified message format for all providers.""" 19 | 20 | model_config = ConfigDict(extra="allow") 21 | 22 | role: Role 23 | content: Union[str, List[Dict[str, Any]]] 24 | name: Optional[str] = None 25 | function_call: Optional[Dict[str, Any]] = None 26 | tool_calls: Optional[List[Dict[str, Any]]] = None 27 | tool_call_id: Optional[str] = None # Required for OpenAI/Azure tool results 28 | 29 | 30 | class Usage(BaseModel): 31 | """Token usage statistics.""" 32 | 33 | prompt_tokens: int 34 | completion_tokens: int 35 | total_tokens: int 36 | estimated_cost: Optional[float] = None 37 | 38 | 39 | class Choice(BaseModel): 40 | """Response choice.""" 41 | 42 | index: int 43 | message: Message 44 | finish_reason: Optional[str] = None 45 | logprobs: Optional[Dict[str, Any]] = None 46 | 47 | 48 | class ModelInfo(BaseModel): 49 | """Information about a model.""" 50 | 51 | name: str 52 | provider: str 53 | max_tokens: Optional[int] = None 54 | max_context_length: Optional[int] = None 55 | supports_functions: bool = False 56 | supports_vision: bool = False 57 | 58 | cost_per_1k_prompt_tokens: Optional[float] = None 59 | cost_per_1k_completion_tokens: Optional[float] = None 60 | latency_ms_per_token: Optional[float] = None 61 | tags: List[str] = Field(default_factory=list) 62 | 63 | 64 | class ProviderConfig(BaseModel): 65 | """Configuration for a provider instance. 66 | 67 | Unified configuration model that combines settings from config files 68 | and runtime provider requirements. Supports both application config 69 | fields and provider-specific settings. 70 | """ 71 | 72 | model_config = ConfigDict(extra="allow") 73 | 74 | name: str 75 | api_key: Optional[str] = None 76 | enabled: bool = True 77 | api_base: Optional[str] = None 78 | base_url: Optional[str] = None 79 | api_version: Optional[str] = None 80 | organization: Optional[str] = None 81 | timeout: Optional[float] = None 82 | max_retries: int = 3 83 | retry_delay: float = 1.0 84 | rate_limit: Optional[int] = None 85 | headers: Dict[str, str] = Field(default_factory=dict) 86 | deployment_mapping: Dict[str, str] = Field(default_factory=dict) 87 | 88 | # Tool-related configuration 89 | native_tools: Optional[Dict[str, Any]] = None 90 | """Configuration for provider-native tools (e.g., Google Search for Gemini).""" 91 | 92 | def model_post_init(self, __context: Any) -> None: 93 | """Handle alternative field names and normalize configuration.""" 94 | if self.base_url and not self.api_base: 95 | self.api_base = self.base_url 96 | elif self.api_base and not self.base_url: 97 | self.base_url = self.api_base 98 | -------------------------------------------------------------------------------- /justllms/providers/deepseek.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from justllms.core.base import BaseResponse 4 | from justllms.core.models import ModelInfo 5 | from justllms.core.openai_base import BaseOpenAIChatProvider 6 | 7 | 8 | class DeepSeekResponse(BaseResponse): 9 | pass 10 | 11 | 12 | class DeepSeekProvider(BaseOpenAIChatProvider): 13 | """DeepSeek provider implementation.""" 14 | 15 | MODELS = { 16 | "deepseek-chat": ModelInfo( 17 | name="deepseek-chat", 18 | provider="deepseek", 19 | max_tokens=8192, 20 | max_context_length=65536, 21 | supports_functions=True, 22 | supports_vision=False, 23 | cost_per_1k_prompt_tokens=0.27, 24 | cost_per_1k_completion_tokens=1.10, 25 | tags=["chat", "general-purpose", "json-output", "function-calling"], 26 | ), 27 | "deepseek-chat-cached": ModelInfo( 28 | name="deepseek-chat", 29 | provider="deepseek", 30 | max_tokens=8192, 31 | max_context_length=65536, 32 | supports_functions=True, 33 | supports_vision=False, 34 | cost_per_1k_prompt_tokens=0.07, 35 | cost_per_1k_completion_tokens=1.10, 36 | tags=["chat", "cached", "discount", "general-purpose"], 37 | ), 38 | "deepseek-reasoner": ModelInfo( 39 | name="deepseek-reasoner", 40 | provider="deepseek", 41 | max_tokens=65536, 42 | max_context_length=65536, 43 | supports_functions=True, 44 | supports_vision=False, 45 | cost_per_1k_prompt_tokens=0.55, 46 | cost_per_1k_completion_tokens=2.19, 47 | tags=["reasoning", "analysis", "complex-tasks", "json-output", "advanced"], 48 | ), 49 | "deepseek-reasoner-cached": ModelInfo( 50 | name="deepseek-reasoner", 51 | provider="deepseek", 52 | max_tokens=65536, 53 | max_context_length=65536, 54 | supports_functions=True, 55 | supports_vision=False, 56 | cost_per_1k_prompt_tokens=0.14, 57 | cost_per_1k_completion_tokens=2.19, 58 | tags=["reasoning", "cached", "discount", "advanced"], 59 | ), 60 | } 61 | 62 | @property 63 | def name(self) -> str: 64 | return "deepseek" 65 | 66 | def get_available_models(self) -> Dict[str, ModelInfo]: 67 | return self.MODELS.copy() 68 | 69 | def _get_api_endpoint(self) -> str: 70 | """Get DeepSeek chat completions endpoint.""" 71 | base_url = self.config.api_base or "https://api.deepseek.com" 72 | return f"{base_url}/chat/completions" 73 | 74 | def _get_request_headers(self) -> Dict[str, str]: 75 | """Generate HTTP headers for DeepSeek API requests.""" 76 | # Start with base headers if they exist 77 | headers = {} 78 | headers.update( 79 | { 80 | "Authorization": f"Bearer {self.config.api_key}", 81 | "Content-Type": "application/json", 82 | } 83 | ) 84 | # Add any custom headers from config 85 | if self.config.headers: 86 | headers.update(self.config.headers) 87 | return headers 88 | -------------------------------------------------------------------------------- /.github/labeler.yml: -------------------------------------------------------------------------------- 1 | # Configuration for GitHub labeler action 2 | # This file defines which labels to apply based on file paths 3 | 4 | # Documentation changes 5 | documentation: 6 | - changed-files: 7 | - any-glob-to-any-file: 8 | - '**/*.md' 9 | - 'docs/**' 10 | - 'README.md' 11 | - 'CHANGELOG.md' 12 | - 'CONTRIBUTING.md' 13 | 14 | # Core functionality 15 | core: 16 | - changed-files: 17 | - any-glob-to-any-file: 18 | - 'justllms/core/**' 19 | - 'justllms/__init__.py' 20 | - 'justllms/__version__.py' 21 | 22 | # Test changes 23 | tests: 24 | - changed-files: 25 | - any-glob-to-any-file: 26 | - 'tests/**' 27 | - '**/*test*.py' 28 | - '**/*_test.py' 29 | - '**/test_*.py' 30 | - 'pytest.ini' 31 | - '.coveragerc' 32 | 33 | # Dependencies 34 | dependencies: 35 | - changed-files: 36 | - any-glob-to-any-file: 37 | - 'pyproject.toml' 38 | - 'requirements*.txt' 39 | - 'setup.py' 40 | - 'setup.cfg' 41 | - 'Pipfile' 42 | - 'Pipfile.lock' 43 | - 'poetry.lock' 44 | 45 | # CI/CD changes 46 | ci: 47 | - changed-files: 48 | - any-glob-to-any-file: 49 | - '.github/workflows/**' 50 | - '.github/actions/**' 51 | - '.github/*.yml' 52 | - '.github/*.yaml' 53 | 54 | # Configuration changes 55 | config: 56 | - changed-files: 57 | - any-glob-to-any-file: 58 | - '*.json' 59 | - '*.yaml' 60 | - '*.yml' 61 | - '*.toml' 62 | - '*.ini' 63 | - '*.cfg' 64 | - '.env*' 65 | - '.*rc' 66 | - '.*config.*' 67 | - '.github/**/*.yml' 68 | - '.github/**/*.yaml' 69 | 70 | # Provider-specific changes 71 | provider: 72 | - changed-files: 73 | - any-glob-to-any-file: 74 | - 'justllms/providers/**' 75 | - 'justllms/core/providers/**' 76 | 77 | # LLM/Model changes 78 | llm: 79 | - changed-files: 80 | - any-glob-to-any-file: 81 | - 'justllms/models/**' 82 | - 'justllms/core/models/**' 83 | - '**/model_*.py' 84 | - '**/*_model.py' 85 | 86 | # Routing changes 87 | routing: 88 | - changed-files: 89 | - any-glob-to-any-file: 90 | - 'justllms/routing/**' 91 | - 'justllms/core/routing/**' 92 | - '**/router*.py' 93 | 94 | # Analytics and monitoring changes 95 | analytics: 96 | - changed-files: 97 | - any-glob-to-any-file: 98 | - 'justllms/analytics/**' 99 | - 'justllms/monitoring/**' 100 | - '**/analytics*.py' 101 | - '**/dashboard*.py' 102 | - '**/metrics*.py' 103 | 104 | # Conversation management changes 105 | conversations: 106 | - changed-files: 107 | - any-glob-to-any-file: 108 | - 'justllms/conversations/**' 109 | - '**/conversation*.py' 110 | - '**/storage*.py' 111 | 112 | # Cache-related changes 113 | caching: 114 | - changed-files: 115 | - any-glob-to-any-file: 116 | - 'justllms/cache/**' 117 | - 'justllms/core/cache/**' 118 | - '**/cache*.py' 119 | 120 | # Examples 121 | examples: 122 | - changed-files: 123 | - any-glob-to-any-file: 124 | - 'examples/**' 125 | - 'notebooks/**' 126 | - '*.ipynb' 127 | 128 | # Security-related changes 129 | security: 130 | - changed-files: 131 | - any-glob-to-any-file: 132 | - '**/auth*.py' 133 | - '**/security*.py' 134 | - '**/token*.py' 135 | - '**/crypt*.py' 136 | - '.github/dependabot.yml' -------------------------------------------------------------------------------- /justllms/tools/google.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | 3 | 4 | class GoogleNativeTool: 5 | """Base class for Google native tools.""" 6 | 7 | def __init__(self, config: Optional[Dict[str, Any]] = None) -> None: 8 | """Initialize a Google native tool. 9 | 10 | Args: 11 | config: Optional configuration for the tool. 12 | """ 13 | self.config = config or {} 14 | self._is_native = True 15 | self._provider = "google" 16 | 17 | def is_native_tool(self) -> bool: 18 | """Check if this is a native tool.""" 19 | return True 20 | 21 | def get_provider(self) -> str: 22 | """Get the provider this tool belongs to.""" 23 | return "google" 24 | 25 | def to_api_format(self) -> Dict[str, Any]: 26 | """Convert to Gemini API format. 27 | 28 | Should be overridden by subclasses. 29 | """ 30 | raise NotImplementedError 31 | 32 | 33 | class GoogleSearch(GoogleNativeTool): 34 | """Google Search native tool for Gemini. 35 | 36 | Enables server-side web search capabilities. 37 | 38 | Example: 39 | from justllms.tools.google import GoogleSearch 40 | 41 | response = client.completion.create( 42 | messages=[{"role": "user", "content": "What are the latest AI developments?"}], 43 | tools=[GoogleSearch()], 44 | provider="google" 45 | ) 46 | """ 47 | 48 | def __init__(self, config: Optional[Dict[str, Any]] = None) -> None: 49 | """Initialize Google Search tool. 50 | 51 | Args: 52 | config: Optional configuration including: 53 | - dynamic_retrieval_config: Configuration for dynamic retrieval 54 | - mode: "MODE_DYNAMIC" or "MODE_UNSPECIFIED" 55 | - dynamic_threshold: Threshold for retrieval (0.0-1.0) 56 | """ 57 | super().__init__(config) 58 | self.name = "google_search" 59 | 60 | def to_api_format(self) -> Dict[str, Any]: 61 | """Convert to Gemini API format. 62 | 63 | Returns: 64 | Dict in format: {"google_search": {}} 65 | or {"google_search": {"dynamic_retrieval_config": {...}}} 66 | """ 67 | if self.config: 68 | return {"google_search": self.config} 69 | return {"google_search": {}} 70 | 71 | 72 | class GoogleCodeExecution(GoogleNativeTool): 73 | """Google Code Execution native tool for Gemini. 74 | 75 | Enables server-side Python code execution in a sandbox. 76 | 77 | Example: 78 | from justllms.tools.google import GoogleCodeExecution 79 | 80 | response = client.completion.create( 81 | messages=[{"role": "user", "content": "Calculate fibonacci sequence up to 100"}], 82 | tools=[GoogleCodeExecution()], 83 | provider="google" 84 | ) 85 | """ 86 | 87 | def __init__(self, config: Optional[Dict[str, Any]] = None) -> None: 88 | """Initialize Code Execution tool. 89 | 90 | Args: 91 | config: Optional configuration (currently unused for code execution). 92 | """ 93 | super().__init__(config) 94 | self.name = "code_execution" 95 | 96 | def to_api_format(self) -> Dict[str, Any]: 97 | """Convert to Gemini API format. 98 | 99 | Returns: 100 | Dict in format: {"code_execution": {}} 101 | """ 102 | if self.config: 103 | return {"code_execution": self.config} 104 | return {"code_execution": {}} 105 | -------------------------------------------------------------------------------- /justllms/exceptions/exceptions.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | 3 | 4 | class JustLLMsError(Exception): 5 | """Base exception for all JustLLMs library errors. 6 | 7 | Provides common error handling patterns with structured error information 8 | including error codes and additional context details. 9 | 10 | Args: 11 | message: Human-readable error description. 12 | code: Optional error code for programmatic error handling. 13 | details: Optional dictionary with additional error context. 14 | """ 15 | 16 | def __init__( 17 | self, 18 | message: str, 19 | code: Optional[str] = None, 20 | details: Optional[Dict[str, Any]] = None, 21 | ): 22 | super().__init__(message) 23 | self.message = message 24 | self.code = code 25 | self.details = details or {} 26 | 27 | 28 | class ProviderError(JustLLMsError): 29 | """Error originating from an LLM provider API. 30 | 31 | Represents failures in communication with or responses from LLM provider 32 | APIs, including HTTP errors, API errors, and malformed responses. 33 | 34 | Args: 35 | message: Error description from provider or library. 36 | provider: Name of the provider that generated the error. 37 | status_code: HTTP status code if applicable. 38 | response_body: Raw response body from the provider API. 39 | **kwargs: Additional arguments passed to parent JustLLMsError. 40 | """ 41 | 42 | def __init__( 43 | self, 44 | message: str, 45 | provider: Optional[str] = None, 46 | status_code: Optional[int] = None, 47 | response_body: Optional[str] = None, 48 | **kwargs: Any, 49 | ): 50 | super().__init__(message, **kwargs) 51 | self.provider = provider 52 | self.status_code = status_code 53 | self.response_body = response_body 54 | 55 | 56 | class ValidationError(JustLLMsError): 57 | """Error during input validation.""" 58 | 59 | def __init__( 60 | self, 61 | message: str, 62 | field: Optional[str] = None, 63 | value: Any = None, 64 | **kwargs: Any, 65 | ): 66 | super().__init__(message, **kwargs) 67 | self.field = field 68 | self.value = value 69 | 70 | 71 | class RateLimitError(ProviderError): 72 | """Rate limit exceeded error.""" 73 | 74 | def __init__( 75 | self, 76 | message: str, 77 | retry_after: Optional[int] = None, 78 | **kwargs: Any, 79 | ): 80 | super().__init__(message, **kwargs) 81 | self.retry_after = retry_after 82 | 83 | 84 | class TimeoutError(ProviderError): 85 | """Request timeout error.""" 86 | 87 | def __init__( 88 | self, 89 | message: str, 90 | timeout_seconds: Optional[float] = None, 91 | **kwargs: Any, 92 | ): 93 | super().__init__(message, **kwargs) 94 | self.timeout_seconds = timeout_seconds 95 | 96 | 97 | class AuthenticationError(ProviderError): 98 | """Authentication/authorization error.""" 99 | 100 | def __init__( 101 | self, 102 | message: str, 103 | required_auth: Optional[str] = None, 104 | **kwargs: Any, 105 | ): 106 | super().__init__(message, **kwargs) 107 | self.required_auth = required_auth 108 | 109 | 110 | class ConfigurationError(JustLLMsError): 111 | """Configuration error.""" 112 | 113 | def __init__( 114 | self, 115 | message: str, 116 | config_key: Optional[str] = None, 117 | config_value: Any = None, 118 | **kwargs: Any, 119 | ): 120 | super().__init__(message, **kwargs) 121 | self.config_key = config_key 122 | self.config_value = config_value 123 | -------------------------------------------------------------------------------- /justllms/config/config.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import json 3 | import os 4 | from pathlib import Path 5 | from typing import Any, Dict, Optional, Union 6 | 7 | import yaml 8 | from pydantic import BaseModel, ConfigDict, Field 9 | 10 | 11 | class ConfigProviderSettings(BaseModel): 12 | model_config = ConfigDict(extra="allow") 13 | name: str 14 | api_key: Optional[str] = None 15 | enabled: bool = True 16 | base_url: Optional[str] = None 17 | timeout: Optional[int] = None 18 | max_retries: int = 3 19 | rate_limit: Optional[int] = None 20 | deployment_mapping: Dict[str, str] = Field(default_factory=dict) 21 | 22 | 23 | class RoutingConfig(BaseModel): 24 | """Configuration for provider and model fallbacks.""" 25 | 26 | model_config = ConfigDict(extra="allow") 27 | 28 | fallback_provider: Optional[str] = None 29 | fallback_model: Optional[str] = None 30 | 31 | """max execution time per tool""" 32 | tool_timeout: float = 120.0 33 | 34 | """max number of tool execution rounds""" 35 | max_tool_iterations: int = 10 36 | 37 | """whether to automatically execute tools by default""" 38 | execute_tools_by_default: bool = True 39 | 40 | 41 | class Config(BaseModel): 42 | """Configuration class for multi-provider LLM client.""" 43 | 44 | model_config = ConfigDict(extra="allow") 45 | 46 | providers: Dict[str, Dict[str, Any]] = Field(default_factory=dict) 47 | routing: RoutingConfig = Field(default_factory=RoutingConfig) 48 | 49 | @classmethod 50 | def from_file(cls, path: Union[str, Path]) -> "Config": 51 | """Load configuration from a file.""" 52 | path = Path(path) 53 | 54 | if not path.exists(): 55 | raise FileNotFoundError(f"Configuration file not found: {path}") 56 | 57 | with open(path) as f: 58 | if path.suffix in [".yaml", ".yml"]: 59 | data = yaml.safe_load(f) 60 | elif path.suffix == ".json": 61 | data = json.load(f) 62 | else: 63 | raise ValueError(f"Unsupported configuration file format: {path.suffix}") 64 | 65 | return cls(**data) 66 | 67 | @classmethod 68 | def from_env(cls) -> "Config": 69 | """Create configuration from environment variables.""" 70 | providers = {} 71 | 72 | # Common provider environment variables 73 | provider_keys = { 74 | "openai": "OPENAI_API_KEY", 75 | "anthropic": "ANTHROPIC_API_KEY", 76 | "google": "GOOGLE_API_KEY", 77 | "azure_openai": "AZURE_OPENAI_API_KEY", 78 | "deepseek": "DEEPSEEK_API_KEY", 79 | "grok": ("XAI_API_KEY", "GROK_API_KEY"), # Support both for backwards compatibility 80 | } 81 | 82 | for provider_name, env_key in provider_keys.items(): 83 | if isinstance(env_key, tuple): 84 | api_key = None 85 | for key in env_key: 86 | api_key = os.getenv(key) 87 | if api_key: 88 | break 89 | else: 90 | # Type guard: env_key is a string here 91 | api_key = os.getenv(env_key) # type: ignore[call-overload] 92 | 93 | if api_key: 94 | providers[provider_name] = {"api_key": api_key} 95 | 96 | ollama_base = os.getenv("OLLAMA_API_BASE") or os.getenv("OLLAMA_HOST") 97 | ollama_enabled = os.getenv("OLLAMA_ENABLED", "").lower() in {"1", "true", "yes"} 98 | if ollama_base or ollama_enabled: 99 | provider_entry: Dict[str, Any] = {"enabled": True} 100 | if ollama_base: 101 | provider_entry["base_url"] = ollama_base 102 | 103 | headers_json = os.getenv("OLLAMA_HEADERS_JSON") 104 | if headers_json: 105 | with contextlib.suppress(json.JSONDecodeError): 106 | provider_entry["headers"] = json.loads(headers_json) 107 | 108 | providers["ollama"] = provider_entry 109 | 110 | return cls(providers=providers, routing=RoutingConfig()) 111 | 112 | 113 | def load_config( 114 | config_path: Optional[str] = None, 115 | use_defaults: bool = True, 116 | use_env: bool = True, 117 | ) -> Config: 118 | """Load configuration from various sources.""" 119 | if config_path: 120 | return Config.from_file(config_path) 121 | 122 | # Try to find config file 123 | config_files = ["justllms.yaml", "justllms.yml", "justllms.json"] 124 | for config_file in config_files: 125 | if Path(config_file).exists(): 126 | return Config.from_file(config_file) 127 | 128 | if use_env: 129 | return Config.from_env() 130 | 131 | if use_defaults: 132 | return Config() 133 | 134 | raise FileNotFoundError("No configuration file found and environment variables not available") 135 | -------------------------------------------------------------------------------- /justllms/sxs/core/executor.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | from concurrent.futures import ThreadPoolExecutor, as_completed 4 | from typing import Any, Callable, Dict, List, Optional, Tuple 5 | 6 | from justllms.sxs.models import ModelResponse, ResponseStatus 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class ParallelExecutor: 12 | """Execute multiple model calls in parallel.""" 13 | 14 | def __init__(self, client: Any) -> None: 15 | """Initialize the executor. 16 | 17 | Args: 18 | client: JustLLM client instance 19 | """ 20 | self.client = client 21 | 22 | def execute_comparison( 23 | self, 24 | prompt: str, 25 | models: List[Tuple[str, str]], 26 | on_model_complete: Optional[Callable] = None, 27 | temperature: float = 0.7, 28 | max_tokens: Optional[int] = None, 29 | ) -> Dict[str, ModelResponse]: 30 | """Execute all models in parallel. 31 | 32 | Args: 33 | prompt: The prompt to send to all models 34 | models: List of (provider, model) tuples 35 | on_model_complete: Callback when a model completes 36 | temperature: Temperature for generation 37 | max_tokens: Maximum tokens to generate 38 | 39 | Returns: 40 | Dictionary mapping model_id to ModelResponse 41 | """ 42 | results: Dict[str, ModelResponse] = {} 43 | 44 | if not models: 45 | return results 46 | 47 | def call_model(provider: str, model: str) -> Tuple[str, ModelResponse]: 48 | """Call a single model.""" 49 | model_id = f"{provider}/{model}" 50 | start = time.time() 51 | 52 | try: 53 | # Call the model 54 | response = self.client.completion.create( 55 | messages=[{"role": "user", "content": prompt}], 56 | provider=provider, 57 | model=model, 58 | temperature=temperature, 59 | max_tokens=max_tokens, 60 | ) 61 | 62 | # Create successful response 63 | result = ModelResponse( 64 | provider=provider, 65 | model=model, 66 | content=response.content, 67 | status=ResponseStatus.COMPLETED, 68 | latency=time.time() - start, 69 | tokens=response.usage.total_tokens if response.usage else 0, 70 | cost=response.usage.estimated_cost if response.usage else 0.0, 71 | ) 72 | 73 | except Exception as e: 74 | # Create error response 75 | logger.error(f"Error calling {model_id}: {e}") 76 | result = ModelResponse( 77 | provider=provider, 78 | model=model, 79 | content="", 80 | status=ResponseStatus.ERROR, 81 | latency=time.time() - start, 82 | tokens=0, 83 | cost=0.0, 84 | error=str(e), 85 | ) 86 | 87 | # Call callback if provided 88 | if on_model_complete: 89 | try: 90 | on_model_complete(model_id, result) 91 | except Exception as e: 92 | logger.error(f"Error in callback for {model_id}: {e}") 93 | 94 | return model_id, result 95 | 96 | with ThreadPoolExecutor(max_workers=min(len(models), 10)) as executor: 97 | futures = [executor.submit(call_model, provider, model) for provider, model in models] 98 | 99 | try: 100 | for future in as_completed(futures): 101 | try: 102 | model_id, result = future.result() 103 | results[model_id] = result 104 | except Exception as e: 105 | logger.error(f"Error processing future: {e}") 106 | finally: 107 | for future in futures: 108 | if not future.done(): 109 | future.cancel() 110 | 111 | for provider, model in models: 112 | model_id = f"{provider}/{model}" 113 | if model_id not in results: 114 | results[model_id] = ModelResponse( 115 | provider=provider, 116 | model=model, 117 | content="", 118 | status=ResponseStatus.ERROR, 119 | latency=0.0, 120 | tokens=0, 121 | cost=0.0, 122 | error="Request failed or was cancelled", 123 | ) 124 | 125 | return results 126 | -------------------------------------------------------------------------------- /justllms/utils/validators.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Union 2 | 3 | from justllms.core.models import Message, Role 4 | from justllms.exceptions import ValidationError 5 | 6 | 7 | def validate_messages( # noqa: C901 8 | messages: Union[List[Dict[str, Any]], List[Message]], 9 | ) -> List[Message]: 10 | """Validate and normalize message inputs for LLM requests. 11 | 12 | Performs comprehensive validation of message structure, content, roles, 13 | and conversation flow to ensure compatibility with provider APIs. 14 | Converts dictionary inputs to Message objects. 15 | 16 | Args: 17 | messages: List of messages as dictionaries or Message objects. 18 | Each message must have 'role' and 'content' fields. 19 | 20 | Returns: 21 | List[Message]: Validated and normalized Message objects ready for 22 | provider consumption. 23 | 24 | Raises: 25 | ValidationError: If messages are invalid, empty, malformed, or missing 26 | required fields. Includes specific error descriptions 27 | and field references for debugging. 28 | """ 29 | if not messages: 30 | raise ValidationError("Messages list cannot be empty") 31 | 32 | if not isinstance(messages, list): 33 | raise ValidationError("Messages must be a list") 34 | 35 | validated_messages = [] 36 | 37 | for i, msg in enumerate(messages): 38 | if isinstance(msg, Message): 39 | validated_messages.append(msg) 40 | elif isinstance(msg, dict): 41 | # Validate required fields 42 | if "role" not in msg: 43 | raise ValidationError(f"Message {i} missing required field 'role'") 44 | 45 | if "content" not in msg: 46 | raise ValidationError(f"Message {i} missing required field 'content'") 47 | 48 | # Validate role 49 | role = msg["role"] 50 | if isinstance(role, str): 51 | try: 52 | role = Role(role.lower()) 53 | except ValueError as e: 54 | valid_roles = [r.value for r in Role] 55 | raise ValidationError( 56 | f"Message {i} has invalid role '{role}'. " 57 | f"Valid roles are: {', '.join(valid_roles)}" 58 | ) from e 59 | elif not isinstance(role, Role): 60 | raise ValidationError(f"Message {i} role must be a string or Role enum") 61 | 62 | # Validate content 63 | content = msg["content"] 64 | if not isinstance(content, (str, list)): 65 | raise ValidationError(f"Message {i} content must be a string or list") 66 | 67 | if isinstance(content, str) and not content.strip(): 68 | raise ValidationError(f"Message {i} content cannot be empty") 69 | 70 | if isinstance(content, list): 71 | if not content: 72 | raise ValidationError(f"Message {i} content list cannot be empty") 73 | 74 | # Validate multimodal content 75 | for j, item in enumerate(content): 76 | if not isinstance(item, dict): 77 | raise ValidationError(f"Message {i} content item {j} must be a dictionary") 78 | 79 | if "type" not in item: 80 | raise ValidationError(f"Message {i} content item {j} missing 'type' field") 81 | 82 | item_type = item["type"] 83 | if item_type == "text": 84 | if "text" not in item: 85 | raise ValidationError( 86 | f"Message {i} content item {j} of type 'text' missing 'text' field" 87 | ) 88 | elif item_type == "image": 89 | if "image" not in item and "image_url" not in item: 90 | raise ValidationError( 91 | f"Message {i} content item {j} of type 'image' " 92 | "missing 'image' or 'image_url' field" 93 | ) 94 | else: 95 | # Allow other types but don't validate 96 | pass 97 | 98 | # Create Message object 99 | try: 100 | validated_messages.append(Message(**msg)) 101 | except Exception as e: 102 | raise ValidationError(f"Message {i} validation failed: {str(e)}") from e 103 | else: 104 | raise ValidationError(f"Message {i} must be a dict or Message object, got {type(msg)}") 105 | 106 | # Additional validations 107 | if not any(msg.role == Role.USER for msg in validated_messages): 108 | raise ValidationError("Messages must contain at least one user message") 109 | 110 | # Check message order (system messages should be first) 111 | system_indices = [i for i, msg in enumerate(validated_messages) if msg.role == Role.SYSTEM] 112 | 113 | if system_indices and any(i > 0 for i in system_indices): 114 | # Allow system messages after position 0 but warn 115 | pass 116 | 117 | return validated_messages 118 | -------------------------------------------------------------------------------- /.github/workflows/pr-labeler.yml: -------------------------------------------------------------------------------- 1 | name: PR Labeler 2 | 3 | on: 4 | pull_request: 5 | types: [opened, edited, synchronize] 6 | 7 | permissions: 8 | contents: read 9 | pull-requests: write 10 | issues: write 11 | 12 | jobs: 13 | label: 14 | name: Auto Label PR 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v4 19 | 20 | - name: Label PR based on files changed 21 | uses: actions/labeler@v5 22 | with: 23 | repo-token: "${{ secrets.GITHUB_TOKEN }}" 24 | configuration-path: .github/labeler.yml 25 | sync-labels: false 26 | 27 | - name: Label PR based on title 28 | uses: actions/github-script@v7 29 | with: 30 | github-token: ${{ secrets.GITHUB_TOKEN }} 31 | script: | 32 | const title = context.payload.pull_request.title.toLowerCase(); 33 | const labels = []; 34 | 35 | // Add labels based on PR title keywords 36 | if (title.includes('feat:') || title.includes('feature')) { 37 | labels.push('enhancement'); 38 | } 39 | if (title.includes('fix:') || title.includes('bug')) { 40 | labels.push('bug'); 41 | } 42 | if (title.includes('docs:') || title.includes('documentation')) { 43 | labels.push('documentation'); 44 | } 45 | if (title.includes('test:') || title.includes('tests')) { 46 | labels.push('tests'); 47 | } 48 | if (title.includes('chore:') || title.includes('ci:')) { 49 | labels.push('chore'); 50 | } 51 | if (title.includes('refactor:')) { 52 | labels.push('refactor'); 53 | } 54 | if (title.includes('perf:') || title.includes('performance')) { 55 | labels.push('performance'); 56 | } 57 | if (title.includes('breaking') || title.includes('!:')) { 58 | labels.push('breaking-change'); 59 | } 60 | if (title.includes('deps:') || title.includes('dependencies')) { 61 | labels.push('dependencies'); 62 | } 63 | 64 | // Add labels based on PR body content 65 | const body = context.payload.pull_request.body?.toLowerCase() || ''; 66 | 67 | if (body.includes('llm') || body.includes('model')) { 68 | labels.push('llm'); 69 | } 70 | if (body.includes('provider')) { 71 | labels.push('provider'); 72 | } 73 | if (body.includes('routing')) { 74 | labels.push('routing'); 75 | } 76 | if (body.includes('cache') || body.includes('caching')) { 77 | labels.push('caching'); 78 | } 79 | if (body.includes('analytics') || body.includes('dashboard')) { 80 | labels.push('analytics'); 81 | } 82 | 83 | // Apply the labels if any were identified 84 | if (labels.length > 0) { 85 | await github.rest.issues.addLabels({ 86 | owner: context.repo.owner, 87 | repo: context.repo.repo, 88 | issue_number: context.issue.number, 89 | labels: labels 90 | }); 91 | } 92 | 93 | - name: Label PR size 94 | uses: actions/github-script@v7 95 | with: 96 | github-token: ${{ secrets.GITHUB_TOKEN }} 97 | script: | 98 | const pr = context.payload.pull_request; 99 | const additions = pr.additions; 100 | const deletions = pr.deletions; 101 | const changes = additions + deletions; 102 | 103 | let sizeLabel = ''; 104 | 105 | if (changes < 10) { 106 | sizeLabel = 'size/XS'; 107 | } else if (changes < 50) { 108 | sizeLabel = 'size/S'; 109 | } else if (changes < 200) { 110 | sizeLabel = 'size/M'; 111 | } else if (changes < 500) { 112 | sizeLabel = 'size/L'; 113 | } else if (changes < 1000) { 114 | sizeLabel = 'size/XL'; 115 | } else { 116 | sizeLabel = 'size/XXL'; 117 | } 118 | 119 | // Remove existing size labels 120 | const { data: currentLabels } = await github.rest.issues.listLabelsOnIssue({ 121 | owner: context.repo.owner, 122 | repo: context.repo.repo, 123 | issue_number: context.issue.number 124 | }); 125 | 126 | const sizeLabels = currentLabels 127 | .filter(label => label.name.startsWith('size/')) 128 | .map(label => label.name); 129 | 130 | if (sizeLabels.length > 0) { 131 | for (const label of sizeLabels) { 132 | await github.rest.issues.removeLabel({ 133 | owner: context.repo.owner, 134 | repo: context.repo.repo, 135 | issue_number: context.issue.number, 136 | name: label 137 | }); 138 | } 139 | } 140 | 141 | // Add the new size label 142 | await github.rest.issues.addLabels({ 143 | owner: context.repo.owner, 144 | repo: context.repo.repo, 145 | issue_number: context.issue.number, 146 | labels: [sizeLabel] 147 | }); -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*' 7 | 8 | permissions: 9 | contents: write 10 | id-token: write 11 | 12 | jobs: 13 | build: 14 | name: Build Distribution 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v4 19 | 20 | - name: Set up Python 21 | uses: actions/setup-python@v5 22 | with: 23 | python-version: '3.11' 24 | 25 | - name: Install build dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | pip install build twine 29 | 30 | - name: Build package 31 | run: python -m build 32 | 33 | - name: Check distribution 34 | run: twine check dist/* 35 | 36 | - name: Store distributions 37 | uses: actions/upload-artifact@v4 38 | with: 39 | name: python-package-distributions 40 | path: dist/ 41 | 42 | publish-testpypi: 43 | name: Publish to TestPyPI 44 | needs: build 45 | runs-on: ubuntu-latest 46 | environment: 47 | name: testpypi 48 | url: https://test.pypi.org/p/justllms 49 | 50 | steps: 51 | - name: Download distributions 52 | uses: actions/download-artifact@v4 53 | with: 54 | name: python-package-distributions 55 | path: dist/ 56 | 57 | - name: Publish to TestPyPI 58 | uses: pypa/gh-action-pypi-publish@release/v1 59 | with: 60 | repository-url: https://test.pypi.org/legacy/ 61 | password: ${{ secrets.TEST_PYPI_API_TOKEN }} 62 | 63 | - name: Wait for TestPyPI propagation 64 | run: sleep 180 65 | 66 | verify-testpypi: 67 | name: Verify TestPyPI Installation 68 | needs: publish-testpypi 69 | runs-on: ubuntu-latest 70 | 71 | steps: 72 | - name: Set up Python 73 | uses: actions/setup-python@v5 74 | with: 75 | python-version: '3.11' 76 | 77 | - name: Extract version 78 | id: version 79 | run: echo "VERSION=${GITHUB_REF#refs/tags/v}" >> $GITHUB_OUTPUT 80 | 81 | - name: Install from TestPyPI 82 | run: | 83 | pip install --index-url https://test.pypi.org/simple/ \ 84 | --extra-index-url https://pypi.org/simple/ \ 85 | justllms==${{ steps.version.outputs.VERSION }} 86 | 87 | - name: Verify installation 88 | run: | 89 | python -c "import justllms; print(f'Installed version: {justllms.__version__}')" 90 | justllms --help || echo "CLI check completed" 91 | 92 | - name: Test basic functionality 93 | run: | 94 | python -c "from justllms import JustLLM; print('✅ Import successful')" 95 | 96 | publish-pypi: 97 | name: Publish to PyPI 98 | needs: verify-testpypi 99 | runs-on: ubuntu-latest 100 | environment: 101 | name: pypi 102 | url: https://pypi.org/project/justllms 103 | 104 | steps: 105 | - name: Download distributions 106 | uses: actions/download-artifact@v4 107 | with: 108 | name: python-package-distributions 109 | path: dist/ 110 | 111 | - name: Publish to PyPI 112 | uses: pypa/gh-action-pypi-publish@release/v1 113 | with: 114 | password: ${{ secrets.PYPI_API_TOKEN }} 115 | 116 | github-release: 117 | name: Create GitHub Release 118 | needs: publish-pypi 119 | runs-on: ubuntu-latest 120 | permissions: 121 | contents: write 122 | 123 | steps: 124 | - uses: actions/checkout@v4 125 | 126 | - name: Extract version 127 | id: version 128 | run: echo "VERSION=${GITHUB_REF#refs/tags/v}" >> $GITHUB_OUTPUT 129 | 130 | - name: Extract tag message 131 | id: tag_message 132 | run: | 133 | TAG_MESSAGE=$(git tag -l --format='%(contents)' ${GITHUB_REF#refs/tags/}) 134 | echo "MESSAGE<> $GITHUB_OUTPUT 135 | echo "$TAG_MESSAGE" >> $GITHUB_OUTPUT 136 | echo "EOF" >> $GITHUB_OUTPUT 137 | 138 | - name: Download distributions 139 | uses: actions/download-artifact@v4 140 | with: 141 | name: python-package-distributions 142 | path: dist/ 143 | 144 | - name: Create Release 145 | uses: softprops/action-gh-release@v1 146 | with: 147 | name: v${{ steps.version.outputs.VERSION }} 148 | body: | 149 | ## JustLLMs v${{ steps.version.outputs.VERSION }} 150 | 151 | ${{ steps.tag_message.outputs.MESSAGE }} 152 | 153 | ### Installation 154 | ```bash 155 | pip install justllms==${{ steps.version.outputs.VERSION }} 156 | ``` 157 | 158 | --- 159 | *View on PyPI: https://pypi.org/project/justllms/${{ steps.version.outputs.VERSION }}/* 160 | files: dist/* 161 | draft: false 162 | prerelease: false 163 | generate_release_notes: true 164 | 165 | verify-release: 166 | name: Verify Release 167 | needs: publish-pypi 168 | runs-on: ubuntu-latest 169 | 170 | steps: 171 | - name: Wait for PyPI availability 172 | run: sleep 180 173 | 174 | - name: Extract version 175 | id: version 176 | run: echo "VERSION=${GITHUB_REF#refs/tags/v}" >> $GITHUB_OUTPUT 177 | 178 | - name: Test installation from PyPI 179 | run: | 180 | python -m pip install --upgrade pip 181 | pip install justllms==${{ steps.version.outputs.VERSION }} 182 | python -c "import justllms; assert justllms.__version__ == '${{ steps.version.outputs.VERSION }}'" 183 | echo "✅ Release verified successfully!" 184 | -------------------------------------------------------------------------------- /justllms/tools/adapters/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any, Dict, List, Optional, Union 3 | 4 | from justllms.core.models import Message, Role 5 | from justllms.tools.models import Tool, ToolCall, ToolResult 6 | 7 | 8 | class BaseToolAdapter(ABC): 9 | """Abstract base class for provider-specific tool format conversion. 10 | 11 | Each provider (OpenAI, Anthropic, Google, etc.) has its own format 12 | for tool/function definitions and tool call responses. This adapter 13 | provides a unified interface for conversion between JustLLMs Tool 14 | objects and provider-specific formats. 15 | """ 16 | 17 | @abstractmethod 18 | def format_tools_for_api(self, tools: List[Tool]) -> List[Dict[str, Any]]: 19 | """Convert Tool objects to provider's API format. 20 | 21 | Args: 22 | tools: List of Tool instances to convert. 23 | 24 | Returns: 25 | List of tool definitions in provider-specific format. 26 | 27 | Examples: 28 | OpenAI format: 29 | [{"type": "function", "function": {...}}] 30 | 31 | Anthropic format: 32 | [{"name": "...", "description": "...", "input_schema": {...}}] 33 | """ 34 | pass 35 | 36 | @abstractmethod 37 | def format_tool_choice( 38 | self, tool_choice: Optional[Union[str, Dict[str, Any]]] 39 | ) -> Optional[Any]: 40 | """Normalize tool_choice parameter for provider. 41 | 42 | Args: 43 | tool_choice: Tool selection strategy. Can be: 44 | - "auto": Let model decide 45 | - "none": Don't use tools 46 | - "required": Must use a tool (OpenAI) 47 | - Dict with specific tool name 48 | - None: Use provider default 49 | 50 | Returns: 51 | Provider-specific tool_choice format. 52 | """ 53 | pass 54 | 55 | @abstractmethod 56 | def extract_tool_calls(self, response: Dict[str, Any]) -> List[ToolCall]: 57 | """Extract tool calls from provider's response. 58 | 59 | Args: 60 | response: Raw response from provider API. 61 | 62 | Returns: 63 | List of ToolCall objects extracted from response. 64 | """ 65 | pass 66 | 67 | @abstractmethod 68 | def format_tool_result_message(self, tool_result: ToolResult, tool_call: ToolCall) -> Message: 69 | """Format tool execution result as a message. 70 | 71 | Args: 72 | tool_result: Result from tool execution. 73 | tool_call: Original tool call that produced this result. 74 | 75 | Returns: 76 | Message object formatted for the provider. 77 | """ 78 | pass 79 | 80 | def format_tool_calls_message(self, tool_calls: List[ToolCall]) -> Optional[Message]: 81 | """Format tool calls as an assistant message. 82 | 83 | Some providers need tool calls formatted as assistant messages. 84 | 85 | Args: 86 | tool_calls: List of tool calls to include in message. 87 | 88 | Returns: 89 | Assistant message with tool calls, or None if not needed. 90 | """ 91 | # Default implementation for OpenAI-style 92 | if not tool_calls: 93 | return None 94 | 95 | tool_calls_data = [] 96 | for tc in tool_calls: 97 | tool_call_dict = { 98 | "id": tc.id, 99 | "type": "function", 100 | "function": { 101 | "name": tc.name, 102 | "arguments": tc.raw_arguments or self._serialize_arguments(tc.arguments), 103 | }, 104 | } 105 | tool_calls_data.append(tool_call_dict) 106 | 107 | return Message( 108 | role=Role.ASSISTANT, 109 | content="", # Tool calls don't have content 110 | tool_calls=tool_calls_data, 111 | ) 112 | 113 | def supports_parallel_tools(self) -> bool: 114 | """Check if provider supports parallel tool calls. 115 | 116 | Returns: 117 | True if provider can call multiple tools in one response. 118 | """ 119 | return False 120 | 121 | def supports_required_tools(self) -> bool: 122 | """Check if provider supports 'required' tool choice. 123 | 124 | Returns: 125 | True if provider supports forcing tool use. 126 | """ 127 | return False 128 | 129 | def get_max_tools_per_call(self) -> Optional[int]: 130 | """Get maximum number of tools that can be defined per call. 131 | 132 | Returns: 133 | Maximum number or None if unlimited. 134 | """ 135 | return None 136 | 137 | def _serialize_arguments(self, arguments: Dict[str, Any]) -> str: 138 | """Serialize arguments dictionary to JSON string. 139 | 140 | Args: 141 | arguments: Tool arguments dictionary. 142 | 143 | Returns: 144 | JSON string representation. 145 | """ 146 | import json 147 | 148 | return json.dumps(arguments, default=str) 149 | 150 | def _parse_arguments(self, arguments_str: str) -> Dict[str, Any]: 151 | """Parse JSON string to arguments dictionary. 152 | 153 | Args: 154 | arguments_str: JSON string of arguments. 155 | 156 | Returns: 157 | Parsed arguments dictionary. 158 | """ 159 | import json 160 | from typing import cast 161 | 162 | try: 163 | return cast(Dict[str, Any], json.loads(arguments_str)) 164 | except json.JSONDecodeError: 165 | # Return as-is if not valid JSON 166 | return {"raw": arguments_str} 167 | -------------------------------------------------------------------------------- /justllms/tools/models.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from dataclasses import dataclass, field 3 | from enum import Enum 4 | from typing import Any, Callable, Dict, List, Optional 5 | 6 | from pydantic import BaseModel, ConfigDict, Field 7 | 8 | 9 | class ParameterInfo(BaseModel): 10 | """Information about a tool parameter.""" 11 | 12 | name: str 13 | type: str # JSON Schema type 14 | description: Optional[str] = None 15 | required: bool = True 16 | default: Any = None 17 | enum: Optional[List[Any]] = None 18 | items: Optional[Dict[str, Any]] = None # For array types 19 | properties: Optional[Dict[str, Any]] = None # For object types 20 | 21 | 22 | class Tool(BaseModel): 23 | """Core tool representation.""" 24 | 25 | model_config = ConfigDict(arbitrary_types_allowed=True) 26 | 27 | name: str 28 | namespace: Optional[str] = None 29 | description: str 30 | callable: Callable 31 | parameters: Dict[str, ParameterInfo] = Field(default_factory=dict) 32 | parameter_descriptions: Dict[str, str] = Field(default_factory=dict) 33 | return_type: Optional[Any] = None # Can be type or typing generic 34 | metadata: Dict[str, Any] = Field(default_factory=dict) 35 | is_native: bool = False # For provider-specific native tools 36 | 37 | @property 38 | def full_name(self) -> str: 39 | """Get fully qualified name including namespace.""" 40 | if self.namespace: 41 | return f"{self.namespace}.{self.name}" 42 | return self.name 43 | 44 | def to_json_schema(self) -> Dict[str, Any]: 45 | """Convert tool to JSON Schema format for providers.""" 46 | required_params = [name for name, param in self.parameters.items() if param.required] 47 | 48 | properties = {} 49 | for param_name, param_info in self.parameters.items(): 50 | prop: Dict[str, Any] = {"type": param_info.type} 51 | 52 | # Add description from parameter_descriptions or param_info 53 | desc = self.parameter_descriptions.get(param_name) or param_info.description 54 | if desc: 55 | prop["description"] = desc 56 | 57 | if param_info.enum: 58 | prop["enum"] = param_info.enum 59 | if param_info.items: 60 | prop["items"] = param_info.items 61 | if param_info.properties: 62 | prop["properties"] = param_info.properties 63 | 64 | properties[param_name] = prop 65 | 66 | schema = { 67 | "type": "object", 68 | "properties": properties, 69 | } 70 | 71 | if required_params: 72 | schema["required"] = required_params 73 | 74 | return schema 75 | 76 | 77 | class ToolCall(BaseModel): 78 | """Represents a tool invocation request from the LLM.""" 79 | 80 | id: str = Field(default_factory=lambda: f"call_{uuid.uuid4().hex[:8]}") 81 | name: str 82 | namespace: Optional[str] = None 83 | arguments: Dict[str, Any] = Field(default_factory=dict) 84 | raw_arguments: Optional[str] = None # Original JSON string from LLM 85 | 86 | @property 87 | def full_name(self) -> str: 88 | """Get fully qualified name including namespace.""" 89 | if self.namespace: 90 | return f"{self.namespace}.{self.name}" 91 | return self.name 92 | 93 | 94 | class ToolResultStatus(str, Enum): 95 | """Status of tool execution.""" 96 | 97 | SUCCESS = "success" 98 | ERROR = "error" 99 | TIMEOUT = "timeout" 100 | 101 | 102 | class ToolResult(BaseModel): 103 | """Tool execution result.""" 104 | 105 | tool_call_id: str 106 | result: Any = None 107 | error: Optional[str] = None 108 | execution_time_ms: float = 0.0 109 | status: ToolResultStatus = ToolResultStatus.SUCCESS 110 | cost: Optional[float] = None 111 | """Estimated cost of tool execution in USD (e.g., API call costs).""" 112 | 113 | @property 114 | def is_success(self) -> bool: 115 | """Check if execution was successful.""" 116 | return self.status == ToolResultStatus.SUCCESS 117 | 118 | def to_message_content(self) -> str: 119 | """Convert result to string for message content.""" 120 | if self.is_success: 121 | if isinstance(self.result, str): 122 | return self.result 123 | elif self.result is None: 124 | return "Tool executed successfully with no output." 125 | else: 126 | import json 127 | 128 | try: 129 | return json.dumps(self.result, default=str) 130 | except (TypeError, ValueError): 131 | return str(self.result) 132 | else: 133 | return f"Error: {self.error}" 134 | 135 | 136 | @dataclass 137 | class ToolExecutionEntry: 138 | """Single entry in the tool execution history.""" 139 | 140 | iteration: int 141 | tool_call: ToolCall 142 | tool_result: ToolResult 143 | messages: List[Dict[str, Any]] = field(default_factory=list) 144 | 145 | def to_dict(self) -> Dict[str, Any]: 146 | """Convert to dictionary for serialization.""" 147 | return { 148 | "iteration": self.iteration, 149 | "tool_call": { 150 | "id": self.tool_call.id, 151 | "name": self.tool_call.name, 152 | "namespace": self.tool_call.namespace, 153 | "arguments": self.tool_call.arguments, 154 | "status": self.tool_result.status.value, 155 | "error": self.tool_result.error, 156 | }, 157 | "execution_time_ms": self.tool_result.execution_time_ms, 158 | "result": self.tool_result.result if self.tool_result.is_success else None, 159 | "cost": self.tool_result.cost, 160 | } 161 | -------------------------------------------------------------------------------- /justllms/utils/token_counter.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Union 2 | 3 | try: 4 | import tiktoken 5 | 6 | HAS_TIKTOKEN = True 7 | except ImportError: 8 | HAS_TIKTOKEN = False 9 | 10 | 11 | class TokenCounter: 12 | """Count tokens for different models.""" 13 | 14 | # Model to encoding mapping 15 | MODEL_ENCODINGS = { 16 | "gpt-4": "cl100k_base", 17 | "gpt-4o": "o200k_base", 18 | "gpt-4o-mini": "o200k_base", 19 | "gpt-4-turbo": "cl100k_base", 20 | "gpt-3.5-turbo": "cl100k_base", 21 | "text-embedding-ada-002": "cl100k_base", 22 | "claude": "cl100k_base", # Approximation 23 | "gemini": "cl100k_base", # Approximation 24 | } 25 | 26 | # Rough token estimates when tiktoken is not available 27 | CHARS_PER_TOKEN = { 28 | "default": 4, 29 | "chinese": 2, 30 | "japanese": 2, 31 | "korean": 2, 32 | } 33 | 34 | def __init__(self) -> None: 35 | self._encodings: Dict[str, Any] = {} 36 | 37 | def _get_encoding(self, model: str) -> Optional[Any]: 38 | """Get the encoding for a model.""" 39 | if not HAS_TIKTOKEN: 40 | return None 41 | 42 | # Find the encoding name for the model 43 | encoding_name = None 44 | 45 | # Check exact match first 46 | if model in self.MODEL_ENCODINGS: 47 | encoding_name = self.MODEL_ENCODINGS[model] 48 | else: 49 | # Check prefixes 50 | for model_prefix, enc_name in self.MODEL_ENCODINGS.items(): 51 | if model.startswith(model_prefix): 52 | encoding_name = enc_name 53 | break 54 | 55 | if not encoding_name: 56 | encoding_name = "cl100k_base" # Default 57 | 58 | # Cache encodings 59 | if encoding_name not in self._encodings: 60 | try: 61 | self._encodings[encoding_name] = tiktoken.get_encoding(encoding_name) 62 | except Exception: 63 | return None 64 | 65 | return self._encodings[encoding_name] 66 | 67 | def count_tokens( 68 | self, 69 | text: str, 70 | model: Optional[str] = None, 71 | ) -> int: 72 | """Count tokens in text.""" 73 | if HAS_TIKTOKEN and model: 74 | encoding = self._get_encoding(model) 75 | if encoding: 76 | try: 77 | return len(encoding.encode(text)) 78 | except Exception: 79 | pass 80 | 81 | # Fallback to character-based estimation 82 | return self._estimate_tokens(text) 83 | 84 | def _estimate_tokens(self, text: str) -> int: 85 | """Estimate token count based on character count.""" 86 | # Simple heuristic: ~4 characters per token for English 87 | # Adjust for other languages if detected 88 | 89 | # Check for CJK characters 90 | cjk_count = sum( 91 | 1 92 | for char in text 93 | if "\u4e00" <= char <= "\u9fff" # Chinese 94 | or "\u3040" <= char <= "\u309f" # Hiragana 95 | or "\u30a0" <= char <= "\u30ff" # Katakana 96 | or "\uac00" <= char <= "\ud7af" # Korean 97 | ) 98 | 99 | chars_per_token = 2 if cjk_count > len(text) * 0.3 else 4 100 | 101 | return max(1, len(text) // chars_per_token) 102 | 103 | def count_messages_tokens( 104 | self, 105 | messages: List[Dict[str, Any]], 106 | model: Optional[str] = None, 107 | ) -> Dict[str, int]: 108 | """Count tokens in a list of messages.""" 109 | total_tokens = 0 110 | per_message_tokens = 4 # Overhead per message 111 | 112 | for message in messages: 113 | # Count role tokens 114 | role = message.get("role", "") 115 | total_tokens += self.count_tokens(role, model) 116 | 117 | # Count content tokens 118 | content = message.get("content", "") 119 | if isinstance(content, str): 120 | total_tokens += self.count_tokens(content, model) 121 | elif isinstance(content, list): 122 | # Handle multimodal content 123 | for item in content: 124 | if isinstance(item, dict) and item.get("type") == "text": 125 | total_tokens += self.count_tokens(item.get("text", ""), model) 126 | elif isinstance(item, dict) and item.get("type") == "image": 127 | # Rough estimate for images 128 | total_tokens += 85 # Base64 encoded image token estimate 129 | 130 | # Add per-message overhead 131 | total_tokens += per_message_tokens 132 | 133 | # Handle other fields 134 | if message.get("name"): 135 | total_tokens += self.count_tokens(message["name"], model) 136 | 137 | if message.get("function_call"): 138 | total_tokens += self.count_tokens(str(message["function_call"]), model) 139 | 140 | # Add base prompt tokens 141 | total_tokens += 3 # Every reply is primed with <|start|>assistant<|message|> 142 | 143 | return { 144 | "total": total_tokens, 145 | "messages": len(messages), 146 | } 147 | 148 | 149 | # Global instance 150 | _token_counter = TokenCounter() 151 | 152 | 153 | def count_tokens( 154 | text: Union[str, List[Dict[str, Any]]], 155 | model: Optional[str] = None, 156 | ) -> int: 157 | """Count tokens in text or messages.""" 158 | if isinstance(text, str): 159 | return _token_counter.count_tokens(text, model) 160 | elif isinstance(text, list): 161 | return _token_counter.count_messages_tokens(text, model)["total"] 162 | else: 163 | return 0 164 | -------------------------------------------------------------------------------- /justllms/providers/openai.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | from justllms.core.base import BaseResponse 4 | from justllms.core.models import ModelInfo 5 | from justllms.core.openai_base import BaseOpenAIChatProvider 6 | from justllms.tools.adapters.base import BaseToolAdapter 7 | 8 | 9 | class OpenAIResponse(BaseResponse): 10 | """OpenAI-specific response implementation.""" 11 | 12 | pass 13 | 14 | 15 | class OpenAIProvider(BaseOpenAIChatProvider): 16 | """Simplified OpenAI provider implementation.""" 17 | 18 | supports_tools = True 19 | """OpenAI supports function calling.""" 20 | 21 | MODELS = { 22 | "gpt-5": ModelInfo( 23 | name="gpt-5", 24 | provider="openai", 25 | max_tokens=128000, 26 | max_context_length=272000, 27 | supports_functions=True, 28 | supports_vision=True, 29 | cost_per_1k_prompt_tokens=1.25, 30 | cost_per_1k_completion_tokens=10.0, 31 | tags=["flagship", "reasoning", "multimodal", "long-context", "tool-chaining"], 32 | ), 33 | "gpt-5-mini": ModelInfo( 34 | name="gpt-5-mini", 35 | provider="openai", 36 | max_tokens=128000, 37 | max_context_length=272000, 38 | supports_functions=True, 39 | supports_vision=True, 40 | cost_per_1k_prompt_tokens=0.3, 41 | cost_per_1k_completion_tokens=1.2, 42 | tags=["efficient", "multimodal", "long-context"], 43 | ), 44 | "gpt-4.1": ModelInfo( 45 | name="gpt-4.1", 46 | provider="openai", 47 | max_tokens=128000, 48 | max_context_length=1000000, 49 | supports_functions=True, 50 | supports_vision=True, 51 | cost_per_1k_prompt_tokens=0.004, 52 | cost_per_1k_completion_tokens=0.012, 53 | tags=["reasoning", "multimodal", "long-context", "cost-efficient"], 54 | ), 55 | "gpt-4.1-nano": ModelInfo( 56 | name="gpt-4.1-nano", 57 | provider="openai", 58 | max_tokens=32000, 59 | max_context_length=128000, 60 | supports_functions=True, 61 | supports_vision=False, 62 | cost_per_1k_prompt_tokens=0.00008, 63 | cost_per_1k_completion_tokens=0.0003, 64 | tags=["fastest", "cheapest", "efficient"], 65 | ), 66 | "gpt-4o": ModelInfo( 67 | name="gpt-4o", 68 | provider="openai", 69 | max_tokens=16384, 70 | max_context_length=128000, 71 | supports_functions=True, 72 | supports_vision=True, 73 | cost_per_1k_prompt_tokens=0.005, 74 | cost_per_1k_completion_tokens=0.015, 75 | tags=["multimodal", "general-purpose"], 76 | ), 77 | "gpt-4o-mini": ModelInfo( 78 | name="gpt-4o-mini", 79 | provider="openai", 80 | max_tokens=16384, 81 | max_context_length=128000, 82 | supports_functions=True, 83 | supports_vision=True, 84 | cost_per_1k_prompt_tokens=0.00015, 85 | cost_per_1k_completion_tokens=0.0006, 86 | tags=["multimodal", "efficient", "affordable"], 87 | ), 88 | "o1": ModelInfo( 89 | name="o1", 90 | provider="openai", 91 | max_tokens=100000, 92 | max_context_length=200000, 93 | supports_functions=True, 94 | supports_vision=False, 95 | cost_per_1k_prompt_tokens=15.0, 96 | cost_per_1k_completion_tokens=60.0, 97 | tags=["reasoning", "complex-tasks", "long-context"], 98 | ), 99 | "o4-mini": ModelInfo( 100 | name="o4-mini", 101 | provider="openai", 102 | max_tokens=100000, 103 | max_context_length=200000, 104 | supports_functions=True, 105 | supports_vision=False, 106 | cost_per_1k_prompt_tokens=3.0, 107 | cost_per_1k_completion_tokens=12.0, 108 | tags=["reasoning", "complex-tasks", "affordable"], 109 | ), 110 | "gpt-oss-120b": ModelInfo( 111 | name="gpt-oss-120b", 112 | provider="openai", 113 | max_tokens=32000, 114 | max_context_length=128000, 115 | supports_functions=True, 116 | supports_vision=False, 117 | cost_per_1k_prompt_tokens=0.0, 118 | cost_per_1k_completion_tokens=0.0, 119 | tags=["open-source", "code", "problem-solving", "tool-calling"], 120 | ), 121 | } 122 | 123 | @property 124 | def name(self) -> str: 125 | return "openai" 126 | 127 | def get_available_models(self) -> Dict[str, ModelInfo]: 128 | return self.MODELS.copy() 129 | 130 | def _get_api_endpoint(self) -> str: 131 | """Get OpenAI chat completions endpoint.""" 132 | base_url = self.config.api_base or "https://api.openai.com" 133 | base_url = base_url.rstrip("/") 134 | if base_url.endswith("/v1"): 135 | base_url = base_url[:-3] 136 | return f"{base_url}/v1/chat/completions" 137 | 138 | def _get_request_headers(self) -> Dict[str, str]: 139 | """Generate HTTP headers for OpenAI API requests.""" 140 | headers = { 141 | "Authorization": f"Bearer {self.config.api_key}", 142 | "Content-Type": "application/json", 143 | } 144 | 145 | if self.config.organization: 146 | headers["OpenAI-Organization"] = self.config.organization 147 | 148 | headers.update(self.config.headers) 149 | return headers 150 | 151 | def get_tool_adapter(self) -> Optional[BaseToolAdapter]: 152 | """Return the OpenAI tool adapter.""" 153 | from justllms.tools.adapters.openai import OpenAIToolAdapter 154 | 155 | return OpenAIToolAdapter() 156 | -------------------------------------------------------------------------------- /justllms/cli.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import click 4 | 5 | 6 | @click.group() 7 | @click.version_option() 8 | def main() -> None: 9 | pass 10 | 11 | 12 | @main.command() 13 | def sxs() -> None: 14 | """Launch side-by-side model comparison.""" 15 | from justllms.sxs.cli import run_interactive_sxs 16 | 17 | try: 18 | run_interactive_sxs() 19 | except KeyboardInterrupt: 20 | click.echo("\nInterrupted by user", err=True) 21 | sys.exit(0) 22 | except Exception as e: 23 | click.echo(f"Error: {e}", err=True) 24 | sys.exit(1) 25 | 26 | 27 | @main.group() 28 | def tools() -> None: 29 | """Tool discovery and management commands.""" 30 | pass 31 | 32 | 33 | @tools.command() 34 | @click.option("--provider", "-p", help="Filter tools by provider (e.g., 'google', 'all')") 35 | @click.option("--native", is_flag=True, help="Show only native provider tools") 36 | def list(provider: str, native: bool) -> None: 37 | """List available tools. 38 | 39 | Examples: 40 | justllms tools list 41 | justllms tools list --provider google 42 | justllms tools list --native 43 | """ 44 | from justllms.tools.registry import GlobalToolRegistry 45 | 46 | try: 47 | # Show native tools if requested 48 | if native or provider: 49 | if provider == "google" or (native and not provider): 50 | click.echo("Native Tools (Google):") 51 | click.echo("-" * 60) 52 | 53 | from justllms.tools.native.google_tools import GOOGLE_NATIVE_TOOLS 54 | 55 | for _tool_name, tool_class in GOOGLE_NATIVE_TOOLS.items(): 56 | tool_instance = tool_class() 57 | click.echo(f"\n {tool_instance.name}") 58 | click.echo(f" Description: {tool_instance.description}") 59 | click.echo(f" Namespace: {tool_instance.namespace}") 60 | click.echo(" Provider: google") 61 | 62 | if not native and provider not in ["google", "all"]: 63 | click.echo(f"No native tools available for provider: {provider}", err=True) 64 | 65 | # Show registered user tools 66 | if not native: 67 | registry = GlobalToolRegistry() 68 | registered_tools = registry.list_tools() 69 | 70 | if registered_tools: 71 | click.echo("\n\nRegistered User Tools:") 72 | click.echo("-" * 60) 73 | for tool_name in registered_tools: 74 | tool = registry.get_tool(tool_name) 75 | if tool: 76 | click.echo(f"\n {tool.name}") 77 | click.echo(f" Description: {tool.description}") 78 | if tool.namespace: 79 | click.echo(f" Namespace: {tool.namespace}") 80 | else: 81 | if not provider and not native: 82 | click.echo("\nNo user tools registered.") 83 | click.echo("Use @tool decorator or Client.register_tools() to add tools.") 84 | 85 | except Exception as e: 86 | click.echo(f"Error listing tools: {e}", err=True) 87 | sys.exit(1) 88 | 89 | 90 | @tools.command() 91 | @click.argument("tool_name") 92 | @click.option("--provider", "-p", help="Provider for native tools (e.g., 'google')") 93 | def describe(tool_name: str, provider: str) -> None: 94 | """Show detailed information about a specific tool. 95 | 96 | Examples: 97 | justllms tools describe my_tool 98 | justllms tools describe google_search --provider google 99 | """ 100 | from justllms.tools.registry import GlobalToolRegistry 101 | 102 | try: 103 | # Check native tools first if provider specified 104 | if provider == "google": 105 | from justllms.tools.native.google_tools import get_google_native_tool 106 | 107 | try: 108 | native_tool = get_google_native_tool(tool_name) 109 | click.echo(f"Tool: {native_tool.name}") 110 | click.echo("Type: Native Tool") 111 | click.echo("Provider: google") 112 | click.echo(f"Namespace: {native_tool.namespace}") 113 | click.echo(f"Description: {native_tool.description}") 114 | click.echo("\nConfiguration:") 115 | # Config is stored in metadata 116 | config_keys = [k for k in native_tool.metadata if k != "provider"] 117 | if config_keys: 118 | for key in config_keys: 119 | click.echo(f" {key}: {native_tool.metadata[key]}") 120 | else: 121 | click.echo(" No configuration") 122 | return 123 | except ValueError: 124 | click.echo(f"Native tool '{tool_name}' not found for provider 'google'", err=True) 125 | sys.exit(1) 126 | 127 | # Check user tools 128 | registry = GlobalToolRegistry() 129 | user_tool = registry.get_tool(tool_name) 130 | 131 | if not user_tool: 132 | click.echo(f"Tool '{tool_name}' not found", err=True) 133 | click.echo("\nUse 'justllms tools list' to see available tools.") 134 | sys.exit(1) 135 | 136 | # Display tool details 137 | click.echo(f"Tool: {user_tool.name}") 138 | click.echo("Type: User Tool") 139 | if user_tool.namespace: 140 | click.echo(f"Namespace: {user_tool.namespace}") 141 | click.echo(f"Description: {user_tool.description}") 142 | 143 | # Show parameters 144 | if user_tool.parameters: 145 | click.echo("\nParameters:") 146 | for param_name, param_info in user_tool.parameters.items(): 147 | required = " (required)" if param_info.required else "" 148 | default = ( 149 | f" [default: {param_info.default}]" if param_info.default is not None else "" 150 | ) 151 | desc = user_tool.parameter_descriptions.get(param_name, "") 152 | click.echo(f" {param_name}: {param_info.type}{required}{default}") 153 | if desc: 154 | click.echo(f" {desc}") 155 | 156 | # Show return type 157 | if user_tool.return_type: 158 | click.echo(f"\nReturns: {user_tool.return_type}") 159 | 160 | # Show metadata 161 | if user_tool.metadata: 162 | click.echo("\nMetadata:") 163 | for key, value in user_tool.metadata.items(): 164 | click.echo(f" {key}: {value}") 165 | 166 | except Exception as e: 167 | click.echo(f"Error describing tool: {e}", err=True) 168 | sys.exit(1) 169 | 170 | 171 | if __name__ == "__main__": 172 | main() 173 | -------------------------------------------------------------------------------- /justllms/tools/decorators.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import Callable, Dict, Optional, Union 3 | 4 | from justllms.tools.models import Tool 5 | from justllms.tools.utils import ( 6 | extract_docstring_descriptions, 7 | extract_function_schema, 8 | ) 9 | 10 | 11 | def tool( 12 | func: Optional[Callable] = None, 13 | *, 14 | name: Optional[str] = None, 15 | namespace: Optional[str] = None, 16 | description: Optional[str] = None, 17 | parameter_descriptions: Optional[Dict[str, str]] = None, 18 | register: bool = False, 19 | ) -> Union[Callable, Tool]: 20 | """Decorator to convert a function into a Tool. 21 | 22 | Can be used with or without parentheses: 23 | @tool 24 | def my_func(): ... 25 | 26 | @tool(name="custom_name", namespace="math") 27 | def my_func(): ... 28 | 29 | Args: 30 | func: The function to convert (when used without parentheses). 31 | name: Custom name for the tool (defaults to function name). 32 | namespace: Optional namespace for the tool. 33 | description: Custom description (defaults to function docstring). 34 | parameter_descriptions: Additional parameter descriptions. 35 | register: Whether to register globally (default: False). 36 | 37 | Returns: 38 | Tool instance when used as decorator, or decorator function. 39 | 40 | Examples: 41 | >>> @tool 42 | ... def add(a: int, b: int) -> int: 43 | ... '''Add two numbers.''' 44 | ... return a + b 45 | 46 | >>> @tool(namespace="math", description="Multiply numbers") 47 | ... def multiply(x: float, y: float) -> float: 48 | ... return x * y 49 | 50 | >>> @tool( 51 | ... name="search_documents", 52 | ... parameter_descriptions={ 53 | ... "query": "The search query", 54 | ... "limit": "Maximum number of results" 55 | ... } 56 | ... ) 57 | ... def search(query: str, limit: int = 10) -> list: 58 | ... return [f"Result for {query}"] 59 | """ 60 | 61 | def decorator(f: Callable) -> Tool: 62 | """Inner decorator that creates the Tool instance.""" 63 | # Extract metadata 64 | tool_name = name or f.__name__ 65 | tool_description = description or inspect.getdoc(f) or f"Tool: {tool_name}" 66 | 67 | # Extract parameters 68 | parameters = extract_function_schema(f) 69 | 70 | # Extract parameter descriptions from docstring 71 | docstring_descriptions = extract_docstring_descriptions(f) 72 | 73 | # Merge parameter descriptions 74 | merged_descriptions = {} 75 | if docstring_descriptions: 76 | merged_descriptions.update(docstring_descriptions) 77 | if parameter_descriptions: 78 | merged_descriptions.update(parameter_descriptions) 79 | 80 | # Update parameter descriptions in ParameterInfo objects 81 | for param_name, param_info in parameters.items(): 82 | if param_name in merged_descriptions: 83 | param_info.description = merged_descriptions[param_name] 84 | 85 | # Get return type 86 | sig = inspect.signature(f) 87 | return_type = ( 88 | sig.return_annotation if sig.return_annotation != inspect.Signature.empty else None 89 | ) 90 | 91 | # Create Tool instance 92 | tool_instance = Tool( 93 | name=tool_name, 94 | namespace=namespace, 95 | description=tool_description, 96 | callable=f, 97 | parameters=parameters, 98 | parameter_descriptions=merged_descriptions, 99 | return_type=return_type, 100 | ) 101 | 102 | # Register globally if requested 103 | if register: 104 | from justllms.tools.registry import GlobalToolRegistry 105 | 106 | registry = GlobalToolRegistry() 107 | registry.register(tool_instance) 108 | 109 | # Add the tool as an attribute of the function 110 | f.tool = tool_instance # type: ignore 111 | 112 | return tool_instance 113 | 114 | # Handle usage with or without parentheses 115 | if func is not None: 116 | # Used without parentheses: @tool 117 | return decorator(func) 118 | else: 119 | # Used with parentheses: @tool(...) 120 | return decorator 121 | 122 | 123 | def tool_from_callable( 124 | func: Callable, 125 | name: Optional[str] = None, 126 | namespace: Optional[str] = None, 127 | description: Optional[str] = None, 128 | parameter_descriptions: Optional[Dict[str, str]] = None, 129 | ) -> Tool: 130 | """Convert an existing callable into a Tool. 131 | 132 | This is useful for converting existing functions that you can't 133 | or don't want to decorate directly. 134 | 135 | Args: 136 | func: The callable to convert. 137 | name: Custom name for the tool (defaults to function name). 138 | namespace: Optional namespace for the tool. 139 | description: Custom description (defaults to function docstring). 140 | parameter_descriptions: Parameter descriptions. 141 | 142 | Returns: 143 | Tool instance. 144 | 145 | Examples: 146 | >>> def existing_func(x: int, y: int) -> int: 147 | ... return x * y 148 | ... 149 | >>> tool_instance = tool_from_callable( 150 | ... existing_func, 151 | ... name="multiplier", 152 | ... description="Multiplies two numbers" 153 | ... ) 154 | """ 155 | # Extract metadata 156 | tool_name = name or func.__name__ 157 | tool_description = description or inspect.getdoc(func) or f"Tool: {tool_name}" 158 | 159 | # Extract parameters 160 | parameters = extract_function_schema(func) 161 | 162 | # Extract parameter descriptions from docstring 163 | docstring_descriptions = extract_docstring_descriptions(func) 164 | 165 | # Merge parameter descriptions 166 | merged_descriptions = {} 167 | if docstring_descriptions: 168 | merged_descriptions.update(docstring_descriptions) 169 | if parameter_descriptions: 170 | merged_descriptions.update(parameter_descriptions) 171 | 172 | # Update parameter descriptions in ParameterInfo objects 173 | for param_name, param_info in parameters.items(): 174 | if param_name in merged_descriptions: 175 | param_info.description = merged_descriptions[param_name] 176 | 177 | # Get return type 178 | sig = inspect.signature(func) 179 | return_type = ( 180 | sig.return_annotation if sig.return_annotation != inspect.Signature.empty else None 181 | ) 182 | 183 | # Create Tool instance 184 | return Tool( 185 | name=tool_name, 186 | namespace=namespace, 187 | description=tool_description, 188 | callable=func, 189 | parameters=parameters, 190 | parameter_descriptions=merged_descriptions, 191 | return_type=return_type, 192 | ) 193 | -------------------------------------------------------------------------------- /justllms/tools/registry.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional 2 | 3 | from justllms.tools.models import Tool 4 | 5 | 6 | class ToolRegistry: 7 | """Manages a collection of tools with namespace support. 8 | 9 | The registry provides a way to organize and access tools, 10 | with optional namespace isolation. 11 | 12 | Attributes: 13 | namespace: Optional namespace for this registry. 14 | _tools: Dictionary mapping tool names to Tool instances. 15 | 16 | Examples: 17 | >>> registry = ToolRegistry(namespace="math") 18 | >>> registry.register(add_tool) 19 | >>> registry.register(multiply_tool) 20 | >>> print(registry.list_tools()) 21 | ['add', 'multiply'] 22 | 23 | >>> # Get specific tool 24 | >>> tool = registry.get_tool("add") 25 | 26 | >>> # Merge registries 27 | >>> other_registry = ToolRegistry(namespace="text") 28 | >>> combined = registry.merge(other_registry) 29 | """ 30 | 31 | def __init__(self, namespace: Optional[str] = None): 32 | """Initialize a new tool registry. 33 | 34 | Args: 35 | namespace: Optional namespace for tools in this registry. 36 | """ 37 | self.namespace = namespace 38 | self._tools: Dict[str, Tool] = {} 39 | 40 | def register(self, tool: Tool) -> None: 41 | """Register a tool in the registry. 42 | 43 | If the tool doesn't have a namespace and the registry does, 44 | the registry's namespace will be applied to the tool. 45 | 46 | Args: 47 | tool: The Tool instance to register. 48 | 49 | Raises: 50 | ValueError: If a tool with the same name is already registered. 51 | """ 52 | # Apply registry namespace if tool doesn't have one 53 | if self.namespace and not tool.namespace: 54 | tool.namespace = self.namespace 55 | 56 | if tool.name in self._tools: 57 | existing = self._tools[tool.name] 58 | raise ValueError( 59 | f"Tool '{tool.name}' is already registered in this registry. " 60 | f"Existing tool: {existing.description}" 61 | ) 62 | 63 | self._tools[tool.name] = tool 64 | 65 | def unregister(self, name: str) -> None: 66 | """Remove a tool from the registry. 67 | 68 | Args: 69 | name: Name of the tool to remove. 70 | 71 | Raises: 72 | KeyError: If the tool doesn't exist. 73 | """ 74 | if name not in self._tools: 75 | raise KeyError(f"Tool '{name}' not found in registry") 76 | del self._tools[name] 77 | 78 | def get_tool(self, name: str) -> Optional[Tool]: 79 | """Get a tool by name. 80 | 81 | Args: 82 | name: Name of the tool to retrieve. 83 | 84 | Returns: 85 | Tool instance if found, None otherwise. 86 | """ 87 | return self._tools.get(name) 88 | 89 | def list_tools(self) -> List[str]: 90 | """List all tool names in the registry. 91 | 92 | Returns: 93 | List of tool names. 94 | """ 95 | return list(self._tools.keys()) 96 | 97 | def get_all_tools(self) -> List[Tool]: 98 | """Get all Tool instances in the registry. 99 | 100 | Returns: 101 | List of all Tool instances. 102 | """ 103 | return list(self._tools.values()) 104 | 105 | def merge(self, other: "ToolRegistry", check_conflicts: bool = True) -> "ToolRegistry": 106 | """Merge another registry into a new registry. 107 | 108 | Creates a new registry containing tools from both registries. 109 | Does not modify either original registry. 110 | 111 | Args: 112 | other: Another ToolRegistry to merge with. 113 | check_conflicts: Whether to check for name conflicts. 114 | 115 | Returns: 116 | New ToolRegistry containing tools from both. 117 | 118 | Raises: 119 | ValueError: If check_conflicts is True and there are name conflicts. 120 | """ 121 | # Create new registry with no specific namespace 122 | merged = ToolRegistry() 123 | 124 | # Add tools from this registry 125 | for tool in self._tools.values(): 126 | merged._tools[tool.name] = tool 127 | 128 | # Add tools from other registry 129 | for tool in other._tools.values(): 130 | if check_conflicts and tool.name in merged._tools: 131 | existing = merged._tools[tool.name] 132 | if existing.full_name != tool.full_name: 133 | raise ValueError( 134 | f"Tool name conflict: '{tool.name}' exists in both registries. " 135 | f"Existing: {existing.full_name}, New: {tool.full_name}" 136 | ) 137 | merged._tools[tool.name] = tool 138 | 139 | return merged 140 | 141 | def clear(self) -> None: 142 | """Remove all tools from the registry.""" 143 | self._tools.clear() 144 | 145 | def __len__(self) -> int: 146 | """Return the number of tools in the registry.""" 147 | return len(self._tools) 148 | 149 | def __contains__(self, name: str) -> bool: 150 | """Check if a tool exists in the registry. 151 | 152 | Args: 153 | name: Name of the tool to check. 154 | 155 | Returns: 156 | True if the tool exists, False otherwise. 157 | """ 158 | return name in self._tools 159 | 160 | def __repr__(self) -> str: 161 | """String representation of the registry.""" 162 | namespace_str = f", namespace='{self.namespace}'" if self.namespace else "" 163 | return f"ToolRegistry(tools={len(self._tools)}{namespace_str})" 164 | 165 | 166 | class GlobalToolRegistry(ToolRegistry): 167 | """Singleton global tool registry. 168 | 169 | This registry is shared across the application and can be used 170 | to register tools that should be available globally. 171 | 172 | Examples: 173 | >>> from justllms.tools.registry import GlobalToolRegistry 174 | >>> registry = GlobalToolRegistry() 175 | >>> registry.register(my_tool) 176 | >>> # Access from anywhere 177 | >>> registry2 = GlobalToolRegistry() 178 | >>> assert registry is registry2 # Same instance 179 | """ 180 | 181 | _instance: Optional["GlobalToolRegistry"] = None 182 | 183 | def __new__(cls) -> "GlobalToolRegistry": 184 | """Create or return the singleton instance.""" 185 | if cls._instance is None: 186 | cls._instance = super().__new__(cls) 187 | cls._instance.__initialized = False 188 | return cls._instance 189 | 190 | def __init__(self) -> None: 191 | """Initialize the global registry.""" 192 | # Only initialize once 193 | if not getattr(self, "_GlobalToolRegistry__initialized", False): 194 | super().__init__(namespace=None) 195 | self.__initialized = True 196 | 197 | @classmethod 198 | def reset(cls) -> None: 199 | """Reset the global registry (mainly for testing).""" 200 | if cls._instance: 201 | cls._instance.clear() 202 | cls._instance = None 203 | -------------------------------------------------------------------------------- /justllms/tools/adapters/anthropic.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Any, Dict, List, Optional, Union 3 | 4 | from justllms.core.models import Message, Role 5 | from justllms.tools.adapters.base import BaseToolAdapter 6 | from justllms.tools.models import Tool, ToolCall, ToolResult 7 | 8 | 9 | class AnthropicToolAdapter(BaseToolAdapter): 10 | """Adapter for Anthropic Claude's tool format. 11 | 12 | Claude uses a different format than OpenAI: 13 | - Tools have name, description, and input_schema 14 | - Tool choice is {"type": "auto" | "any" | "tool", "name": "..."} 15 | - Tool results are sent as user messages with tool_use_id 16 | """ 17 | 18 | def format_tools_for_api(self, tools: List[Tool]) -> List[Dict[str, Any]]: 19 | """Convert Tool objects to Anthropic's format. 20 | 21 | Args: 22 | tools: List of Tool instances. 23 | 24 | Returns: 25 | List of tool definitions in Anthropic format. 26 | 27 | Example output: 28 | [ 29 | { 30 | "name": "get_weather", 31 | "description": "Get weather for location", 32 | "input_schema": { 33 | "type": "object", 34 | "properties": { 35 | "location": { 36 | "type": "string", 37 | "description": "City name" 38 | } 39 | }, 40 | "required": ["location"] 41 | } 42 | } 43 | ] 44 | """ 45 | formatted_tools = [] 46 | 47 | for tool in tools: 48 | # Skip native tools (Anthropic doesn't have native tools) 49 | if tool.is_native: 50 | continue 51 | 52 | tool_def = { 53 | "name": tool.name, 54 | "description": tool.description, 55 | "input_schema": tool.to_json_schema(), 56 | } 57 | 58 | formatted_tools.append(tool_def) 59 | 60 | return formatted_tools 61 | 62 | def format_tool_choice( 63 | self, tool_choice: Optional[Union[str, Dict[str, Any]]] 64 | ) -> Optional[Dict[str, Any]]: 65 | """Format tool_choice for Anthropic API. 66 | 67 | Args: 68 | tool_choice: Tool selection strategy: 69 | - "auto": Let model decide (maps to {"type": "auto"}) 70 | - "none": Don't use tools (not directly supported, omit tools) 71 | - "required": Must use a tool (maps to {"type": "any"}) 72 | - {"name": "tool_name"}: Use specific tool 73 | - {"type": "auto|any|tool", "name": "..."}: Full format 74 | - None: Use default (auto) 75 | 76 | Returns: 77 | Formatted tool_choice for Anthropic API. 78 | """ 79 | if tool_choice is None: 80 | return {"type": "auto"} 81 | 82 | # Handle string choices 83 | if isinstance(tool_choice, str): 84 | choice_map = { 85 | "auto": {"type": "auto"}, 86 | "none": None, # Anthropic doesn't have explicit "none", handled by not passing tools 87 | "required": {"type": "any"}, # Map to "any" - must use some tool 88 | } 89 | if tool_choice in choice_map: 90 | return choice_map[tool_choice] 91 | # Assume it's a tool name 92 | return {"type": "tool", "name": tool_choice} 93 | 94 | # Handle dict choices 95 | if isinstance(tool_choice, dict): 96 | # Check if already in Anthropic format 97 | if "type" in tool_choice: 98 | return tool_choice 99 | 100 | # Convert from simple {"name": "..."} format 101 | if "name" in tool_choice: 102 | return {"type": "tool", "name": tool_choice["name"]} 103 | 104 | # Default to auto 105 | return {"type": "auto"} 106 | 107 | def extract_tool_calls(self, response: Dict[str, Any]) -> List[ToolCall]: 108 | """Extract tool calls from Anthropic response. 109 | 110 | Args: 111 | response: Raw response from Anthropic API. 112 | 113 | Returns: 114 | List of ToolCall objects. 115 | """ 116 | tool_calls = [] 117 | 118 | # In Anthropic's format, tool use is in the content array 119 | content = response.get("content", []) 120 | 121 | for item in content: 122 | if isinstance(item, dict) and item.get("type") == "tool_use": 123 | # Extract tool call information 124 | tool_call = ToolCall( 125 | id=item.get("id", ""), 126 | name=item.get("name", ""), 127 | arguments=item.get("input", {}), 128 | raw_arguments=json.dumps(item.get("input", {})), 129 | ) 130 | tool_calls.append(tool_call) 131 | 132 | return tool_calls 133 | 134 | def format_tool_result_message(self, tool_result: ToolResult, tool_call: ToolCall) -> Message: 135 | """Format tool result as a message for Anthropic. 136 | 137 | Anthropic expects tool results as user messages with tool_result content. 138 | 139 | Args: 140 | tool_result: Result from tool execution. 141 | tool_call: Original tool call. 142 | 143 | Returns: 144 | Message with tool result in Anthropic format. 145 | """ 146 | # Anthropic uses a content array with tool_result items 147 | content = [ 148 | { 149 | "type": "tool_result", 150 | "tool_use_id": tool_call.id, 151 | "content": tool_result.to_message_content(), 152 | } 153 | ] 154 | 155 | return Message(role=Role.USER, content=content) 156 | 157 | def format_tool_calls_message(self, tool_calls: List[ToolCall]) -> Optional[Message]: 158 | """Format tool calls as an assistant message. 159 | 160 | Anthropic includes tool calls in the assistant's content array. 161 | 162 | Args: 163 | tool_calls: List of tool calls. 164 | 165 | Returns: 166 | Assistant message with tool calls in content. 167 | """ 168 | if not tool_calls: 169 | return None 170 | 171 | content: List[Dict[str, Any]] = [] 172 | 173 | # Add any text content if needed 174 | # Anthropic requires at least one text block before tool use 175 | content.append({"type": "text", "text": "I'll help you with that."}) 176 | 177 | # Add tool use blocks 178 | for tc in tool_calls: 179 | tool_use: Dict[str, Any] = { 180 | "type": "tool_use", 181 | "id": tc.id, 182 | "name": tc.name, 183 | "input": tc.arguments, 184 | } 185 | content.append(tool_use) 186 | 187 | return Message(role=Role.ASSISTANT, content=content) 188 | 189 | def supports_parallel_tools(self) -> bool: 190 | """Claude 3 supports calling multiple tools in one response.""" 191 | return True 192 | 193 | def supports_required_tools(self) -> bool: 194 | """Claude supports 'any' which is similar to required.""" 195 | return True 196 | 197 | def get_max_tools_per_call(self) -> Optional[int]: 198 | """Anthropic doesn't document a specific limit.""" 199 | return None 200 | -------------------------------------------------------------------------------- /justllms/providers/anthropic.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional 2 | 3 | from justllms.core.base import BaseProvider, BaseResponse 4 | from justllms.core.models import Choice, Message, ModelInfo, Role, Usage 5 | from justllms.tools.adapters.base import BaseToolAdapter 6 | 7 | 8 | class AnthropicResponse(BaseResponse): 9 | """Anthropic-specific response implementation.""" 10 | 11 | pass 12 | 13 | 14 | class AnthropicProvider(BaseProvider): 15 | """Anthropic provider implementation.""" 16 | 17 | supports_tools = True 18 | """Anthropic Claude supports tool use.""" 19 | 20 | MODELS = { 21 | "claude-opus-4.1": ModelInfo( 22 | name="claude-opus-4.1", 23 | provider="anthropic", 24 | max_tokens=32000, 25 | max_context_length=200000, 26 | supports_functions=True, 27 | supports_vision=True, 28 | cost_per_1k_prompt_tokens=15.0, 29 | cost_per_1k_completion_tokens=75.0, 30 | tags=["flagship", "most-capable", "multimodal", "extended-thinking"], 31 | ), 32 | "claude-sonnet-4": ModelInfo( 33 | name="claude-sonnet-4", 34 | provider="anthropic", 35 | max_tokens=64000, 36 | max_context_length=200000, 37 | supports_functions=True, 38 | supports_vision=True, 39 | cost_per_1k_prompt_tokens=3.0, 40 | cost_per_1k_completion_tokens=15.0, 41 | tags=["high-performance", "multimodal", "extended-thinking"], 42 | ), 43 | "claude-haiku-3.5": ModelInfo( 44 | name="claude-haiku-3.5", 45 | provider="anthropic", 46 | max_tokens=8192, 47 | max_context_length=200000, 48 | supports_functions=True, 49 | supports_vision=True, 50 | cost_per_1k_prompt_tokens=0.8, 51 | cost_per_1k_completion_tokens=4.0, 52 | tags=["fastest", "efficient", "multimodal"], 53 | ), 54 | "claude-3-5-sonnet-20241022": ModelInfo( 55 | name="claude-3-5-sonnet-20241022", 56 | provider="anthropic", 57 | max_tokens=8192, 58 | max_context_length=200000, 59 | supports_functions=True, 60 | supports_vision=True, 61 | cost_per_1k_prompt_tokens=0.003, 62 | cost_per_1k_completion_tokens=0.015, 63 | tags=["legacy", "reasoning", "multimodal"], 64 | ), 65 | "claude-3-5-haiku-20241022": ModelInfo( 66 | name="claude-3-5-haiku-20241022", 67 | provider="anthropic", 68 | max_tokens=8192, 69 | max_context_length=200000, 70 | supports_functions=True, 71 | supports_vision=False, 72 | cost_per_1k_prompt_tokens=0.001, 73 | cost_per_1k_completion_tokens=0.005, 74 | tags=["legacy", "fast", "efficient"], 75 | ), 76 | "claude-3-opus-20240229": ModelInfo( 77 | name="claude-3-opus-20240229", 78 | provider="anthropic", 79 | max_tokens=4096, 80 | max_context_length=200000, 81 | supports_functions=True, 82 | supports_vision=True, 83 | cost_per_1k_prompt_tokens=0.015, 84 | cost_per_1k_completion_tokens=0.075, 85 | tags=["legacy", "powerful", "reasoning"], 86 | ), 87 | } 88 | 89 | @property 90 | def name(self) -> str: 91 | return "anthropic" 92 | 93 | def get_available_models(self) -> Dict[str, ModelInfo]: 94 | return self.MODELS.copy() 95 | 96 | def _get_headers(self) -> Dict[str, str]: 97 | """Get request headers.""" 98 | headers = { 99 | "x-api-key": self.config.api_key or "", 100 | "anthropic-version": self.config.api_version or "2023-06-01", 101 | "content-type": "application/json", 102 | } 103 | 104 | headers.update(self.config.headers) 105 | return headers 106 | 107 | def _format_messages( 108 | self, messages: List[Message] 109 | ) -> tuple[Optional[str], List[Dict[str, Any]]]: 110 | """Format messages for Anthropic API.""" 111 | system_message = None 112 | formatted_messages = [] 113 | 114 | for msg in messages: 115 | if msg.role == Role.SYSTEM: 116 | system_message = msg.content if isinstance(msg.content, str) else str(msg.content) 117 | else: 118 | formatted_msg = { 119 | "role": "user" if msg.role == Role.USER else "assistant", 120 | "content": msg.content, 121 | } 122 | formatted_messages.append(formatted_msg) 123 | 124 | return system_message, formatted_messages 125 | 126 | def _parse_response(self, response_data: Dict[str, Any], model: str) -> AnthropicResponse: 127 | """Parse Anthropic API response.""" 128 | content = response_data.get("content", []) 129 | 130 | text_content = "" 131 | for item in content: 132 | if item.get("type") == "text": 133 | text_content = item.get("text", "") 134 | break 135 | 136 | message = Message( 137 | role=Role.ASSISTANT, 138 | content=text_content, 139 | ) 140 | 141 | choice = Choice( 142 | index=0, 143 | message=message, 144 | finish_reason=response_data.get("stop_reason"), 145 | ) 146 | 147 | usage_data = response_data.get("usage", {}) 148 | usage = Usage( 149 | prompt_tokens=usage_data.get("input_tokens", 0), 150 | completion_tokens=usage_data.get("output_tokens", 0), 151 | total_tokens=usage_data.get("input_tokens", 0) + usage_data.get("output_tokens", 0), 152 | ) 153 | 154 | return self._create_base_response( # type: ignore[return-value] 155 | AnthropicResponse, 156 | response_data, 157 | [choice], 158 | usage, 159 | model, 160 | ) 161 | 162 | def complete( 163 | self, 164 | messages: List[Message], 165 | model: str, 166 | timeout: Optional[float] = None, 167 | **kwargs: Any, 168 | ) -> BaseResponse: 169 | """Synchronous completion. 170 | 171 | Args: 172 | messages: List of messages for the completion. 173 | model: Model identifier to use. 174 | timeout: Optional timeout in seconds. If None, no timeout is enforced. 175 | **kwargs: Additional provider-specific parameters. 176 | """ 177 | url = f"{self.config.api_base or 'https://api.anthropic.com'}/v1/messages" 178 | 179 | system_message, formatted_messages = self._format_messages(messages) 180 | 181 | payload = { 182 | "model": model, 183 | "messages": formatted_messages, 184 | "max_tokens": kwargs.get("max_tokens", 4096), 185 | } 186 | 187 | if system_message: 188 | payload["system"] = system_message 189 | 190 | # Map common parameters 191 | if "temperature" in kwargs: 192 | payload["temperature"] = kwargs["temperature"] 193 | if "top_p" in kwargs: 194 | payload["top_p"] = kwargs["top_p"] 195 | if "stop" in kwargs: 196 | payload["stop_sequences"] = ( 197 | kwargs["stop"] if isinstance(kwargs["stop"], list) else [kwargs["stop"]] 198 | ) 199 | 200 | response_data = self._make_http_request( 201 | url=url, 202 | payload=payload, 203 | headers=self._get_headers(), 204 | timeout=timeout, 205 | ) 206 | 207 | return self._parse_response(response_data, model) 208 | 209 | def get_tool_adapter(self) -> Optional[BaseToolAdapter]: 210 | """Return the Anthropic tool adapter.""" 211 | from justllms.tools.adapters.anthropic import AnthropicToolAdapter 212 | 213 | return AnthropicToolAdapter() 214 | -------------------------------------------------------------------------------- /justllms/tools/native/google_tools.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | 3 | from justllms.tools.models import Tool 4 | 5 | 6 | class GoogleNativeTool(Tool): 7 | """Base class for Google's native tools. 8 | 9 | Native tools are executed server-side by Google's API and don't require 10 | local execution. They're enabled via provider configuration. 11 | """ 12 | 13 | def __init__( 14 | self, 15 | name: str, 16 | description: str, 17 | namespace: str = "google", 18 | config: Optional[Dict[str, Any]] = None, 19 | ) -> None: 20 | """Initialize a Google native tool. 21 | 22 | Args: 23 | name: Tool identifier. 24 | description: Human-readable description. 25 | namespace: Tool namespace (default: "google"). 26 | config: Provider-specific configuration. 27 | """ 28 | 29 | # Native tools don't have a callable - they're handled by the provider 30 | def _placeholder() -> None: 31 | raise NotImplementedError( 32 | f"Native tool '{name}' is executed server-side by Google. " 33 | "It should not be called directly." 34 | ) 35 | 36 | # Store provider and config in metadata 37 | metadata = config.copy() if config else {} 38 | metadata["provider"] = "google" 39 | 40 | super().__init__( 41 | name=name, 42 | namespace=namespace, 43 | description=description, 44 | callable=_placeholder, 45 | parameters={}, 46 | parameter_descriptions={}, 47 | return_type=None, 48 | metadata=metadata, 49 | is_native=True, 50 | ) 51 | 52 | 53 | class GoogleSearch(GoogleNativeTool): 54 | """Google Search tool for Gemini models. 55 | 56 | Enables real-time web search capability. Results are retrieved and 57 | processed by Google's servers before being returned in the LLM response. 58 | 59 | Configuration options: 60 | - dynamic_retrieval_config: Configure retrieval parameters 61 | - mode: "MODE_DYNAMIC" (default) or "MODE_UNSPECIFIED" 62 | - dynamic_threshold: float (0.0-1.0), relevance threshold 63 | 64 | Example: 65 | ```python 66 | # Enable Google Search in config 67 | config = { 68 | "providers": { 69 | "google": { 70 | "api_key": "...", 71 | "native_tools": { 72 | "google_search": { 73 | "enabled": True, 74 | "dynamic_retrieval_config": { 75 | "mode": "MODE_DYNAMIC", 76 | "dynamic_threshold": 0.7 77 | } 78 | } 79 | } 80 | } 81 | } 82 | } 83 | 84 | client = Client(config) 85 | response = client.completion.create( 86 | messages=[{"role": "user", "content": "What's the weather in SF?"}], 87 | provider="google", 88 | model="gemini-2.0-flash-exp", 89 | ) 90 | ``` 91 | 92 | Note: 93 | - Only available in Gemini 1.5 Pro and newer models 94 | - Requires appropriate API permissions 95 | - Search results count toward token usage 96 | """ 97 | 98 | def __init__(self, config: Optional[Dict[str, Any]] = None) -> None: 99 | """Initialize Google Search tool. 100 | 101 | Args: 102 | config: Optional configuration dict with: 103 | - dynamic_retrieval_config: Retrieval settings 104 | - enabled: Whether tool is enabled (default: True) 105 | """ 106 | super().__init__( 107 | name="google_search", 108 | description=( 109 | "Search Google for real-time information. Returns relevant web results " 110 | "that are automatically incorporated into the response." 111 | ), 112 | namespace="google", 113 | config=config, 114 | ) 115 | 116 | def to_api_format(self) -> Dict[str, Any]: 117 | """Convert to Gemini API format. 118 | 119 | Returns: 120 | Dict in format expected by Gemini API's googleSearch tool. 121 | """ 122 | api_format: Dict[str, Any] = {"googleSearch": {}} 123 | 124 | # Add dynamic retrieval config if provided 125 | if "dynamic_retrieval_config" in self.metadata: 126 | api_format["googleSearchRetrieval"] = { 127 | "dynamicRetrievalConfig": self.metadata["dynamic_retrieval_config"] 128 | } 129 | 130 | return api_format 131 | 132 | 133 | class GoogleCodeExecution(GoogleNativeTool): 134 | """Code execution tool for Gemini models. 135 | 136 | Enables Python code execution in a sandboxed server-side environment. 137 | The model can generate and execute code to solve problems, perform 138 | calculations, or process data. 139 | 140 | Configuration options: 141 | - timeout: Execution timeout in seconds (default: 30) 142 | - max_output_size: Maximum output size in bytes 143 | 144 | Example: 145 | ```python 146 | # Enable code execution in config 147 | config = { 148 | "providers": { 149 | "google": { 150 | "api_key": "...", 151 | "native_tools": { 152 | "code_execution": { 153 | "enabled": True, 154 | "timeout": 30 155 | } 156 | } 157 | } 158 | } 159 | } 160 | 161 | client = Client(config) 162 | response = client.completion.create( 163 | messages=[{"role": "user", "content": "Calculate fibonacci(20)"}], 164 | provider="google", 165 | model="gemini-2.0-flash-exp", 166 | ) 167 | ``` 168 | 169 | Note: 170 | - Only available in Gemini 1.5 Pro and newer models 171 | - Execution happens in isolated sandbox environment 172 | - Code output is included in response 173 | - Limited to Python standard library 174 | """ 175 | 176 | def __init__(self, config: Optional[Dict[str, Any]] = None) -> None: 177 | """Initialize code execution tool. 178 | 179 | Args: 180 | config: Optional configuration dict with: 181 | - timeout: Execution timeout in seconds 182 | - max_output_size: Maximum output size 183 | - enabled: Whether tool is enabled (default: True) 184 | """ 185 | super().__init__( 186 | name="code_execution", 187 | description=( 188 | "Execute Python code in a sandboxed environment. " 189 | "Use this to perform calculations, data processing, or algorithmic tasks." 190 | ), 191 | namespace="google", 192 | config=config, 193 | ) 194 | 195 | def to_api_format(self) -> Dict[str, Any]: 196 | """Convert to Gemini API format. 197 | 198 | Returns: 199 | Dict in format expected by Gemini API's codeExecution tool. 200 | """ 201 | return {"codeExecution": {}} 202 | 203 | 204 | # Registry of available Google native tools 205 | GOOGLE_NATIVE_TOOLS = { 206 | "google_search": GoogleSearch, 207 | "code_execution": GoogleCodeExecution, 208 | } 209 | 210 | 211 | def get_google_native_tool(name: str, config: Optional[Dict[str, Any]] = None) -> GoogleNativeTool: 212 | """Get a Google native tool by name. 213 | 214 | Args: 215 | name: Tool name ("google_search" or "code_execution"). 216 | config: Optional tool-specific configuration. 217 | 218 | Returns: 219 | GoogleNativeTool instance. 220 | 221 | Raises: 222 | ValueError: If tool name is not recognized. 223 | """ 224 | if name not in GOOGLE_NATIVE_TOOLS: 225 | raise ValueError( 226 | f"Unknown Google native tool: {name}. " 227 | f"Available tools: {list(GOOGLE_NATIVE_TOOLS.keys())}" 228 | ) 229 | 230 | tool_class = GOOGLE_NATIVE_TOOLS[name] 231 | return tool_class(config=config) 232 | -------------------------------------------------------------------------------- /justllms/tools/native/manager.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional 2 | 3 | from justllms.tools.models import Tool 4 | from justllms.tools.native.google_tools import GOOGLE_NATIVE_TOOLS, get_google_native_tool 5 | 6 | 7 | class NativeToolManager: 8 | """Manages native tools for providers. 9 | 10 | Responsibilities: 11 | - Discover enabled native tools from provider config 12 | - Instantiate native tool objects with their configs 13 | - Merge native tools with user-defined tools 14 | - Handle namespace conflicts 15 | """ 16 | 17 | def __init__(self, provider: str, config: Optional[Dict[str, Any]] = None): 18 | """Initialize the native tool manager. 19 | 20 | Args: 21 | provider: Provider name (e.g., "google", "openai"). 22 | config: Provider's native_tools configuration dict. 23 | """ 24 | self.provider = provider 25 | self.config = config or {} 26 | self._native_tools: Dict[str, Tool] = {} 27 | self._load_native_tools() 28 | 29 | def _load_native_tools(self) -> None: 30 | """Load and instantiate enabled native tools based on config.""" 31 | if self.provider == "google": 32 | self._load_google_tools() 33 | # Add more providers as they support native tools 34 | # elif self.provider == "openai": 35 | # self._load_openai_tools() 36 | 37 | def _load_google_tools(self) -> None: 38 | """Load Google native tools from config. 39 | 40 | Config format: 41 | { 42 | "google_search": { 43 | "enabled": True, 44 | "dynamic_retrieval_config": {...} 45 | }, 46 | "code_execution": { 47 | "enabled": True, 48 | "timeout": 30 49 | } 50 | } 51 | """ 52 | for tool_name, tool_config in self.config.items(): 53 | # Skip if not enabled 54 | if not tool_config.get("enabled", False): 55 | continue 56 | 57 | # Check if it's a valid Google native tool 58 | if tool_name not in GOOGLE_NATIVE_TOOLS: 59 | continue 60 | 61 | try: 62 | # Instantiate the tool with its config 63 | tool_instance = get_google_native_tool(tool_name, tool_config) 64 | self._native_tools[tool_name] = tool_instance 65 | except Exception: 66 | # Silently skip tools that fail to initialize 67 | pass 68 | 69 | def get_native_tools(self) -> List[Tool]: 70 | """Get all loaded native tools. 71 | 72 | Returns: 73 | List of native Tool instances. 74 | """ 75 | return list(self._native_tools.values()) 76 | 77 | def get_native_tool(self, name: str) -> Optional[Tool]: 78 | """Get a specific native tool by name. 79 | 80 | Args: 81 | name: Tool name. 82 | 83 | Returns: 84 | Tool instance if found, None otherwise. 85 | """ 86 | return self._native_tools.get(name) 87 | 88 | def has_native_tool(self, name: str) -> bool: 89 | """Check if a native tool is loaded. 90 | 91 | Args: 92 | name: Tool name. 93 | 94 | Returns: 95 | True if tool is loaded, False otherwise. 96 | """ 97 | return name in self._native_tools 98 | 99 | def merge_with_user_tools( 100 | self, user_tools: List[Tool], prefer_native: bool = True 101 | ) -> List[Tool]: 102 | """Merge native tools with user-defined tools. 103 | 104 | Args: 105 | user_tools: User-defined tools. 106 | prefer_native: If True, native tools override user tools with same name. 107 | If False, user tools take precedence. 108 | 109 | Returns: 110 | Combined list of tools. 111 | """ 112 | # Create a dict to track tools by qualified name (namespace:name) 113 | merged: Dict[str, Tool] = {} 114 | 115 | # Add native tools first if prefer_native, else user tools first 116 | first_tools = self._native_tools.values() if prefer_native else user_tools 117 | second_tools = user_tools if prefer_native else self._native_tools.values() 118 | 119 | # Add first set of tools 120 | for tool in first_tools: 121 | key = f"{tool.namespace}:{tool.name}" if tool.namespace else tool.name 122 | merged[key] = tool 123 | 124 | # Add second set, respecting the preference order 125 | for tool in second_tools: 126 | key = f"{tool.namespace}:{tool.name}" if tool.namespace else tool.name 127 | if key not in merged: 128 | merged[key] = tool 129 | 130 | return list(merged.values()) 131 | 132 | def get_api_format_for_google(self) -> List[Dict[str, Any]]: 133 | """Get Google-specific API format for native tools. 134 | 135 | Google native tools use a different format than regular function declarations. 136 | They're added to the "tools" array with special keys like "googleSearch" or 137 | "codeExecution". 138 | 139 | Returns: 140 | List of native tool dicts in Google API format. 141 | """ 142 | api_tools = [] 143 | 144 | for tool in self._native_tools.values(): 145 | if hasattr(tool, "to_api_format"): 146 | api_tools.append(tool.to_api_format()) 147 | 148 | return api_tools 149 | 150 | def __len__(self) -> int: 151 | """Return number of loaded native tools.""" 152 | return len(self._native_tools) 153 | 154 | def __bool__(self) -> bool: 155 | """Return True if any native tools are loaded.""" 156 | return len(self._native_tools) > 0 157 | 158 | 159 | class GoogleNativeToolManager(NativeToolManager): 160 | """Specialized manager for Google native tools. 161 | 162 | Provides Google-specific functionality for native tool management. 163 | """ 164 | 165 | def __init__(self, config: Optional[Dict[str, Any]] = None): 166 | """Initialize Google native tool manager. 167 | 168 | Args: 169 | config: Google native_tools configuration. 170 | """ 171 | super().__init__(provider="google", config=config) 172 | 173 | def format_for_gemini_api( 174 | self, user_tools: Optional[List[Tool]] = None 175 | ) -> List[Dict[str, Any]]: 176 | """Format tools for Gemini API request. 177 | 178 | Gemini requires native tools and function declarations in separate formats 179 | within the same "tools" array. 180 | 181 | Args: 182 | user_tools: Optional user-defined tools to include. 183 | 184 | Returns: 185 | List of formatted tool definitions for Gemini API. 186 | """ 187 | tools_array: List[Dict[str, Any]] = [] 188 | 189 | # Add native tools in their special format 190 | native_tools_api = self.get_api_format_for_google() 191 | if native_tools_api: 192 | # Native tools go directly in the tools array 193 | tools_array.extend(native_tools_api) 194 | 195 | # Add user-defined tools as function declarations 196 | if user_tools: 197 | from justllms.tools.adapters.google import GoogleToolAdapter 198 | 199 | adapter = GoogleToolAdapter() 200 | user_tools_formatted = adapter.format_tools_for_api(user_tools) 201 | 202 | # User tools are wrapped in functionDeclarations 203 | if user_tools_formatted: 204 | tools_array.extend(user_tools_formatted) 205 | 206 | return tools_array 207 | 208 | 209 | def create_native_tool_manager( 210 | provider: str, config: Optional[Dict[str, Any]] = None 211 | ) -> Optional[NativeToolManager]: 212 | """Factory function to create the appropriate native tool manager. 213 | 214 | Args: 215 | provider: Provider name (e.g., "google", "openai"). 216 | config: Provider's native_tools configuration. 217 | 218 | Returns: 219 | NativeToolManager instance if provider supports native tools, None otherwise. 220 | """ 221 | if provider == "google": 222 | return GoogleNativeToolManager(config=config) 223 | 224 | # Add more providers as they support native tools 225 | # elif provider == "openai": 226 | # return OpenAINativeToolManager(config=config) 227 | 228 | return None 229 | -------------------------------------------------------------------------------- /justllms/tools/adapters/openai.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Any, Dict, List, Optional, Union 3 | 4 | from justllms.core.models import Message, Role 5 | from justllms.tools.adapters.base import BaseToolAdapter 6 | from justllms.tools.models import Tool, ToolCall, ToolResult 7 | 8 | 9 | class OpenAIToolAdapter(BaseToolAdapter): 10 | """Adapter for OpenAI's function calling format. 11 | 12 | OpenAI uses a specific format for function definitions and tool calls: 13 | - Functions are wrapped in {"type": "function", "function": {...}} 14 | - Tool choice can be "auto", "none", "required", or specific function 15 | - Tool calls are returned in the response choices 16 | 17 | This adapter is also used by Azure OpenAI since they share the same API. 18 | """ 19 | 20 | def format_tools_for_api(self, tools: List[Tool]) -> List[Dict[str, Any]]: 21 | """Convert Tool objects to OpenAI's function format. 22 | 23 | Args: 24 | tools: List of Tool instances. 25 | 26 | Returns: 27 | List of function definitions in OpenAI format. 28 | 29 | Example output: 30 | [ 31 | { 32 | "type": "function", 33 | "function": { 34 | "name": "get_weather", 35 | "description": "Get weather for location", 36 | "parameters": { 37 | "type": "object", 38 | "properties": { 39 | "location": {"type": "string"}, 40 | "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]} 41 | }, 42 | "required": ["location"] 43 | } 44 | } 45 | } 46 | ] 47 | """ 48 | formatted_tools = [] 49 | 50 | for tool in tools: 51 | # Skip native tools (OpenAI doesn't have native tools) 52 | if tool.is_native: 53 | continue 54 | 55 | function_def = { 56 | "type": "function", 57 | "function": { 58 | "name": tool.name, 59 | "description": tool.description, 60 | "parameters": tool.to_json_schema(), 61 | }, 62 | } 63 | 64 | formatted_tools.append(function_def) 65 | 66 | return formatted_tools 67 | 68 | def format_tool_choice( 69 | self, tool_choice: Optional[Union[str, Dict[str, Any]]] 70 | ) -> Optional[Union[str, Dict[str, Any]]]: 71 | """Format tool_choice for OpenAI API. 72 | 73 | Args: 74 | tool_choice: Tool selection strategy: 75 | - "auto": Model decides whether to use tools 76 | - "none": Model will not call functions 77 | - "required": Model must call at least one function 78 | - {"name": "function_name"}: Force specific function 79 | - {"type": "function", "function": {"name": "..."}}: Full format 80 | - None: Use default (auto) 81 | 82 | Returns: 83 | Formatted tool_choice for OpenAI API. 84 | """ 85 | if tool_choice is None: 86 | return "auto" 87 | 88 | # Handle string choices 89 | if isinstance(tool_choice, str): 90 | if tool_choice in ("auto", "none", "required"): 91 | return tool_choice 92 | # If it's a function name as string, convert to dict 93 | return {"type": "function", "function": {"name": tool_choice}} 94 | 95 | # Handle dict choices 96 | if isinstance(tool_choice, dict): 97 | # Check if it's already in full format 98 | if "type" in tool_choice and tool_choice["type"] == "function": 99 | return tool_choice 100 | 101 | # Check if it's simplified format {"name": "function_name"} 102 | if "name" in tool_choice: 103 | return {"type": "function", "function": {"name": tool_choice["name"]}} 104 | 105 | # Default to auto if format not recognized 106 | return "auto" 107 | 108 | def extract_tool_calls(self, response: Dict[str, Any]) -> List[ToolCall]: 109 | """Extract tool calls from OpenAI response. 110 | 111 | Args: 112 | response: Raw response from OpenAI API. 113 | 114 | Returns: 115 | List of ToolCall objects. 116 | """ 117 | tool_calls = [] 118 | 119 | # Check choices for tool calls 120 | choices = response.get("choices", []) 121 | for choice in choices: 122 | message = choice.get("message", {}) 123 | 124 | # Check for tool_calls in message 125 | message_tool_calls = message.get("tool_calls", []) 126 | for tc in message_tool_calls: 127 | function = tc.get("function", {}) 128 | 129 | # Parse arguments 130 | arguments_str = function.get("arguments", "{}") 131 | try: 132 | arguments = json.loads(arguments_str) 133 | except json.JSONDecodeError: 134 | arguments = {} 135 | 136 | tool_call = ToolCall( 137 | id=tc.get("id", ""), 138 | name=function.get("name", ""), 139 | arguments=arguments, 140 | raw_arguments=arguments_str, 141 | ) 142 | tool_calls.append(tool_call) 143 | 144 | # Also check for legacy function_call format 145 | function_call = message.get("function_call") 146 | if function_call: 147 | arguments_str = function_call.get("arguments", "{}") 148 | try: 149 | arguments = json.loads(arguments_str) 150 | except json.JSONDecodeError: 151 | arguments = {} 152 | 153 | tool_call = ToolCall( 154 | id=f"call_{function_call.get('name', 'unknown')}", 155 | name=function_call.get("name", ""), 156 | arguments=arguments, 157 | raw_arguments=arguments_str, 158 | ) 159 | tool_calls.append(tool_call) 160 | 161 | return tool_calls 162 | 163 | def format_tool_result_message(self, tool_result: ToolResult, tool_call: ToolCall) -> Message: 164 | """Format tool result as a message for OpenAI. 165 | 166 | OpenAI expects tool results as messages with role="tool". 167 | 168 | Args: 169 | tool_result: Result from tool execution. 170 | tool_call: Original tool call. 171 | 172 | Returns: 173 | Message with role="tool" containing the result. 174 | """ 175 | content = tool_result.to_message_content() 176 | 177 | return Message( 178 | role=Role.TOOL, 179 | content=content, 180 | name=tool_call.name, 181 | tool_call_id=tool_call.id, # OpenAI requires matching the tool call ID 182 | ) 183 | 184 | def format_tool_calls_message(self, tool_calls: List[ToolCall]) -> Optional[Message]: 185 | """Format tool calls as an assistant message. 186 | 187 | OpenAI requires tool calls to be in an assistant message. 188 | 189 | Args: 190 | tool_calls: List of tool calls. 191 | 192 | Returns: 193 | Assistant message with tool calls. 194 | """ 195 | if not tool_calls: 196 | return None 197 | 198 | tool_calls_data = [] 199 | for tc in tool_calls: 200 | tool_call_dict = { 201 | "id": tc.id, 202 | "type": "function", 203 | "function": { 204 | "name": tc.name, 205 | "arguments": tc.raw_arguments or json.dumps(tc.arguments), 206 | }, 207 | } 208 | tool_calls_data.append(tool_call_dict) 209 | 210 | return Message( 211 | role=Role.ASSISTANT, 212 | content="", # OpenAI tool calls have empty content 213 | tool_calls=tool_calls_data, 214 | ) 215 | 216 | def supports_parallel_tools(self) -> bool: 217 | """OpenAI supports calling multiple tools in parallel.""" 218 | return True 219 | 220 | def supports_required_tools(self) -> bool: 221 | """OpenAI supports the 'required' tool_choice option.""" 222 | return True 223 | 224 | def get_max_tools_per_call(self) -> Optional[int]: 225 | """OpenAI has a limit of 128 functions per request.""" 226 | return 128 227 | -------------------------------------------------------------------------------- /justllms/tools/utils.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Union, get_args, get_origin 3 | 4 | from justllms.tools.models import ParameterInfo 5 | 6 | if TYPE_CHECKING: 7 | from justllms.tools.models import Tool 8 | 9 | 10 | def python_type_to_json_schema(python_type: type) -> Dict[str, Any]: 11 | """Convert Python type hints to JSON Schema types. 12 | 13 | Args: 14 | python_type: Python type or type hint. 15 | 16 | Returns: 17 | JSON Schema type definition. 18 | """ 19 | # Handle None type 20 | if python_type is type(None) or python_type is None: 21 | return {"type": "null"} 22 | 23 | # Handle basic types 24 | if python_type is str: 25 | return {"type": "string"} 26 | elif python_type is int: 27 | return {"type": "integer"} 28 | elif python_type is float: 29 | return {"type": "number"} 30 | elif python_type is bool: 31 | return {"type": "boolean"} 32 | elif python_type is dict or python_type is Dict: 33 | return {"type": "object"} 34 | 35 | # Handle typing module types 36 | origin = get_origin(python_type) 37 | args = get_args(python_type) 38 | 39 | # Handle Optional[T] (Union[T, None]) 40 | if origin is Union: 41 | non_none_args = [arg for arg in args if arg is not type(None)] 42 | if len(non_none_args) == 1: 43 | # Optional[T] 44 | schema = python_type_to_json_schema(non_none_args[0]) 45 | # Don't mark as required in the parent schema 46 | return schema 47 | else: 48 | # Union of multiple types 49 | return {"oneOf": [python_type_to_json_schema(arg) for arg in args]} 50 | 51 | # Handle List[T] 52 | elif origin in (list, List): 53 | items_type = args[0] if args else Any 54 | return { 55 | "type": "array", 56 | "items": python_type_to_json_schema(items_type), 57 | } 58 | 59 | # Handle Dict[K, V] 60 | elif origin in (dict, Dict): 61 | return {"type": "object"} 62 | 63 | # Default to string for unknown types 64 | return {"type": "string"} 65 | 66 | 67 | def extract_function_schema(func: Callable) -> Dict[str, ParameterInfo]: 68 | """Extract parameter information from a function signature. 69 | 70 | Args: 71 | func: The function to introspect. 72 | 73 | Returns: 74 | Dictionary mapping parameter names to ParameterInfo objects. 75 | """ 76 | sig = inspect.signature(func) 77 | parameters = {} 78 | 79 | for param_name, param in sig.parameters.items(): 80 | # Skip self/cls parameters 81 | if param_name in ("self", "cls"): 82 | continue 83 | 84 | # Determine if parameter is required 85 | required = param.default is inspect.Parameter.empty 86 | 87 | # Get type hint 88 | param_type = param.annotation if param.annotation != inspect.Parameter.empty else Any 89 | 90 | # Convert to JSON schema type 91 | schema_info = python_type_to_json_schema(param_type) 92 | 93 | # Handle Optional types - they're not required 94 | if get_origin(param_type) is Union: 95 | args = get_args(param_type) 96 | if type(None) in args: 97 | required = False 98 | 99 | param_info = ParameterInfo( 100 | name=param_name, 101 | type=schema_info.get("type", "string"), 102 | required=required, 103 | default=None if param.default is inspect.Parameter.empty else param.default, 104 | ) 105 | 106 | # Add additional schema properties 107 | if "items" in schema_info: 108 | param_info.items = schema_info["items"] 109 | if "properties" in schema_info: 110 | param_info.properties = schema_info["properties"] 111 | 112 | parameters[param_name] = param_info 113 | 114 | return parameters 115 | 116 | 117 | def extract_docstring_descriptions(func: Callable) -> Dict[str, str]: 118 | """Extract parameter descriptions from function docstring. 119 | 120 | Supports Google-style docstrings: 121 | Args: 122 | param_name: Description of parameter. 123 | 124 | Args: 125 | func: Function to extract docstring from. 126 | 127 | Returns: 128 | Dictionary mapping parameter names to descriptions. 129 | """ 130 | descriptions: Dict[str, str] = {} 131 | docstring = inspect.getdoc(func) 132 | 133 | if not docstring: 134 | return descriptions 135 | 136 | lines = docstring.split("\n") 137 | in_args_section = False 138 | current_param = None 139 | current_desc_lines = [] 140 | 141 | for line in lines: 142 | stripped = line.strip() 143 | 144 | # Check if we're entering Args section 145 | if stripped in ("Args:", "Arguments:", "Parameters:"): 146 | in_args_section = True 147 | continue 148 | 149 | # Check if we're leaving Args section 150 | if ( 151 | in_args_section 152 | and stripped 153 | and not line.startswith((" ", "\t")) 154 | and stripped.endswith(":") 155 | ): 156 | in_args_section = False 157 | 158 | if in_args_section: 159 | # Check if this is a parameter definition (has : after the param name) 160 | if ":" in stripped and line.startswith((" " * 4, "\t")): 161 | # Save previous parameter if exists 162 | if current_param and current_desc_lines: 163 | descriptions[current_param] = " ".join(current_desc_lines).strip() 164 | 165 | # Parse new parameter 166 | param_part, desc_part = stripped.split(":", 1) 167 | current_param = param_part.strip() 168 | 169 | # Handle type hints in docstring (param_name (type): description) 170 | if "(" in current_param and ")" in current_param: 171 | current_param = current_param.split("(")[0].strip() 172 | 173 | current_desc_lines = [desc_part.strip()] if desc_part.strip() else [] 174 | 175 | elif current_param and stripped: 176 | # Continuation of previous parameter description 177 | current_desc_lines.append(stripped) 178 | 179 | # Save last parameter 180 | if current_param and current_desc_lines: 181 | descriptions[current_param] = " ".join(current_desc_lines).strip() 182 | 183 | return descriptions 184 | 185 | 186 | def validate_tool_arguments(tool: "Tool", arguments: Dict[str, Any]) -> Dict[str, Any]: 187 | """Validate and coerce arguments for a tool. 188 | 189 | Args: 190 | tool: The Tool instance. 191 | arguments: Arguments to validate. 192 | 193 | Returns: 194 | Validated and coerced arguments. 195 | 196 | Raises: 197 | ValueError: If required arguments are missing or invalid. 198 | """ 199 | 200 | validated = {} 201 | 202 | # Check required parameters 203 | for param_name, param_info in tool.parameters.items(): 204 | if param_info.required and param_name not in arguments: 205 | raise ValueError(f"Missing required parameter: {param_name}") 206 | 207 | if param_name in arguments: 208 | value = arguments[param_name] 209 | 210 | # Basic type coercion 211 | if param_info.type == "integer" and not isinstance(value, int): 212 | try: 213 | value = int(value) 214 | except (TypeError, ValueError): 215 | raise ValueError(f"Parameter {param_name} must be an integer") from None 216 | 217 | elif param_info.type == "number" and not isinstance(value, (int, float)): 218 | try: 219 | value = float(value) 220 | except (TypeError, ValueError): 221 | raise ValueError(f"Parameter {param_name} must be a number") from None 222 | 223 | elif param_info.type == "boolean" and not isinstance(value, bool): 224 | if isinstance(value, str): 225 | value = value.lower() in ("true", "1", "yes") 226 | else: 227 | value = bool(value) 228 | 229 | elif param_info.type == "array" and not isinstance(value, list): 230 | raise ValueError(f"Parameter {param_name} must be an array") 231 | 232 | elif param_info.type == "object" and not isinstance(value, dict): 233 | raise ValueError(f"Parameter {param_name} must be an object") 234 | 235 | validated[param_name] = value 236 | elif param_info.default is not None: 237 | # Use default value 238 | validated[param_name] = param_info.default 239 | 240 | return validated 241 | -------------------------------------------------------------------------------- /justllms/providers/grok.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Any, Dict, List, Optional 3 | 4 | import httpx 5 | from tenacity import retry, stop_after_attempt, wait_exponential 6 | 7 | from justllms.core.base import DEFAULT_TIMEOUT, BaseProvider, BaseResponse 8 | from justllms.core.models import Choice, Message, ModelInfo, Usage 9 | from justllms.exceptions import ProviderError 10 | 11 | 12 | class GrokResponse(BaseResponse): 13 | """Grok-specific response implementation.""" 14 | 15 | pass 16 | 17 | 18 | class GrokProvider(BaseProvider): 19 | """Grok provider implementation.""" 20 | 21 | MODELS = { 22 | "grok-4": ModelInfo( 23 | name="grok-4", 24 | provider="grok", 25 | max_tokens=32768, 26 | max_context_length=130000, 27 | supports_functions=True, 28 | supports_vision=True, 29 | cost_per_1k_prompt_tokens=6.0, 30 | cost_per_1k_completion_tokens=30.0, 31 | tags=["flagship", "most-intelligent", "multimodal", "coding", "latest"], 32 | ), 33 | "grok-4-heavy": ModelInfo( 34 | name="grok-4-heavy", 35 | provider="grok", 36 | max_tokens=32768, 37 | max_context_length=130000, 38 | supports_functions=True, 39 | supports_vision=True, 40 | cost_per_1k_prompt_tokens=8.0, 41 | cost_per_1k_completion_tokens=40.0, 42 | tags=["heavy", "premium", "exclusive", "multimodal"], 43 | ), 44 | "grok-3": ModelInfo( 45 | name="grok-3", 46 | provider="grok", 47 | max_tokens=32768, 48 | max_context_length=131072, 49 | supports_functions=True, 50 | supports_vision=False, 51 | cost_per_1k_prompt_tokens=3.0, 52 | cost_per_1k_completion_tokens=15.0, 53 | tags=["advanced", "reasoning", "long-context"], 54 | ), 55 | "grok-3-speedy": ModelInfo( 56 | name="grok-3-speedy", 57 | provider="grok", 58 | max_tokens=32768, 59 | max_context_length=131072, 60 | supports_functions=True, 61 | supports_vision=False, 62 | cost_per_1k_prompt_tokens=5.0, 63 | cost_per_1k_completion_tokens=25.0, 64 | tags=["speedy", "premium", "fast"], 65 | ), 66 | "grok-3-mini": ModelInfo( 67 | name="grok-3-mini", 68 | provider="grok", 69 | max_tokens=16384, 70 | max_context_length=131072, 71 | supports_functions=True, 72 | supports_vision=False, 73 | cost_per_1k_prompt_tokens=0.3, 74 | cost_per_1k_completion_tokens=0.5, 75 | tags=["mini", "affordable", "efficient"], 76 | ), 77 | "grok-3-mini-speedy": ModelInfo( 78 | name="grok-3-mini-speedy", 79 | provider="grok", 80 | max_tokens=16384, 81 | max_context_length=131072, 82 | supports_functions=True, 83 | supports_vision=False, 84 | cost_per_1k_prompt_tokens=0.6, 85 | cost_per_1k_completion_tokens=4.0, 86 | tags=["mini", "speedy", "fast", "affordable"], 87 | ), 88 | } 89 | 90 | @property 91 | def name(self) -> str: 92 | return "grok" 93 | 94 | def get_available_models(self) -> Dict[str, ModelInfo]: 95 | return self.MODELS.copy() 96 | 97 | def _get_api_endpoint(self) -> str: 98 | """Get the API endpoint.""" 99 | base_url = self.config.api_base or "https://api.x.ai" 100 | return f"{base_url}/v1/chat/completions" 101 | 102 | def _format_messages(self, messages: List[Message]) -> List[Dict[str, Any]]: 103 | """Format messages for Grok API (OpenAI-compatible format).""" 104 | formatted_messages = [] 105 | 106 | for msg in messages: 107 | formatted_msg: Dict[str, Any] = {"role": msg.role.value, "content": msg.content} 108 | 109 | # Handle multimodal content if supported 110 | if isinstance(msg.content, list): 111 | content_list: List[Dict[str, Any]] = [] 112 | for item in msg.content: 113 | if isinstance(item, dict): 114 | if item.get("type") == "text": 115 | content_list.append({"type": "text", "text": item.get("text", "")}) 116 | elif item.get("type") == "image": 117 | content_list.append( 118 | {"type": "image_url", "image_url": item.get("image", {})} 119 | ) 120 | formatted_msg["content"] = content_list 121 | 122 | formatted_messages.append(formatted_msg) 123 | 124 | return formatted_messages 125 | 126 | def _get_headers(self) -> Dict[str, str]: 127 | """Get request headers.""" 128 | return { 129 | "Authorization": f"Bearer {self.config.api_key}", 130 | "Content-Type": "application/json", 131 | } 132 | 133 | def _parse_response(self, response_data: Dict[str, Any], model: str) -> GrokResponse: 134 | """Parse Grok API response.""" 135 | choices_data = response_data.get("choices", []) 136 | 137 | if not choices_data: 138 | raise ProviderError("No choices in Grok response") 139 | 140 | # Parse choices 141 | choices = [] 142 | for choice_data in choices_data: 143 | message_data = choice_data.get("message", {}) 144 | message = Message( 145 | role=message_data.get("role", "assistant"), 146 | content=message_data.get("content", ""), 147 | ) 148 | choice = Choice( 149 | index=choice_data.get("index", 0), 150 | message=message, 151 | finish_reason=choice_data.get("finish_reason", "stop"), 152 | ) 153 | choices.append(choice) 154 | 155 | # Parse usage 156 | usage_data = response_data.get("usage", {}) 157 | usage = Usage( 158 | prompt_tokens=usage_data.get("prompt_tokens", 0), 159 | completion_tokens=usage_data.get("completion_tokens", 0), 160 | total_tokens=usage_data.get("total_tokens", 0), 161 | ) 162 | 163 | # Extract only the keys we want to avoid conflicts 164 | raw_response = { 165 | k: v 166 | for k, v in response_data.items() 167 | if k not in ["id", "model", "choices", "usage", "created"] 168 | } 169 | 170 | return GrokResponse( 171 | id=response_data.get("id", f"grok-{int(time.time())}"), 172 | model=model, 173 | choices=choices, 174 | usage=usage, 175 | created=response_data.get("created", int(time.time())), 176 | **raw_response, 177 | ) 178 | 179 | @retry( 180 | stop=stop_after_attempt(3), 181 | wait=wait_exponential(multiplier=1, min=4, max=10), 182 | ) 183 | def complete( 184 | self, 185 | messages: List[Message], 186 | model: str, 187 | timeout: Optional[float] = None, 188 | **kwargs: Any, 189 | ) -> BaseResponse: 190 | """Synchronous completion. 191 | 192 | Args: 193 | messages: List of messages for the completion. 194 | model: Model identifier to use. 195 | timeout: Optional timeout in seconds. If None, no timeout is enforced. 196 | **kwargs: Additional provider-specific parameters. 197 | """ 198 | url = self._get_api_endpoint() 199 | 200 | # Format request 201 | request_data = { 202 | "model": model, 203 | "messages": self._format_messages(messages), 204 | **{ 205 | k: v 206 | for k, v in kwargs.items() 207 | if k 208 | in [ 209 | "temperature", 210 | "max_tokens", 211 | "top_p", 212 | "frequency_penalty", 213 | "presence_penalty", 214 | "stop", 215 | ] 216 | and v is not None 217 | }, 218 | } 219 | 220 | timeout_config = timeout if timeout is not None else DEFAULT_TIMEOUT 221 | 222 | with httpx.Client(timeout=timeout_config) as client: 223 | response = client.post( 224 | url, 225 | json=request_data, 226 | headers=self._get_headers(), 227 | ) 228 | 229 | if response.status_code != 200: 230 | raise ProviderError(f"Grok API error: {response.status_code} - {response.text}") 231 | 232 | return self._parse_response(response.json(), model) 233 | -------------------------------------------------------------------------------- /justllms/routing/router.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple, Union 2 | 3 | from justllms.core.base import BaseProvider 4 | from justllms.core.models import Message 5 | from justllms.exceptions import ProviderError 6 | 7 | 8 | class Router: 9 | """Simple provider and model selector with fallback support. 10 | 11 | Handles model selection logic: 12 | 1. If model explicitly specified (e.g., "provider/model"), use it 13 | 2. Else if fallback configured, use fallback 14 | 3. Else use first available provider/model 15 | """ 16 | 17 | def __init__( 18 | self, 19 | config: Optional[Union[Dict[str, Any], Any]] = None, 20 | fallback_provider: Optional[str] = None, 21 | fallback_model: Optional[str] = None, 22 | ): 23 | """Initialize the router. 24 | 25 | Args: 26 | config: Optional config dict or RoutingConfig object. 27 | fallback_provider: Optional fallback provider name. 28 | fallback_model: Optional fallback model name. 29 | """ 30 | # Handle both dict and RoutingConfig object 31 | if config is not None and hasattr(config, "model_dump"): 32 | # It's a Pydantic model, convert to dict 33 | self.config = config.model_dump() 34 | else: 35 | self.config = config or {} 36 | 37 | # Get fallback values from config if not provided 38 | self.fallback_provider = fallback_provider or self.config.get("fallback_provider") 39 | self.fallback_model = fallback_model or self.config.get("fallback_model") 40 | 41 | def route( # noqa: C901 42 | self, 43 | messages: List[Message], 44 | model: Optional[str] = None, 45 | providers: Optional[Dict[str, BaseProvider]] = None, 46 | constraints: Optional[Dict[str, Any]] = None, 47 | **kwargs: Any, 48 | ) -> Tuple[str, str]: 49 | """Select provider and model using fallback logic. 50 | 51 | Args: 52 | messages: The messages to process (unused in selection). 53 | model: Optional specific model requested. 54 | providers: Available providers. 55 | constraints: Additional constraints (unused in selection). 56 | **kwargs: Additional parameters. 57 | 58 | Returns: 59 | Tuple of (provider_name, model_name) 60 | 61 | Raises: 62 | ValueError: If no providers or suitable models available. 63 | """ 64 | if not providers: 65 | raise ValueError("No providers available") 66 | 67 | # If specific model requested, try to find it 68 | if model: 69 | # Check if it's in format "provider/model" 70 | if "/" in model: 71 | provider_name, model_name = model.split("/", 1) 72 | if provider_name not in providers: 73 | raise ValueError(f"Provider '{provider_name}' not found") 74 | 75 | provider = providers[provider_name] 76 | if not provider.validate_model(model_name): 77 | raise ValueError( 78 | f"Model '{model_name}' not found in provider '{provider_name}'" 79 | ) 80 | 81 | return provider_name, model_name 82 | 83 | # Check all providers for the model 84 | for provider_name, provider in providers.items(): 85 | if provider.validate_model(model): 86 | return provider_name, model 87 | 88 | raise ValueError(f"Model '{model}' not found in any available provider") 89 | 90 | # No specific model requested - use fallback or first available 91 | # First, try configured fallback if provided 92 | if self.fallback_provider and self.fallback_model and self.fallback_provider in providers: 93 | provider = providers[self.fallback_provider] 94 | available_models = provider.get_available_models() 95 | if self.fallback_model in available_models: 96 | return self.fallback_provider, self.fallback_model 97 | 98 | # Fall back to first available provider and model 99 | for provider_name, provider in providers.items(): 100 | models = provider.get_available_models() 101 | if models: 102 | model_name = list(models.keys())[0] 103 | return provider_name, model_name 104 | 105 | raise ValueError("No models available in any provider") 106 | 107 | def route_streaming( 108 | self, 109 | messages: List[Message], 110 | providers: Dict[str, BaseProvider], 111 | model: Optional[str] = None, 112 | constraints: Optional[Dict[str, Any]] = None, 113 | **kwargs: Any, 114 | ) -> Tuple[str, str]: 115 | """Select provider and model for streaming requests. 116 | 117 | Filters providers to only those supporting streaming before selection. 118 | 119 | Args: 120 | messages: The messages to process. 121 | providers: Available providers. 122 | model: Optional specific model requested. 123 | constraints: Additional constraints for routing. 124 | **kwargs: Additional parameters. 125 | 126 | Returns: 127 | Tuple of (provider_name, model_name). 128 | 129 | Raises: 130 | ValueError: If no streaming-capable providers available. 131 | """ 132 | # Filter to streaming-capable providers 133 | streaming_providers = { 134 | name: provider for name, provider in providers.items() if provider.supports_streaming() 135 | } 136 | 137 | if not streaming_providers: 138 | raise ProviderError( 139 | "No streaming-capable providers configured. " 140 | "Enable openai, azure_openai, google, or ollama providers, or set stream=False." 141 | ) 142 | 143 | # If specific model requested, validate streaming support 144 | if model: 145 | # Use normal route logic first 146 | provider_name, model_name = self.route( 147 | messages, 148 | model=model, 149 | providers=streaming_providers, 150 | constraints=constraints, 151 | **kwargs, 152 | ) 153 | 154 | # Validate model supports streaming 155 | provider = streaming_providers[provider_name] 156 | if not provider.supports_streaming_for_model(model_name): 157 | raise ProviderError( 158 | f"Model '{model_name}' does not support streaming on provider '{provider_name}'. " 159 | f"Use stream=False or choose a different model." 160 | ) 161 | 162 | return provider_name, model_name 163 | 164 | # Route among streaming providers 165 | provider_name, model_name = self.route( 166 | messages, providers=streaming_providers, constraints=constraints, **kwargs 167 | ) 168 | 169 | # Double-check model supports streaming 170 | provider = streaming_providers[provider_name] 171 | if not provider.supports_streaming_for_model(model_name): 172 | # Try to find another model from this provider that supports streaming 173 | for available_model in provider.get_available_models(): 174 | if provider.supports_streaming_for_model(available_model): 175 | return provider_name, available_model 176 | 177 | raise ProviderError( 178 | f"Model '{model_name}' does not support streaming on provider '{provider_name}'. " 179 | f"Use stream=False or choose a different model." 180 | ) 181 | 182 | return provider_name, model_name 183 | 184 | def route_with_tools( 185 | self, 186 | messages: List[Message], 187 | providers: Dict[str, BaseProvider], 188 | model: Optional[str] = None, 189 | constraints: Optional[Dict[str, Any]] = None, 190 | **kwargs: Any, 191 | ) -> Tuple[str, str]: 192 | """Select provider and model for tool calling requests. 193 | 194 | Filters providers to only those supporting tools before selection. 195 | 196 | Args: 197 | messages: The messages to process. 198 | providers: Available providers. 199 | model: Optional specific model requested. 200 | constraints: Additional constraints for routing. 201 | **kwargs: Additional parameters. 202 | 203 | Returns: 204 | Tuple of (provider_name, model_name). 205 | 206 | Raises: 207 | ValueError: If no tool-capable providers available. 208 | """ 209 | # Filter to tool-capable providers 210 | tool_providers = { 211 | name: provider for name, provider in providers.items() if provider.supports_tools 212 | } 213 | 214 | if not tool_providers: 215 | raise ProviderError( 216 | "No tool-capable providers configured. " 217 | "Enable openai, anthropic, google, or azure_openai providers." 218 | ) 219 | 220 | # If specific model requested, route to it 221 | if model: 222 | provider_name, model_name = self.route( 223 | messages, 224 | model=model, 225 | providers=tool_providers, 226 | constraints=constraints, 227 | **kwargs, 228 | ) 229 | return provider_name, model_name 230 | 231 | # Route among tool-capable providers 232 | provider_name, model_name = self.route( 233 | messages, providers=tool_providers, constraints=constraints, **kwargs 234 | ) 235 | 236 | return provider_name, model_name 237 | -------------------------------------------------------------------------------- /justllms/tools/executor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import threading 3 | import time 4 | from typing import Any, Dict, List, Optional 5 | 6 | from justllms.core.base import BaseResponse 7 | from justllms.tools.models import Tool, ToolCall, ToolExecutionEntry, ToolResult, ToolResultStatus 8 | from justllms.tools.utils import validate_tool_arguments 9 | 10 | 11 | class ToolExecutor: 12 | """Executes tools sequentially with error handling. 13 | 14 | This executor handles tool validation, argument parsing, execution, 15 | and error recovery. It does NOT support parallel execution - all 16 | tools run sequentially. 17 | 18 | Attributes: 19 | tools: Dictionary mapping tool names to Tool instances. 20 | timeout: Maximum execution time per tool in seconds. 21 | execute_in_parallel: Always False (no parallel execution). 22 | """ 23 | 24 | def __init__( 25 | self, 26 | tools: List[Tool], 27 | execute_in_parallel: bool = False, 28 | timeout: float = 30.0, 29 | ): 30 | """Initialize the tool executor. 31 | 32 | Args: 33 | tools: List of Tool instances available for execution. 34 | execute_in_parallel: Ignored (always False, no parallel support). 35 | timeout: Maximum execution time per tool in seconds. 36 | """ 37 | self.tools = {tool.name: tool for tool in tools} 38 | self.execute_in_parallel = False # Always False per requirements 39 | self.timeout = timeout 40 | self._execution_count = 0 41 | 42 | def execute_tool_call(self, tool_call: ToolCall) -> ToolResult: 43 | """Execute a single tool call with timeout and error handling. 44 | 45 | Args: 46 | tool_call: The tool call to execute. 47 | 48 | Returns: 49 | ToolResult with execution outcome. 50 | """ 51 | start_time = time.time() 52 | 53 | # Find the tool 54 | tool = self.tools.get(tool_call.name) 55 | if not tool: 56 | return ToolResult( 57 | tool_call_id=tool_call.id, 58 | result=None, 59 | error=f"Tool '{tool_call.name}' not found", 60 | execution_time_ms=(time.time() - start_time) * 1000, 61 | status=ToolResultStatus.ERROR, 62 | ) 63 | 64 | # Validate and prepare arguments 65 | try: 66 | validated_args = validate_tool_arguments(tool, tool_call.arguments) 67 | except ValueError as e: 68 | return ToolResult( 69 | tool_call_id=tool_call.id, 70 | result=None, 71 | error=f"Invalid arguments: {str(e)}", 72 | execution_time_ms=(time.time() - start_time) * 1000, 73 | status=ToolResultStatus.ERROR, 74 | ) 75 | 76 | # Execute with timeout 77 | result_container: Dict[str, Any] = {} 78 | 79 | def execute_tool() -> None: 80 | """Execute tool in separate thread for timeout control.""" 81 | try: 82 | result_container["result"] = tool.callable(**validated_args) 83 | result_container["success"] = True 84 | except Exception as e: 85 | result_container["error"] = str(e) 86 | result_container["success"] = False 87 | 88 | # Run with timeout 89 | thread = threading.Thread(target=execute_tool, daemon=True) 90 | thread.start() 91 | thread.join(timeout=self.timeout) 92 | 93 | execution_time_ms = (time.time() - start_time) * 1000 94 | 95 | # Check timeout 96 | if thread.is_alive(): 97 | return ToolResult( 98 | tool_call_id=tool_call.id, 99 | result=None, 100 | error=f"Tool execution timed out after {self.timeout}s", 101 | execution_time_ms=execution_time_ms, 102 | status=ToolResultStatus.TIMEOUT, 103 | ) 104 | 105 | # Check for errors 106 | if not result_container.get("success", False): 107 | return ToolResult( 108 | tool_call_id=tool_call.id, 109 | result=None, 110 | error=result_container.get("error", "Unknown error"), 111 | execution_time_ms=execution_time_ms, 112 | status=ToolResultStatus.ERROR, 113 | ) 114 | 115 | # Success 116 | return ToolResult( 117 | tool_call_id=tool_call.id, 118 | result=result_container.get("result"), 119 | error=None, 120 | execution_time_ms=execution_time_ms, 121 | status=ToolResultStatus.SUCCESS, 122 | ) 123 | 124 | def _extract_tool_calls(self, response: BaseResponse) -> List[ToolCall]: 125 | """Extract tool calls from a response. 126 | 127 | Args: 128 | response: Response from provider. 129 | 130 | Returns: 131 | List of ToolCall objects. 132 | """ 133 | tool_calls = [] 134 | 135 | # Check if response has choices with messages containing tool calls 136 | if response.choices: 137 | for choice in response.choices: 138 | message = choice.message 139 | 140 | # Check for tool_calls in message 141 | if message.tool_calls: 142 | for tc in message.tool_calls: 143 | # Handle different formats 144 | if isinstance(tc, dict): 145 | # Extract from dict format 146 | if "function" in tc: 147 | # OpenAI format 148 | func = tc["function"] 149 | arguments_str = func.get("arguments", "{}") 150 | try: 151 | arguments = json.loads(arguments_str) 152 | except json.JSONDecodeError: 153 | arguments = {} 154 | 155 | tool_call = ToolCall( 156 | id=tc.get("id", ""), 157 | name=func.get("name", ""), 158 | arguments=arguments, 159 | raw_arguments=arguments_str, 160 | ) 161 | tool_calls.append(tool_call) 162 | else: 163 | # Direct format 164 | tool_call = ToolCall( 165 | id=tc.get("id", ""), 166 | name=tc.get("name", ""), 167 | arguments=tc.get("arguments", {}), 168 | raw_arguments=tc.get("raw_arguments"), 169 | ) 170 | tool_calls.append(tool_call) 171 | 172 | return tool_calls 173 | 174 | def _create_assistant_message( 175 | self, response: BaseResponse, tool_calls: List[ToolCall] 176 | ) -> Dict[str, Any]: 177 | """Create assistant message with tool calls. 178 | 179 | Args: 180 | response: Original response from provider. 181 | tool_calls: List of tool calls to include. 182 | 183 | Returns: 184 | Message dict for conversation history. 185 | """ 186 | # Format tool calls for message 187 | tool_calls_data = [] 188 | for tc in tool_calls: 189 | tool_call_dict = { 190 | "id": tc.id, 191 | "type": "function", 192 | "function": { 193 | "name": tc.name, 194 | "arguments": tc.raw_arguments or json.dumps(tc.arguments), 195 | }, 196 | } 197 | tool_calls_data.append(tool_call_dict) 198 | 199 | return { 200 | "role": "assistant", 201 | "content": response.content or "", 202 | "tool_calls": tool_calls_data, 203 | } 204 | 205 | def format_tool_result_message(self, result: ToolResult) -> Dict[str, Any]: 206 | """Format tool result as a message. 207 | 208 | Args: 209 | result: Tool execution result. 210 | 211 | Returns: 212 | Message dict with tool result. 213 | """ 214 | return { 215 | "role": "tool", 216 | "content": result.to_message_content(), 217 | "tool_call_id": result.tool_call_id, 218 | } 219 | 220 | def execute_all(self, tool_calls: List[ToolCall]) -> List[ToolResult]: 221 | """Execute all tool calls sequentially. 222 | 223 | Args: 224 | tool_calls: List of tool calls to execute. 225 | 226 | Returns: 227 | List of ToolResult objects. 228 | """ 229 | results = [] 230 | for tool_call in tool_calls: 231 | result = self.execute_tool_call(tool_call) 232 | results.append(result) 233 | return results 234 | 235 | def create_execution_entry( 236 | self, 237 | iteration: int, 238 | tool_call: ToolCall, 239 | tool_result: ToolResult, 240 | messages: Optional[List[Dict[str, Any]]] = None, 241 | ) -> ToolExecutionEntry: 242 | """Create an execution history entry. 243 | 244 | Args: 245 | iteration: The iteration number. 246 | tool_call: The tool call that was executed. 247 | tool_result: The result of execution. 248 | messages: Optional messages generated. 249 | 250 | Returns: 251 | ToolExecutionEntry for history tracking. 252 | """ 253 | return ToolExecutionEntry( 254 | iteration=iteration, 255 | tool_call=tool_call, 256 | tool_result=tool_result, 257 | messages=messages or [], 258 | ) 259 | 260 | @staticmethod 261 | def calculate_total_cost(execution_history: List[ToolExecutionEntry]) -> float: 262 | """Calculate total cost from execution history. 263 | 264 | Args: 265 | execution_history: List of tool execution entries. 266 | 267 | Returns: 268 | Total estimated cost in USD. 269 | """ 270 | total_cost = 0.0 271 | for entry in execution_history: 272 | if entry.tool_result.cost is not None: 273 | total_cost += entry.tool_result.cost 274 | return total_cost 275 | -------------------------------------------------------------------------------- /justllms/core/streaming.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from datetime import datetime 3 | from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Dict, Iterator, List, Optional 4 | 5 | import httpx 6 | 7 | from justllms.core.models import Choice, Message, Role, Usage 8 | 9 | if TYPE_CHECKING: 10 | from justllms.core.base import BaseProvider 11 | from justllms.core.completion import CompletionResponse 12 | 13 | 14 | def parse_sse_stream( 15 | url: str, 16 | payload: Dict[str, Any], 17 | headers: Dict[str, str], 18 | parse_chunk_fn: Callable[[str], Optional["StreamChunk"]], 19 | timeout: Optional[float] = None, 20 | error_prefix: str = "Streaming request", 21 | ) -> Iterator["StreamChunk"]: 22 | """Parse Server-Sent Events (SSE) stream from an HTTP endpoint. 23 | 24 | This is a shared helper for streaming responses that follow the SSE protocol. 25 | Both OpenAI-compatible and Azure OpenAI providers use this format. 26 | 27 | Args: 28 | url: API endpoint URL. 29 | payload: Request payload (should have stream=True). 30 | headers: Request headers including authorization. 31 | parse_chunk_fn: Callback to parse SSE line into StreamChunk. 32 | timeout: Optional timeout in seconds. 33 | error_prefix: Prefix for error messages (e.g., "OpenAI streaming request"). 34 | 35 | Yields: 36 | StreamChunk objects parsed from the SSE stream. 37 | 38 | Raises: 39 | ProviderError: If the streaming request fails. 40 | """ 41 | from justllms.exceptions import ProviderError 42 | 43 | try: 44 | with httpx.Client(timeout=timeout) as client, client.stream( 45 | "POST", url, json=payload, headers=headers 46 | ) as response: 47 | response.raise_for_status() 48 | 49 | for line in response.iter_lines(): 50 | chunk = parse_chunk_fn(line) 51 | if chunk is not None: 52 | yield chunk 53 | elif line.strip() == "data: [DONE]": 54 | break 55 | except (httpx.HTTPError, httpx.RequestError) as e: 56 | raise ProviderError(f"{error_prefix} failed: {str(e)}") from e 57 | 58 | 59 | class StreamChunk: 60 | """Individual chunk from a streaming response.""" 61 | 62 | def __init__( 63 | self, 64 | content: Optional[str] = None, 65 | finish_reason: Optional[str] = None, 66 | usage: Optional[Usage] = None, 67 | raw: Any = None, 68 | ): 69 | """Initialize a stream chunk. 70 | 71 | Args: 72 | content: Text content in this chunk. 73 | finish_reason: Reason streaming stopped (if final chunk). 74 | usage: Token usage info (if available). 75 | raw: Raw provider response for debugging. 76 | """ 77 | self.content = content 78 | self.finish_reason = finish_reason 79 | self.usage = usage 80 | self.raw = raw 81 | 82 | 83 | class StreamResponse: 84 | """Accumulates streaming chunks and builds final CompletionResponse.""" 85 | 86 | def __init__(self, provider: "BaseProvider", model: str, messages: List[Message]): 87 | """Initialize stream accumulator. 88 | 89 | Args: 90 | provider: Provider instance for cost estimation. 91 | model: Model name for token counting. 92 | messages: Original messages for token counting. 93 | """ 94 | self.provider = provider 95 | self.model = model 96 | self.messages = messages 97 | self.id = f"chatcmpl-{uuid.uuid4().hex[:29]}" 98 | self.created = int(datetime.now().timestamp()) 99 | self._content_chunks: List[str] = [] 100 | self._finish_reason: Optional[str] = None 101 | self._usage: Optional[Usage] = None 102 | self.completed = False 103 | 104 | def accumulate(self, chunk: StreamChunk) -> None: 105 | """Accumulate content from a stream chunk. 106 | 107 | Args: 108 | chunk: Stream chunk to accumulate. 109 | """ 110 | if chunk.content: 111 | self._content_chunks.append(chunk.content) 112 | if chunk.finish_reason: 113 | self._finish_reason = chunk.finish_reason 114 | if chunk.usage: 115 | self._usage = chunk.usage 116 | 117 | def mark_complete(self) -> None: 118 | """Mark stream as fully consumed.""" 119 | self.completed = True 120 | 121 | def to_completion_response(self) -> "CompletionResponse": 122 | """Build CompletionResponse from accumulated chunks. 123 | 124 | Returns: 125 | CompletionResponse with proper structure. 126 | 127 | Raises: 128 | RuntimeError: If stream not fully consumed. 129 | """ 130 | if not self.completed: 131 | raise RuntimeError( 132 | "Stream not fully consumed. Iterate through all chunks first " 133 | "or call drain() to consume remaining chunks." 134 | ) 135 | 136 | from justllms.core.completion import CompletionResponse 137 | 138 | # Build Message 139 | message = Message(role=Role.ASSISTANT, content="".join(self._content_chunks)) 140 | 141 | # Build Choice 142 | choice = Choice(index=0, message=message, finish_reason=self._finish_reason or "stop") 143 | 144 | # Build or estimate Usage 145 | if not self._usage: 146 | prompt_tokens = self.provider.count_message_tokens(self.messages, self.model) 147 | completion_tokens = self.provider.count_tokens(message.content, self.model) 148 | self._usage = Usage( 149 | prompt_tokens=prompt_tokens, 150 | completion_tokens=completion_tokens, 151 | total_tokens=prompt_tokens + completion_tokens, 152 | ) 153 | 154 | # Set cost on Usage object 155 | self._usage.estimated_cost = self.provider.estimate_cost(self._usage, self.model) 156 | 157 | return CompletionResponse( 158 | id=self.id, 159 | model=self.model, 160 | choices=[choice], 161 | usage=self._usage, 162 | created=self.created, 163 | provider=self.provider.name, 164 | ) 165 | 166 | 167 | class SyncStreamResponse: 168 | """Synchronous streaming response.""" 169 | 170 | def __init__( 171 | self, 172 | provider: "BaseProvider", 173 | model: str, 174 | messages: List[Message], 175 | raw_stream: Iterator[StreamChunk], 176 | ): 177 | """Initialize sync stream response. 178 | 179 | Args: 180 | provider: Provider instance. 181 | model: Model name. 182 | messages: Original messages. 183 | raw_stream: Iterator of StreamChunks. 184 | """ 185 | self.accumulator = StreamResponse(provider, model, messages) 186 | self.raw_stream = raw_stream 187 | self._iterator_started = False 188 | 189 | def __iter__(self) -> Iterator[StreamChunk]: 190 | """Iterate over stream chunks. 191 | 192 | Yields: 193 | StreamChunk objects. 194 | 195 | Raises: 196 | RuntimeError: If iteration already started from a different iterator. 197 | """ 198 | if self._iterator_started: 199 | raise RuntimeError( 200 | "Stream iteration already started. Cannot create multiple iterators. " 201 | "Use a single for loop or call get_final_response() to consume remaining chunks." 202 | ) 203 | self._iterator_started = True 204 | 205 | for chunk in self.raw_stream: 206 | self.accumulator.accumulate(chunk) 207 | yield chunk 208 | 209 | self.accumulator.mark_complete() 210 | 211 | def drain(self) -> None: 212 | """Consume remaining chunks without yielding. 213 | 214 | Safe to call at any time - will consume from current position. 215 | """ 216 | if self.accumulator.completed: 217 | return # Already finished 218 | 219 | if not self._iterator_started: 220 | # Haven't started yet - consume entire stream 221 | for _ in self: 222 | pass 223 | else: 224 | # Already started - continue from current position 225 | for chunk in self.raw_stream: 226 | self.accumulator.accumulate(chunk) 227 | self.accumulator.mark_complete() 228 | 229 | def get_final_response(self) -> "CompletionResponse": 230 | """Get final CompletionResponse. 231 | 232 | Automatically drains stream if not yet fully consumed. 233 | 234 | Returns: 235 | CompletionResponse with cost and usage. 236 | """ 237 | if not self.accumulator.completed: 238 | self.drain() 239 | return self.accumulator.to_completion_response() 240 | 241 | 242 | class AsyncStreamResponse: 243 | """Asynchronous streaming response.""" 244 | 245 | def __init__( 246 | self, 247 | provider: "BaseProvider", 248 | model: str, 249 | messages: List[Message], 250 | async_stream: AsyncIterator[StreamChunk], 251 | ): 252 | """Initialize async stream response. 253 | 254 | Args: 255 | provider: Provider instance. 256 | model: Model name. 257 | messages: Original messages. 258 | async_stream: Async iterator of StreamChunks. 259 | """ 260 | self.accumulator = StreamResponse(provider, model, messages) 261 | self.async_stream = async_stream 262 | self._iterator_started = False 263 | 264 | async def __aiter__(self) -> AsyncIterator[StreamChunk]: 265 | """Async iterate over stream chunks. 266 | 267 | Yields: 268 | StreamChunk objects. 269 | 270 | Raises: 271 | RuntimeError: If iteration already started from a different iterator. 272 | """ 273 | if self._iterator_started: 274 | raise RuntimeError( 275 | "Stream iteration already started. Cannot create multiple iterators. " 276 | "Use a single async for loop or call get_final_response() to consume remaining chunks." 277 | ) 278 | self._iterator_started = True 279 | 280 | async for chunk in self.async_stream: 281 | self.accumulator.accumulate(chunk) 282 | yield chunk 283 | 284 | self.accumulator.mark_complete() 285 | 286 | async def drain(self) -> None: 287 | """Consume remaining chunks without yielding. 288 | 289 | Safe to call at any time - will consume from current position. 290 | """ 291 | if self.accumulator.completed: 292 | return # Already finished 293 | 294 | if not self._iterator_started: 295 | # Haven't started yet - consume entire stream 296 | async for _ in self: 297 | pass 298 | else: 299 | # Already started - continue from current position 300 | async for chunk in self.async_stream: 301 | self.accumulator.accumulate(chunk) 302 | self.accumulator.mark_complete() 303 | 304 | async def get_final_response(self) -> "CompletionResponse": 305 | """Get final CompletionResponse. 306 | 307 | Automatically drains stream if not yet fully consumed. 308 | 309 | Returns: 310 | CompletionResponse with cost and usage. 311 | """ 312 | if not self.accumulator.completed: 313 | await self.drain() 314 | return self.accumulator.to_completion_response() 315 | -------------------------------------------------------------------------------- /justllms/tools/adapters/google.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Any, Dict, List, Optional, Union 3 | 4 | from justllms.core.models import Message, Role 5 | from justllms.tools.adapters.base import BaseToolAdapter 6 | from justllms.tools.models import Tool, ToolCall, ToolResult 7 | 8 | 9 | class GoogleToolAdapter(BaseToolAdapter): 10 | """Adapter for Google Gemini's function calling format. 11 | 12 | Gemini uses yet another format: 13 | - Functions are defined in a tools array with functionDeclarations 14 | - Tool choice uses toolConfig with functionCallingConfig 15 | - Function calls come back in functionCall objects 16 | - Gemini also supports native tools like google_search 17 | """ 18 | 19 | def format_tools_for_api( 20 | self, tools: List[Tool], include_native: bool = True 21 | ) -> List[Dict[str, Any]]: 22 | """Convert Tool objects to Gemini's function format. 23 | 24 | Args: 25 | tools: List of Tool instances (user-defined and/or native). 26 | include_native: Whether to include native Google tools in API format. 27 | 28 | Returns: 29 | List containing tool configuration for Gemini. 30 | 31 | Example output: 32 | [ 33 | # Native tools in their special format 34 | {"googleSearch": {}}, 35 | {"codeExecution": {}}, 36 | # User tools in function declarations 37 | { 38 | "functionDeclarations": [ 39 | { 40 | "name": "get_weather", 41 | "description": "Get weather for location", 42 | "parameters": { 43 | "type": "object", 44 | "properties": { 45 | "location": { 46 | "type": "string", 47 | "description": "City name" 48 | } 49 | }, 50 | "required": ["location"] 51 | } 52 | } 53 | ] 54 | } 55 | ] 56 | """ 57 | api_tools = [] 58 | function_declarations = [] 59 | 60 | for tool in tools: 61 | # Handle native Google tools - they use a different format 62 | if tool.is_native and tool.namespace == "google" and include_native: 63 | if hasattr(tool, "to_api_format"): 64 | # Native tools have their own API format method 65 | api_tools.append(tool.to_api_format()) 66 | continue 67 | 68 | # Regular user-defined tools 69 | function_def = { 70 | "name": tool.name, 71 | "description": tool.description, 72 | "parameters": tool.to_json_schema(), 73 | } 74 | 75 | function_declarations.append(function_def) 76 | 77 | # Add user tools as function declarations 78 | if function_declarations: 79 | api_tools.append({"functionDeclarations": function_declarations}) 80 | 81 | return api_tools 82 | 83 | def format_tools_with_native( 84 | self, user_tools: List[Tool], native_tools: List[Any] 85 | ) -> List[Dict[str, Any]]: 86 | """Format both user-defined and native tools for Gemini API. 87 | 88 | This method handles the new importable API where users pass native tools 89 | directly: tools=[GoogleSearch(), multiply] 90 | 91 | Args: 92 | user_tools: List of user-defined Tool objects. 93 | native_tools: List of native tool objects (e.g., GoogleSearch, GoogleCodeExecution). 94 | 95 | Returns: 96 | List of tool configurations in Gemini format: 97 | [ 98 | {"google_search": {}}, 99 | {"code_execution": {}}, 100 | {"function_declarations": [...]} 101 | ] 102 | 103 | Note: When mixing native tools with user tools, we use "function_declarations" 104 | (snake_case) as per Gemini's live-tools documentation. 105 | """ 106 | api_tools = [] 107 | 108 | # Add native tools first (they have their own to_api_format method) 109 | for native_tool in native_tools: 110 | if hasattr(native_tool, "to_api_format"): 111 | api_tools.append(native_tool.to_api_format()) 112 | 113 | # Add user-defined tools as function declarations 114 | # When mixing with native tools, use snake_case key as per Gemini live-tools docs 115 | if user_tools: 116 | function_declarations = [] 117 | for tool in user_tools: 118 | function_def = { 119 | "name": tool.name, 120 | "description": tool.description, 121 | "parameters": tool.to_json_schema(), 122 | } 123 | function_declarations.append(function_def) 124 | 125 | # Use snake_case when there are native tools, camelCase otherwise 126 | key = "function_declarations" if native_tools else "functionDeclarations" 127 | api_tools.append({key: function_declarations}) 128 | 129 | return api_tools 130 | 131 | def format_tool_choice( 132 | self, tool_choice: Optional[Union[str, Dict[str, Any]]] 133 | ) -> Optional[Dict[str, Any]]: 134 | """Format tool_choice as toolConfig for Gemini API. 135 | 136 | Args: 137 | tool_choice: Tool selection strategy: 138 | - "auto": Let model decide 139 | - "none": Don't use functions 140 | - "required": Must call a function (ANY mode) 141 | - {"name": "function_name"}: Not directly supported 142 | - None: Use default (auto) 143 | 144 | Returns: 145 | toolConfig dict for Gemini API. 146 | """ 147 | if tool_choice is None or tool_choice == "auto": 148 | # AUTO mode - model decides 149 | return {"functionCallingConfig": {"mode": "AUTO"}} 150 | 151 | if tool_choice == "none": 152 | # NONE mode - no function calls 153 | return {"functionCallingConfig": {"mode": "NONE"}} 154 | 155 | if tool_choice == "required": 156 | # ANY mode - must call at least one function 157 | return {"functionCallingConfig": {"mode": "ANY"}} 158 | 159 | # Gemini doesn't support forcing specific functions directly 160 | # Default to AUTO 161 | return {"functionCallingConfig": {"mode": "AUTO"}} 162 | 163 | def extract_tool_calls(self, response: Dict[str, Any]) -> List[ToolCall]: 164 | """Extract tool calls from Gemini response. 165 | 166 | Args: 167 | response: Raw response from Gemini API. 168 | 169 | Returns: 170 | List of ToolCall objects. 171 | """ 172 | tool_calls = [] 173 | 174 | # Navigate Gemini's response structure 175 | candidates = response.get("candidates", []) 176 | for candidate in candidates: 177 | content = candidate.get("content", {}) 178 | parts = content.get("parts", []) 179 | 180 | for part in parts: 181 | # Check for function calls 182 | if "functionCall" in part: 183 | function_call = part["functionCall"] 184 | 185 | # Extract arguments 186 | args = function_call.get("args", {}) 187 | 188 | tool_call = ToolCall( 189 | id=f"call_{function_call.get('name', 'unknown')}", 190 | name=function_call.get("name", ""), 191 | arguments=args, 192 | raw_arguments=json.dumps(args), 193 | ) 194 | tool_calls.append(tool_call) 195 | 196 | return tool_calls 197 | 198 | def format_tool_result_message(self, tool_result: ToolResult, tool_call: ToolCall) -> Message: 199 | """Format tool result as a message for Gemini. 200 | 201 | Gemini expects function responses in a specific format. 202 | 203 | Args: 204 | tool_result: Result from tool execution. 205 | tool_call: Original tool call. 206 | 207 | Returns: 208 | Message with function response for Gemini. 209 | """ 210 | # Gemini expects function responses as parts 211 | # Pass as list directly so GoogleProvider._format_messages can use it 212 | parts = [ 213 | { 214 | "functionResponse": { 215 | "name": tool_call.name, 216 | "response": {"result": tool_result.to_message_content()}, 217 | } 218 | } 219 | ] 220 | 221 | # Gemini uses USER role for function responses 222 | # Pass parts list directly, not as JSON string 223 | return Message( 224 | role=Role.USER, 225 | content=parts, # Pass as list for GoogleProvider to handle 226 | ) 227 | 228 | def format_tool_calls_message(self, tool_calls: List[ToolCall]) -> Optional[Message]: 229 | """Format tool calls as an assistant message for Gemini. 230 | 231 | Args: 232 | tool_calls: List of tool calls. 233 | 234 | Returns: 235 | Assistant message with function calls. 236 | """ 237 | if not tool_calls: 238 | return None 239 | 240 | parts = [] 241 | for tc in tool_calls: 242 | part = {"functionCall": {"name": tc.name, "args": tc.arguments}} 243 | parts.append(part) 244 | 245 | # Pass parts list directly for GoogleProvider to handle 246 | return Message( 247 | role=Role.ASSISTANT, 248 | content=parts, # Pass as list, not JSON string 249 | ) 250 | 251 | def merge_native_tools( 252 | self, user_tools: List[Tool], native_config: Optional[Dict[str, Any]] 253 | ) -> List[Tool]: 254 | """Merge user-defined tools with Google's native tools. 255 | 256 | Args: 257 | user_tools: User-defined Tool instances. 258 | native_config: Configuration for native tools from provider config. 259 | 260 | Returns: 261 | Combined list of tools (native tools + user tools). 262 | 263 | Example: 264 | native_config = { 265 | "google_search": {"enabled": True}, 266 | "code_execution": {"enabled": True} 267 | } 268 | merged = adapter.merge_native_tools(user_tools, native_config) 269 | """ 270 | if not native_config: 271 | return user_tools 272 | 273 | # Use NativeToolManager to load and merge tools 274 | from justllms.tools.native.manager import GoogleNativeToolManager 275 | 276 | manager = GoogleNativeToolManager(config=native_config) 277 | 278 | # Merge with prefer_native=True so native tools take precedence 279 | merged_tools = manager.merge_with_user_tools(user_tools, prefer_native=True) 280 | 281 | return merged_tools 282 | 283 | def supports_parallel_tools(self) -> bool: 284 | """Gemini supports calling multiple functions in one response.""" 285 | return True 286 | 287 | def supports_required_tools(self) -> bool: 288 | """Gemini supports ANY mode which is like required.""" 289 | return True 290 | 291 | def get_max_tools_per_call(self) -> Optional[int]: 292 | """Gemini supports up to 64 function declarations.""" 293 | return 64 294 | -------------------------------------------------------------------------------- /justllms/core/openai_base.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from typing import Any, Dict, Iterator, List, Optional 4 | 5 | from justllms.core.base import BaseProvider, BaseResponse 6 | from justllms.core.models import Message 7 | from justllms.core.streaming import StreamChunk, SyncStreamResponse 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class BaseOpenAIChatProvider(BaseProvider): 13 | """Base class for providers using OpenAI-compatible chat API. 14 | 15 | Provides common functionality for providers that follow the OpenAI chat 16 | completions API format, including standardized message formatting, 17 | request construction, and response parsing. 18 | 19 | Subclasses need to implement: 20 | - name: Provider identifier 21 | - get_available_models(): Provider's model catalog 22 | - _get_api_endpoint(): API endpoint URL 23 | - _get_request_headers(): Authentication headers 24 | - _customize_payload(): Provider-specific request modifications (optional) 25 | """ 26 | 27 | def _format_messages_openai(self, messages: List[Message]) -> List[Dict[str, Any]]: 28 | """Format messages for OpenAI-compatible APIs. 29 | 30 | Converts Message objects to the standard OpenAI chat format with 31 | support for function calls, tool calls, and multimodal content. 32 | 33 | Args: 34 | messages: List of Message objects to format. 35 | 36 | Returns: 37 | List[Dict[str, Any]]: OpenAI-compatible message format. 38 | """ 39 | return self._format_messages_base(messages) 40 | 41 | def _get_api_endpoint(self) -> str: 42 | """Get the chat completions endpoint URL. 43 | 44 | Must be implemented by subclasses to provide the correct API endpoint. 45 | 46 | Returns: 47 | str: Full URL for the chat completions endpoint. 48 | 49 | Raises: 50 | NotImplementedError: If not implemented by subclass. 51 | """ 52 | raise NotImplementedError("Subclasses must implement _get_api_endpoint") 53 | 54 | def _get_request_headers(self) -> Dict[str, str]: 55 | """Get authentication and request headers. 56 | 57 | Must be implemented by subclasses to provide provider-specific 58 | authentication headers. 59 | 60 | Returns: 61 | Dict[str, str]: Headers for API requests. 62 | 63 | Raises: 64 | NotImplementedError: If not implemented by subclass. 65 | """ 66 | raise NotImplementedError("Subclasses must implement _get_request_headers") 67 | 68 | def _customize_payload(self, payload: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: 69 | """Customize request payload for provider-specific requirements. 70 | 71 | Override this method to modify the request payload before sending. 72 | Default implementation returns payload unchanged. 73 | 74 | Args: 75 | payload: Base OpenAI-compatible request payload. 76 | **kwargs: Additional parameters from the complete() call. 77 | 78 | Returns: 79 | Dict[str, Any]: Modified payload for the API request. 80 | """ 81 | return payload 82 | 83 | def _build_base_payload( 84 | self, messages: List[Message], model: str, **kwargs: Any 85 | ) -> Dict[str, Any]: 86 | """Build base OpenAI-compatible payload with parameter filtering. 87 | 88 | Constructs the base payload with model and formatted messages, then filters 89 | and adds supported parameters while logging ignored/unknown parameters. 90 | 91 | Args: 92 | messages: Conversation messages to process. 93 | model: Model identifier for the request. 94 | **kwargs: Additional parameters to filter and add to payload. 95 | 96 | Returns: 97 | Dict[str, Any]: Base payload ready for API request. 98 | """ 99 | # Build base payload 100 | payload = { 101 | "model": model, 102 | "messages": self._format_messages_openai(messages), 103 | } 104 | 105 | # Supported OpenAI parameters 106 | supported_params = { 107 | "temperature", 108 | "top_p", 109 | "max_tokens", 110 | "stop", 111 | "n", 112 | "presence_penalty", 113 | "frequency_penalty", 114 | "tools", 115 | "tool_choice", 116 | "response_format", 117 | "seed", 118 | "user", 119 | } 120 | 121 | # Parameters to ignore (provider-specific or handled separately) 122 | ignored_params = {"top_k", "generation_config", "timeout"} 123 | 124 | # Filter and add parameters 125 | for key, value in kwargs.items(): 126 | if value is not None: 127 | if key in ignored_params: 128 | logger.debug(f"Parameter '{key}' is not supported by OpenAI. Ignoring.") 129 | elif key in supported_params: 130 | payload[key] = value 131 | else: 132 | logger.debug(f"Unknown parameter '{key}' ignored. Not in OpenAI API spec.") 133 | 134 | return payload 135 | 136 | def _parse_sse_line(self, line: str) -> Optional[StreamChunk]: 137 | """Parse a single SSE line into a StreamChunk. 138 | 139 | Args: 140 | line: Raw SSE line to parse. 141 | 142 | Returns: 143 | StreamChunk if line contains valid data, None otherwise. 144 | """ 145 | line = line.strip() 146 | if not line: 147 | return None 148 | 149 | if not line.startswith("data: "): 150 | return None 151 | 152 | data = line[6:] # Remove "data: " prefix 153 | 154 | if data == "[DONE]": 155 | return None # Signal end of stream 156 | 157 | try: 158 | chunk_data = json.loads(data) 159 | choices = chunk_data.get("choices", []) 160 | 161 | if choices: 162 | delta = choices[0].get("delta", {}) 163 | content = delta.get("content") 164 | finish_reason = choices[0].get("finish_reason") 165 | 166 | if content or finish_reason: 167 | return StreamChunk( 168 | content=content, 169 | finish_reason=finish_reason, 170 | raw=chunk_data, 171 | ) 172 | except json.JSONDecodeError: 173 | logger.warning(f"Failed to parse SSE chunk: {data}") 174 | 175 | return None 176 | 177 | def _stream_sse_response( 178 | self, 179 | url: str, 180 | payload: Dict[str, Any], 181 | headers: Dict[str, str], 182 | timeout: Optional[float] = None, 183 | ) -> Iterator[StreamChunk]: 184 | """Stream SSE response from OpenAI-compatible endpoint. 185 | 186 | Args: 187 | url: API endpoint URL. 188 | payload: Request payload (should have stream=True). 189 | headers: Request headers. 190 | timeout: Optional timeout in seconds. 191 | 192 | Yields: 193 | StreamChunk objects from the response. 194 | 195 | Raises: 196 | ProviderError: If the streaming request fails. 197 | """ 198 | from justllms.core.streaming import parse_sse_stream 199 | 200 | return parse_sse_stream( 201 | url=url, 202 | payload=payload, 203 | headers=headers, 204 | parse_chunk_fn=self._parse_sse_line, 205 | timeout=timeout, 206 | error_prefix="Streaming request", 207 | ) 208 | 209 | def _parse_openai_response( 210 | self, response_data: Dict[str, Any], model: str, response_class: type 211 | ) -> BaseResponse: 212 | """Parse OpenAI-compatible API response. 213 | 214 | Handles standard OpenAI response format with choices and usage data. 215 | Uses common parsing utilities from BaseProvider. 216 | 217 | Args: 218 | response_data: Raw JSON response from API. 219 | model: Model identifier used for the request. 220 | response_class: Response class to instantiate. 221 | 222 | Returns: 223 | BaseResponse: Parsed response object with choices and usage. 224 | """ 225 | choices = [] 226 | for choice_data in response_data.get("choices", []): 227 | message_data = choice_data.get("message", {}) 228 | choice = self._create_standard_choice( 229 | {**message_data, "finish_reason": choice_data.get("finish_reason")}, 230 | choice_data.get("index", 0), 231 | ) 232 | choices.append(choice) 233 | 234 | usage = self._create_standard_usage(response_data.get("usage", {})) 235 | 236 | return self._create_base_response( 237 | response_class, 238 | response_data, 239 | choices, 240 | usage, 241 | model, 242 | ) 243 | 244 | def complete( 245 | self, 246 | messages: List[Message], 247 | model: str, 248 | timeout: Any = None, 249 | **kwargs: Any, 250 | ) -> BaseResponse: 251 | """Execute OpenAI-compatible chat completion request. 252 | 253 | Constructs standardized request payload and handles the API call 254 | using common patterns for OpenAI-compatible providers. 255 | 256 | Args: 257 | messages: Conversation messages to process. 258 | model: Model identifier for the request. 259 | timeout: Optional timeout in seconds. If None, no timeout is enforced. 260 | **kwargs: Additional parameters (temperature, max_tokens, etc.). 261 | 262 | Returns: 263 | BaseResponse: Completed response from the provider. 264 | 265 | Raises: 266 | ProviderError: If the API request fails. 267 | NotImplementedError: If required methods are not implemented. 268 | """ 269 | url = self._get_api_endpoint() 270 | 271 | # Build base payload with parameter filtering 272 | payload = self._build_base_payload(messages, model, **kwargs) 273 | 274 | # Allow provider-specific customization 275 | payload = self._customize_payload(payload, **kwargs) 276 | 277 | # Execute request using common HTTP handling 278 | response_data = self._make_http_request( 279 | url=url, 280 | payload=payload, 281 | headers=self._get_request_headers(), 282 | timeout=timeout, 283 | ) 284 | try: 285 | import importlib 286 | 287 | module = importlib.import_module(self.__module__) 288 | response_class_name = f"{self.__class__.__name__.replace('Provider', 'Response')}" 289 | response_class = getattr(module, response_class_name, BaseResponse) 290 | except (ImportError, AttributeError): 291 | response_class = BaseResponse 292 | 293 | return self._parse_openai_response(response_data, model, response_class) 294 | 295 | def stream( 296 | self, 297 | messages: List[Message], 298 | model: str, 299 | timeout: Optional[float] = None, 300 | **kwargs: Any, 301 | ) -> SyncStreamResponse: 302 | """Stream completion using Server-Sent Events. 303 | 304 | Args: 305 | messages: Conversation messages to process. 306 | model: Model identifier for the request. 307 | timeout: Optional timeout in seconds. 308 | **kwargs: Additional parameters. 309 | 310 | Returns: 311 | SyncStreamResponse: Streaming response iterator. 312 | """ 313 | url = self._get_api_endpoint() 314 | 315 | # Build base payload with parameter filtering 316 | payload = self._build_base_payload(messages, model, **kwargs) 317 | 318 | # Enable streaming 319 | payload["stream"] = True 320 | 321 | # Allow provider-specific customization 322 | payload = self._customize_payload(payload, **kwargs) 323 | 324 | # Use shared SSE streaming helper 325 | stream_iter = self._stream_sse_response( 326 | url=url, 327 | payload=payload, 328 | headers=self._get_request_headers(), 329 | timeout=timeout, 330 | ) 331 | 332 | return SyncStreamResponse( 333 | provider=self, model=model, messages=messages, raw_stream=stream_iter 334 | ) 335 | 336 | def supports_streaming(self) -> bool: 337 | """OpenAI-compatible providers support streaming.""" 338 | return True 339 | 340 | def supports_streaming_for_model(self, model: str) -> bool: 341 | """Check if model supports streaming.""" 342 | return model in self.get_available_models() 343 | --------------------------------------------------------------------------------