├── tests ├── __init__.py ├── requirements.txt ├── conftest.py └── test_all_models.py ├── llm_benchmark.png ├── llms ├── results │ ├── __init__.py │ └── result.py ├── __init__.py └── providers │ ├── __init__.py │ ├── ai21.py │ ├── base_provider.py │ ├── bedrock_anthropic.py │ ├── aleph.py │ ├── huggingface.py │ ├── cohere.py │ ├── together.py │ ├── reka.py │ ├── deepseek.py │ ├── groq.py │ ├── ollama.py │ ├── mistral.py │ ├── openrouter.py │ ├── anthropic.py │ ├── google_genai.py │ └── openai.py ├── requirements.txt ├── LICENSE ├── pytest.ini ├── setup.py ├── .gitignore └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Tests package -------------------------------------------------------------------------------- /llm_benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kagisearch/pyllms/main/llm_benchmark.png -------------------------------------------------------------------------------- /llms/results/__init__.py: -------------------------------------------------------------------------------- 1 | from .result import Result, Results, StreamResult, AsyncStreamResult 2 | -------------------------------------------------------------------------------- /tests/requirements.txt: -------------------------------------------------------------------------------- 1 | # Test dependencies for PyLLMs 2 | pytest>=7.0.0 3 | pytest-asyncio>=0.21.0 4 | pytest-timeout>=2.1.0 5 | pytest-mock>=3.10.0 6 | pytest-cov>=4.0.0 7 | pytest-xdist>=3.0.0 # For parallel test execution 8 | python-dotenv>=1.0.0 # For automatic .env loading -------------------------------------------------------------------------------- /llms/__init__.py: -------------------------------------------------------------------------------- 1 | from .llms import LLMS 2 | 3 | def init(*args, **kwargs): 4 | if len(args) > 1 and not kwargs.get('model'): 5 | raise ValueError( 6 | "Please provide a list of models, like this: model=['j2-grande-instruct', 'claude-v1', 'gpt-3.5-turbo']" 7 | ) 8 | return LLMS(*args, **kwargs) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | openai===1.93.0 2 | tiktoken===0.9.0 3 | anthropic===0.55.0 4 | anthropic_bedrock===0.8.0 5 | ai21===4.0.3 6 | cohere===5.15.0 7 | aleph-alpha-client===10.4.0 8 | huggingface_hub===0.33.1 9 | prettytable===3.16.0 10 | aiohttp===3.12.13 11 | google-cloud-aiplatform~=1.100.0 12 | einops===0.8.1 13 | accelerate===1.8.1 14 | protobuf~=5.28.0 15 | grpcio===1.73.1 16 | google-genai>=1.23.0 17 | ollama===0.5.1 18 | reka-api===3.2.0 19 | together===1.5.17 20 | mistralai===1.8.2 21 | groq===0.29.0 22 | deepseek===1.0.0 23 | -------------------------------------------------------------------------------- /llms/providers/__init__.py: -------------------------------------------------------------------------------- 1 | from .ai21 import AI21Provider 2 | from .aleph import AlephAlphaProvider 3 | from .anthropic import AnthropicProvider 4 | from .bedrock_anthropic import BedrockAnthropicProvider 5 | from .cohere import CohereProvider 6 | from .google_genai import GoogleGenAIProvider, GoogleVertexAIProvider 7 | from .huggingface import HuggingfaceHubProvider 8 | from .openai import OpenAIProvider 9 | from .mistral import MistralProvider 10 | from .ollama import OllamaProvider 11 | from .deepseek import DeepSeekProvider 12 | from .groq import GroqProvider 13 | from .reka import RekaProvider 14 | from .together import TogetherProvider 15 | from .openrouter import OpenRouterProvider 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Kagi Search 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [tool:pytest] 2 | # Pytest configuration for PyLLMs 3 | 4 | # Test discovery 5 | testpaths = tests 6 | python_files = test_*.py 7 | python_classes = Test* 8 | python_functions = test_* 9 | 10 | # Output options 11 | addopts = 12 | --verbose 13 | --tb=short 14 | --strict-markers 15 | --strict-config 16 | --disable-warnings 17 | -ra 18 | 19 | # Async support 20 | asyncio_mode = auto 21 | 22 | # Timeout settings 23 | timeout = 300 24 | timeout_method = thread 25 | 26 | # Markers 27 | markers = 28 | slow: marks tests as slow (may take time due to API calls) 29 | requires_api_key: marks tests that require API keys 30 | unit: fast unit tests 31 | integration: integration tests that may require external services 32 | 33 | # Minimum version 34 | minversion = 7.0 35 | 36 | # Test output 37 | console_output_style = progress 38 | 39 | # Fail on first failure for debugging (comment out for full test runs) 40 | # addopts = --maxfail=1 41 | 42 | # Coverage options (uncomment to enable coverage reporting) 43 | # addopts = --cov=llms --cov-report=html --cov-report=term-missing 44 | 45 | # Parallel execution (uncomment to enable, requires pytest-xdist) 46 | # addopts = -n auto -------------------------------------------------------------------------------- /llms/providers/ai21.py: -------------------------------------------------------------------------------- 1 | # llms/providers/ai21.py 2 | 3 | import ai21 4 | 5 | from ..results.result import Result 6 | from .base_provider import BaseProvider 7 | 8 | 9 | class AI21Provider(BaseProvider): 10 | # per million tokens 11 | MODEL_INFO = { 12 | "j2-grande-instruct": {"prompt": 10.0, "completion": 10.0, "token_limit": 8192}, 13 | "j2-jumbo-instruct": {"prompt": 15.0, "completion": 15.0, "token_limit": 8192}, 14 | } 15 | 16 | def __init__(self, api_key, model=None): 17 | ai21.api_key = api_key 18 | if model is None: 19 | model = list(self.MODEL_INFO.keys())[0] 20 | self.model = model 21 | 22 | def _prepare_model_inputs( 23 | self, 24 | prompt: str, 25 | temperature: float = 0, 26 | max_tokens: int = 300, 27 | **kwargs, 28 | ): 29 | maxTokens = kwargs.pop("maxTokens", max_tokens) 30 | model_inputs = { 31 | "prompt": prompt, 32 | "temperature": temperature, 33 | "maxTokens": maxTokens, 34 | **kwargs, 35 | } 36 | return model_inputs 37 | 38 | def complete( 39 | self, 40 | prompt: str, 41 | temperature: float = 0, 42 | max_tokens: int = 300, 43 | **kwargs, 44 | ) -> Result: 45 | model_inputs = self._prepare_model_inputs( 46 | prompt=prompt, 47 | temperature=temperature, 48 | max_tokens=max_tokens, 49 | **kwargs, 50 | ) 51 | with self.track_latency(): 52 | response = ai21.Completion.execute(model=self.model, **model_inputs) 53 | 54 | completion = response.completions[0].data.text.strip() 55 | tokens_prompt = len(response.prompt.tokens) 56 | tokens_completion = len(response.completions[0].data.tokens) 57 | 58 | meta = { 59 | "tokens_prompt": tokens_prompt, 60 | "tokens_completion": tokens_completion, 61 | "latency": self.latency, 62 | } 63 | 64 | return Result( 65 | text=completion, 66 | model_inputs=model_inputs, 67 | provider=self, 68 | meta=meta, 69 | ) 70 | -------------------------------------------------------------------------------- /llms/providers/base_provider.py: -------------------------------------------------------------------------------- 1 | import time 2 | from contextlib import contextmanager 3 | from typing import Dict 4 | 5 | 6 | class BaseProvider: 7 | """Base class for all providers. 8 | Methods will raise NotImplementedError if they are not overwritten. 9 | """ 10 | 11 | MODEL_INFO = {} 12 | 13 | def __init__(self, model=None, api_key=None, **kwargs): 14 | self.latency = None 15 | # to be overwritten by subclasses 16 | self.api_key = api_key 17 | 18 | def __repr__(self) -> str: 19 | return f"{self.__class__.__name__}('{self.model}')" 20 | 21 | def __str__(self): 22 | return f"{self.__class__.__name__}('{self.model}')" 23 | 24 | def _prepare_model_inputs( 25 | self, 26 | **kwargs 27 | ) -> Dict: 28 | raise NotImplementedError() 29 | 30 | @contextmanager 31 | def track_latency(self): 32 | start = time.perf_counter() 33 | try: 34 | yield 35 | finally: 36 | self.latency = round(time.perf_counter() - start, 2) 37 | 38 | def compute_cost(self, prompt_tokens: int, completion_tokens: int) -> float: 39 | cost_per_token = self.MODEL_INFO[self.model] 40 | cost = ( 41 | (prompt_tokens * cost_per_token["prompt"]) 42 | + (completion_tokens * cost_per_token["completion"]) 43 | ) / 1_000_000 44 | cost = round(cost, 5) 45 | return cost 46 | 47 | def count_tokens(self, content): 48 | raise NotImplementedError( 49 | f"Count tokens is currently not supported with {self.__class__.__name__}" 50 | ) 51 | 52 | def complete(self, *args, **kwargs): 53 | raise NotImplementedError 54 | 55 | async def acomplete(self): 56 | raise NotImplementedError( 57 | f"Async complete is not yet supported with {self.__class__.__name__}" 58 | ) 59 | 60 | def complete_stream(self): 61 | raise NotImplementedError( 62 | f"Streaming is not yet supported with {self.__class__.__name__}" 63 | ) 64 | 65 | async def acomplete_stream(self): 66 | raise NotImplementedError( 67 | f"Async streaming is not yet supported with {self.__class__.__name__}" 68 | ) 69 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | import pathlib 3 | 4 | # Get the long description from the README file 5 | here = pathlib.Path(__file__).parent.resolve() 6 | long_description = (here / "README.md").read_text(encoding="utf-8") 7 | 8 | _project_homepage = "https://github.com/kagisearch/pyllms" 9 | 10 | setup( 11 | name="pyllms", 12 | version="0.7.6", 13 | description="Minimal Python library to connect to LLMs (OpenAI, Anthropic, Google, Mistral, OpenRouter, Reka, Groq, Together, Ollama, AI21, Cohere, Aleph-Alpha, HuggingfaceHub), with a built-in model performance benchmark.", 14 | long_description=long_description, 15 | long_description_content_type="text/markdown", 16 | author="Vladimir Prelovac", 17 | author_email="vlad@kagi.com", 18 | packages=find_packages(), 19 | install_requires=[ 20 | "openai>=1", 21 | "tiktoken", 22 | "anthropic>=0.18", 23 | "ai21", 24 | "cohere", 25 | "aleph-alpha-client", 26 | "huggingface_hub", 27 | "google-cloud-aiplatform", 28 | "prettytable", 29 | "protobuf>=3.20.3", 30 | "grpcio>=1.54.2", 31 | "google-generativeai", 32 | "mistralai", 33 | "ollama", 34 | "reka-api", 35 | "together", 36 | ], 37 | extras_require={ 38 | "local": ["einops", "accelerate"] 39 | }, 40 | classifiers=[ 41 | "Development Status :: 3 - Alpha", 42 | "Intended Audience :: Developers", 43 | "License :: OSI Approved :: MIT License", 44 | "Programming Language :: Python :: 3", 45 | "Programming Language :: Python :: 3.7", 46 | "Programming Language :: Python :: 3.8", 47 | "Programming Language :: Python :: 3.9", 48 | "Programming Language :: Python :: 3.10", 49 | "Programming Language :: Python :: 3.11", 50 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 51 | "Topic :: Scientific/Engineering :: Human Machine Interfaces", 52 | "Topic :: Text Processing", 53 | ], 54 | python_requires=">=3.7", 55 | keywords="llm, llms, large language model, AI, NLP, natural language processing, gpt, chatgpt, openai, anthropic, ai21, cohere, aleph alpha, huggingface hub, vertex ai, palm, palm2, deepseek", 56 | project_urls={ 57 | "Documentation": _project_homepage, 58 | "Source Code": _project_homepage, 59 | "Issue Tracker": _project_homepage+"/issues", 60 | }, 61 | ) 62 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | *~ 132 | .aider* 133 | -------------------------------------------------------------------------------- /llms/providers/bedrock_anthropic.py: -------------------------------------------------------------------------------- 1 | # llms/providers/bedrock_anthropic.py 2 | 3 | import os 4 | from typing import Union, Dict 5 | import tiktoken 6 | 7 | from anthropic import AnthropicBedrock, AsyncAnthropicBedrock 8 | 9 | from .anthropic import AnthropicProvider 10 | 11 | 12 | class BedrockAnthropicProvider(AnthropicProvider): 13 | MODEL_INFO = { 14 | "anthropic.claude-3-haiku-20240307-v1:0": {"prompt": 0.25, "completion": 1.25, "token_limit": 200_000, "output_limit": 4_096}, 15 | "anthropic.claude-3-sonnet-20240229-v1:0": {"prompt": 3.00, "completion": 15, "token_limit": 200_000, "output_limit": 4_096}, 16 | "anthropic.claude-3-5-sonnet-20240620-v1:0": {"prompt": 3.00, "completion": 15, "token_limit": 200_000, "output_limit": 4_096}, 17 | } 18 | 19 | def __init__( 20 | self, 21 | model: Union[str, None] = None, 22 | aws_access_key: Union[str, None] = None, 23 | aws_secret_key: Union[str, None] = None, 24 | aws_region: Union[str, None] = None, 25 | client_kwargs: Union[dict, None] = None, 26 | async_client_kwargs: Union[dict, None] = None, 27 | ): 28 | if model is None: 29 | model = list(self.MODEL_INFO.keys())[0] 30 | self.model = model 31 | 32 | if aws_access_key is None: 33 | aws_access_key = os.getenv("AWS_ACCESS_KEY_ID") 34 | if aws_secret_key is None: 35 | aws_secret_key = os.getenv("AWS_SECRET_ACCESS_KEY") 36 | 37 | if client_kwargs is None: 38 | client_kwargs = {} 39 | self.client = AnthropicBedrock( 40 | aws_access_key=aws_access_key, 41 | aws_secret_key=aws_secret_key, 42 | aws_region=aws_region, 43 | **client_kwargs, 44 | ) 45 | 46 | if async_client_kwargs is None: 47 | async_client_kwargs = {} 48 | self.async_client = AsyncAnthropicBedrock( 49 | aws_access_key=aws_access_key, 50 | aws_secret_key=aws_secret_key, 51 | aws_region=aws_region, 52 | **async_client_kwargs, 53 | ) 54 | 55 | def count_tokens(self, content: str | Dict) -> int: 56 | """ 57 | Override count_tokens since AnthropicBedrock doesn't have this method. 58 | Use tiktoken as a fallback for token estimation. 59 | """ 60 | enc = tiktoken.encoding_for_model("gpt-3.5-turbo") # Use GPT-3.5 as approximation 61 | 62 | if isinstance(content, str): 63 | return len(enc.encode(content, disallowed_special=())) 64 | 65 | # Handle message format 66 | formatting_token_count = 4 67 | total = 0 68 | for message in content: 69 | total += len(enc.encode(message["content"], disallowed_special=())) + formatting_token_count 70 | return total 71 | 72 | @property 73 | def support_message_api(self): 74 | return "claude-3" in self.model -------------------------------------------------------------------------------- /llms/providers/aleph.py: -------------------------------------------------------------------------------- 1 | # llms/providers/aleph.py 2 | 3 | import os 4 | 5 | import tiktoken 6 | from aleph_alpha_client import AsyncClient, Client, CompletionRequest, Prompt 7 | 8 | from ..results.result import Result 9 | from .base_provider import BaseProvider 10 | 11 | 12 | class AlephAlphaProvider(BaseProvider): 13 | MODEL_INFO = { 14 | "luminous-base": {"prompt": 6.6, "completion": 7.6, "token_limit": 2048}, 15 | "luminous-extended": {"prompt": 9.9, "completion": 10.9, "token_limit": 2048}, 16 | "luminous-supreme": {"prompt": 38.5, "completion": 42.5, "token_limit": 2048}, 17 | "luminous-supreme-control": { 18 | "prompt": 48.5, 19 | "completion": 53.6, 20 | "token_limit": 2048, 21 | }, 22 | } 23 | 24 | def __init__(self, api_key=None, model=None): 25 | if api_key is None: 26 | api_key = os.getenv("ALEPHALPHA_API_KEY") 27 | self.client = Client(api_key) 28 | self.async_client = AsyncClient(api_key) 29 | 30 | if model is None: 31 | model = list(self.MODEL_INFO.keys())[0] 32 | self.model = model 33 | 34 | def count_tokens(self, content: str): 35 | enc = tiktoken.encoding_for_model("gpt-3.5-turbo") 36 | return len(enc.encode(content)) 37 | 38 | def _prepare_model_inputs( 39 | self, 40 | prompt: str, 41 | temperature: float = 0, 42 | max_tokens: int = 300, 43 | **kwargs, 44 | ) -> CompletionRequest: 45 | prompt = Prompt.from_text(prompt) 46 | maximum_tokens = kwargs.pop("maximum_tokens", max_tokens) 47 | 48 | model_inputs = CompletionRequest( 49 | prompt=prompt, 50 | temperature=temperature, 51 | maximum_tokens=maximum_tokens, 52 | **kwargs, 53 | ) 54 | return model_inputs 55 | 56 | def complete( 57 | self, 58 | prompt: str, 59 | temperature: float = 0, 60 | max_tokens: int = 300, 61 | **kwargs, 62 | ) -> Result: 63 | model_inputs = self._prepare_model_inputs( 64 | prompt=prompt, temperature=temperature, max_tokens=max_tokens, **kwargs 65 | ) 66 | with self.track_latency(): 67 | response = self.client.complete(request=model_inputs, model=self.model) 68 | 69 | completion = response.completions[0].completion.strip() 70 | 71 | return Result( 72 | text=completion, 73 | model_inputs=model_inputs, 74 | provider=self, 75 | meta={"latency": self.latency}, 76 | ) 77 | 78 | async def acomplete( 79 | self, 80 | prompt: str, 81 | temperature: float = 0, 82 | max_tokens: int = 300, 83 | **kwargs, 84 | ) -> Result: 85 | model_inputs = self._prepare_model_inputs( 86 | prompt=prompt, temperature=temperature, max_tokens=max_tokens, **kwargs 87 | ) 88 | with self.track_latency(): 89 | async with self.async_client as client: 90 | response = await client.complete(request=model_inputs, model=self.model) 91 | 92 | completion = response.completions[0].completion.strip() 93 | 94 | return Result( 95 | text=completion, 96 | model_inputs=model_inputs, 97 | provider=self, 98 | meta={"latency": self.latency}, 99 | ) 100 | -------------------------------------------------------------------------------- /llms/providers/huggingface.py: -------------------------------------------------------------------------------- 1 | # llms/providers/huggingface.py 2 | 3 | import os 4 | 5 | from huggingface_hub.inference_api import InferenceApi 6 | 7 | from ..results.result import Result 8 | from .base_provider import BaseProvider 9 | 10 | 11 | class HuggingfaceHubProvider(BaseProvider): 12 | MODEL_INFO = { 13 | "hf_pythia": { 14 | "full": "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5", 15 | "prompt": 0, 16 | "completion": 0, 17 | "token_limit": 2048, 18 | }, 19 | "hf_falcon40b": { 20 | "full": "tiiuae/falcon-40b-instruct", 21 | "prompt": 0, 22 | "completion": 0, 23 | "token_limit": 2048, 24 | "local": True 25 | }, 26 | "hf_falcon7b": { 27 | "full": "tiiuae/falcon-7b-instruct", 28 | "prompt": 0, 29 | "completion": 0, 30 | "token_limit": 2048, 31 | "local": True 32 | }, 33 | "hf_mptinstruct": { 34 | "full": "mosaicml/mpt-7b-instruct", 35 | "prompt": 0, 36 | "completion": 0, 37 | "token_limit": 2048, 38 | "local": True 39 | }, 40 | "hf_mptchat": { 41 | "full": "mosaicml/mpt-7b-chat", 42 | "prompt": 0, 43 | "completion": 0, 44 | "token_limit": 2048, 45 | "local": True 46 | }, 47 | "hf_llava": { 48 | "full": "liuhaotian/LLaVA-Lightning-MPT-7B-preview", 49 | "prompt": 0, 50 | "completion": 0, 51 | "token_limit": 2048, 52 | "local": True 53 | }, 54 | "hf_dolly": { 55 | "full": "databricks/dolly-v2-12b", 56 | "prompt": 0, 57 | "completion": 0, 58 | "token_limit": -1, 59 | "local": True 60 | }, 61 | "hf_vicuna": { 62 | "full": "CarperAI/stable-vicuna-13b-delta", 63 | "prompt": 0, 64 | "completion": 0, 65 | "token_limit": -1, 66 | }, 67 | } 68 | 69 | def __init__(self, api_key=None, model=None): 70 | if model is None: 71 | model = list(self.MODEL_INFO.keys())[0] 72 | 73 | self.model = model 74 | 75 | if api_key is None: 76 | api_key = os.getenv("HUGGINFACEHUB_API_KEY") 77 | 78 | self.client = InferenceApi( 79 | repo_id=self.MODEL_INFO[model]["full"], token=api_key 80 | ) 81 | 82 | def _prepare_model_inputs( 83 | self, 84 | prompt: str, 85 | temperature: float = 1.0, 86 | max_tokens: int = 300, 87 | **kwargs, 88 | ): 89 | if self.model == "hf_pythia": 90 | prompt = "<|prompter|" + prompt + "<|endoftext|><|assistant|>" 91 | max_new_tokens = kwargs.pop("max_length", max_tokens) 92 | params = { 93 | "temperature": temperature, 94 | "max_length": max_new_tokens, 95 | **kwargs, 96 | } 97 | return prompt, params 98 | 99 | def complete( 100 | self, 101 | prompt: str, 102 | temperature: float = 0.01, 103 | max_tokens: int = 300, 104 | **kwargs, 105 | ) -> Result: 106 | prompt, params = self._prepare_model_inputs( 107 | prompt=prompt, 108 | temperature=temperature, 109 | max_tokens=max_tokens, 110 | **kwargs, 111 | ) 112 | with self.track_latency(): 113 | response = self.client(inputs=prompt, params=params) 114 | 115 | completion = response[0]["generated_text"][len(prompt) :] 116 | meta = { 117 | "tokens_prompt": -1, 118 | "tokens_completion": -1, 119 | "latency": self.latency, 120 | } 121 | return Result( 122 | text=completion, 123 | model_inputs={"prompt": prompt, **params}, 124 | provider=self, 125 | meta=meta, 126 | ) 127 | -------------------------------------------------------------------------------- /llms/providers/cohere.py: -------------------------------------------------------------------------------- 1 | # llms/providers/cohere.py 2 | 3 | import os 4 | from typing import Dict, Generator 5 | 6 | import cohere 7 | 8 | from ..results.result import Result, StreamResult 9 | from .base_provider import BaseProvider 10 | 11 | 12 | class CohereProvider(BaseProvider): 13 | MODEL_INFO = { 14 | "command": {"prompt": 15.0, "completion": 15, "token_limit": 2048}, 15 | "command-nightly": { 16 | "prompt": 15.0, 17 | "completion": 15, 18 | "token_limit": 4096, 19 | }, 20 | } 21 | 22 | def __init__(self, api_key=None, model=None): 23 | if api_key is None: 24 | api_key = os.getenv("COHERE_API_KEY") 25 | self.client = cohere.Client(api_key) 26 | self.async_client = cohere.AsyncClient(api_key) 27 | 28 | if model is None: 29 | model = list(self.MODEL_INFO.keys())[0] 30 | self.model = model 31 | 32 | def count_tokens(self, content: str) -> int: 33 | tokens = self.client.tokenize(content) 34 | return len(tokens) 35 | 36 | def _prepare_model_inputs( 37 | self, 38 | prompt: str, 39 | temperature: float = 0, 40 | max_tokens: int = 300, 41 | stream: bool = False, 42 | **kwargs, 43 | ) -> Dict: 44 | model_inputs = { 45 | "prompt": prompt, 46 | "temperature": temperature, 47 | "max_tokens": max_tokens, 48 | "stream": stream, 49 | **kwargs, 50 | } 51 | return model_inputs 52 | 53 | def complete( 54 | self, 55 | prompt: str, 56 | temperature: float = 0, 57 | max_tokens: int = 300, 58 | **kwargs, 59 | ) -> Result: 60 | model_inputs = self._prepare_model_inputs( 61 | prompt=prompt, 62 | temperature=temperature, 63 | max_tokens=max_tokens, 64 | **kwargs, 65 | ) 66 | with self.track_latency(): 67 | response = self.client.generate( 68 | model=self.model, 69 | **model_inputs, 70 | ) 71 | 72 | completion = response.generations[0].text.strip() 73 | return Result( 74 | text=completion, 75 | model_inputs=model_inputs, 76 | provider=self, 77 | meta={"latency": self.latency}, 78 | ) 79 | 80 | async def acomplete( 81 | self, 82 | prompt: str, 83 | temperature: float = 0, 84 | max_tokens: int = 300, 85 | **kwargs, 86 | ) -> Result: 87 | model_inputs = self._prepare_model_inputs( 88 | prompt=prompt, 89 | temperature=temperature, 90 | max_tokens=max_tokens, 91 | **kwargs, 92 | ) 93 | with self.track_latency(): 94 | async with self.async_client() as client: 95 | response = await client.generate( 96 | model=self.model, 97 | **model_inputs, 98 | ) 99 | 100 | completion = response.generations[0].text.strip() 101 | 102 | return Result( 103 | text=completion, 104 | model_inputs=model_inputs, 105 | provider=self, 106 | meta={"latency": self.latency}, 107 | ) 108 | 109 | def complete_stream( 110 | self, 111 | prompt: str, 112 | temperature: float = 0, 113 | max_tokens: int = 300, 114 | **kwargs, 115 | ): 116 | model_inputs = self._prepare_model_inputs( 117 | prompt=prompt, 118 | temperature=temperature, 119 | max_tokens=max_tokens, 120 | stream=True, 121 | **kwargs, 122 | ) 123 | response = self.client.generate( 124 | model=self.model, 125 | **model_inputs, 126 | ) 127 | 128 | stream = self._process_stream(response) 129 | return StreamResult(stream=stream, model_inputs=model_inputs, provider=self) 130 | 131 | def _process_stream(self, response: Generator) -> Generator: 132 | first_text = next(response).text 133 | yield first_text.lstrip() 134 | 135 | for token in response: 136 | yield token.text 137 | -------------------------------------------------------------------------------- /llms/providers/together.py: -------------------------------------------------------------------------------- 1 | from typing import AsyncGenerator, Dict, List, Optional, Union 2 | import tiktoken 3 | 4 | from together import Together 5 | 6 | from ..results.result import AsyncStreamResult, Result, StreamResult 7 | from .base_provider import BaseProvider 8 | 9 | 10 | class TogetherProvider(BaseProvider): 11 | MODEL_INFO = { 12 | "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo": {"prompt": 5.0, "completion": 5.0, "token_limit": 4096}, 13 | } 14 | 15 | def __init__( 16 | self, 17 | api_key: Union[str, None] = None, 18 | model: Union[str, None] = None, 19 | **kwargs 20 | ): 21 | super().__init__(api_key=api_key, model=model) 22 | if model is None: 23 | model = "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo" 24 | self.model = model 25 | self.client = Together(api_key=api_key) 26 | 27 | def count_tokens(self, content: Union[str, List[dict]]) -> int: 28 | # Together uses the same tokenizer as OpenAI 29 | enc = tiktoken.encoding_for_model("gpt-3.5-turbo") 30 | if isinstance(content, list): 31 | return sum([len(enc.encode(str(message))) for message in content]) 32 | else: 33 | return len(enc.encode(content)) 34 | 35 | def _prepare_model_inputs( 36 | self, 37 | prompt: str, 38 | history: Optional[List[dict]] = None, 39 | system_message: Union[str, List[dict], None] = None, 40 | temperature: float = 0, 41 | max_tokens: int = 300, 42 | stream: bool = False, 43 | **kwargs, 44 | ) -> Dict: 45 | messages = [{"content": prompt, "role": "user"}] 46 | 47 | if history: 48 | messages = [*history, *messages] 49 | 50 | if isinstance(system_message, str): 51 | messages = [{"role": "system", "content": system_message}, *messages] 52 | elif isinstance(system_message, list): 53 | messages = [*system_message, *messages] 54 | 55 | model_inputs = { 56 | "messages": messages, 57 | "temperature": temperature, 58 | "max_tokens": max_tokens, 59 | "stream": stream, 60 | **kwargs, 61 | } 62 | return model_inputs 63 | 64 | def complete( 65 | self, 66 | prompt: str, 67 | history: Optional[List[dict]] = None, 68 | system_message: Optional[List[dict]] = None, 69 | temperature: float = 0, 70 | max_tokens: int = 300, 71 | **kwargs, 72 | ) -> Result: 73 | model_inputs = self._prepare_model_inputs( 74 | prompt=prompt, 75 | history=history, 76 | system_message=system_message, 77 | temperature=temperature, 78 | max_tokens=max_tokens, 79 | **kwargs, 80 | ) 81 | 82 | with self.track_latency(): 83 | response = self.client.chat.completions.create(model=self.model, **model_inputs) 84 | 85 | completion = response.choices[0].message.content.strip() 86 | prompt_tokens = self.count_tokens(model_inputs["messages"]) 87 | completion_tokens = self.count_tokens(completion) 88 | 89 | meta = { 90 | "tokens_prompt": prompt_tokens, 91 | "tokens_completion": completion_tokens, 92 | "latency": self.latency, 93 | } 94 | return Result( 95 | text=completion, 96 | model_inputs=model_inputs, 97 | provider=self, 98 | meta=meta, 99 | ) 100 | 101 | def complete_stream( 102 | self, 103 | prompt: str, 104 | history: Optional[List[dict]] = None, 105 | system_message: Union[str, List[dict], None] = None, 106 | temperature: float = 0, 107 | max_tokens: int = 300, 108 | **kwargs, 109 | ) -> StreamResult: 110 | model_inputs = self._prepare_model_inputs( 111 | prompt=prompt, 112 | history=history, 113 | system_message=system_message, 114 | temperature=temperature, 115 | max_tokens=max_tokens, 116 | stream=True, 117 | **kwargs, 118 | ) 119 | 120 | response = self.client.chat.completions.create(model=self.model, **model_inputs) 121 | stream = self._process_stream(response) 122 | 123 | return StreamResult(stream=stream, model_inputs=model_inputs, provider=self) 124 | 125 | def _process_stream(self, response): 126 | for chunk in response: 127 | yield chunk.choices[0].delta.content 128 | 129 | # Note: Async methods are not implemented for Together AI as their Python SDK doesn't support async operations 130 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pytest configuration and fixtures for PyLLMs test suite. 3 | """ 4 | 5 | # Automatically load .env file if it exists (must be done before any other imports) 6 | import os 7 | from pathlib import Path 8 | 9 | # Try to load .env file automatically, but don't break if dotenv isn't installed 10 | try: 11 | from dotenv import load_dotenv 12 | 13 | # Look for .env file in project root (parent of tests directory) 14 | project_root = Path(__file__).parent.parent 15 | env_file = project_root / ".env" 16 | 17 | if env_file.exists(): 18 | load_dotenv(env_file) 19 | # Don't print during import to avoid cluttering output 20 | 21 | except ImportError: 22 | # Silently continue if python-dotenv is not installed 23 | pass 24 | 25 | import pytest 26 | 27 | 28 | def pytest_configure(config): 29 | """Configure pytest with custom markers.""" 30 | config.addinivalue_line( 31 | "markers", "slow: marks tests as slow (may take time due to API calls)" 32 | ) 33 | config.addinivalue_line( 34 | "markers", "requires_api_key: marks tests that require API keys" 35 | ) 36 | 37 | 38 | def pytest_collection_modifyitems(config, items): 39 | """Add markers to tests based on their characteristics.""" 40 | for item in items: 41 | # Mark tests that make API calls as slow 42 | if "completion" in item.name.lower() or "async" in item.name.lower() or "stream" in item.name.lower(): 43 | item.add_marker(pytest.mark.slow) 44 | 45 | # Mark tests that require API keys 46 | if "test_model_" in item.name or "test_async_" in item.name or "test_streaming_" in item.name: 47 | item.add_marker(pytest.mark.requires_api_key) 48 | 49 | 50 | @pytest.fixture(scope="session") 51 | def test_environment(): 52 | """Fixture to check test environment setup.""" 53 | env_info = { 54 | "has_openai_key": bool(os.getenv("OPENAI_API_KEY")), 55 | "has_anthropic_key": bool(os.getenv("ANTHROPIC_API_KEY")), 56 | "has_google_key": bool(os.getenv("GOOGLE_API_KEY")), 57 | "has_groq_key": bool(os.getenv("GROQ_API_KEY")), 58 | "has_mistral_key": bool(os.getenv("MISTRAL_API_KEY")), 59 | "has_deepseek_key": bool(os.getenv("DEEPSEEK_API_KEY")), 60 | "has_cohere_key": bool(os.getenv("COHERE_API_KEY")), 61 | "total_api_keys": sum([ 62 | bool(os.getenv("OPENAI_API_KEY")), 63 | bool(os.getenv("ANTHROPIC_API_KEY")), 64 | bool(os.getenv("GOOGLE_API_KEY")), 65 | bool(os.getenv("GROQ_API_KEY")), 66 | bool(os.getenv("MISTRAL_API_KEY")), 67 | bool(os.getenv("DEEPSEEK_API_KEY")), 68 | bool(os.getenv("COHERE_API_KEY")), 69 | ]) 70 | } 71 | return env_info 72 | 73 | 74 | @pytest.fixture(autouse=True) 75 | def setup_test_environment(): 76 | """Auto-use fixture to set up test environment.""" 77 | # Set conservative default timeouts to avoid hanging tests 78 | os.environ.setdefault("PYTEST_TIMEOUT", "30") 79 | yield 80 | # Cleanup if needed 81 | pass 82 | 83 | 84 | def pytest_runtest_setup(item): 85 | """Setup for each test item.""" 86 | # Skip tests if no API keys are available and test requires them 87 | if item.get_closest_marker("requires_api_key"): 88 | api_keys_available = any([ 89 | os.getenv("OPENAI_API_KEY"), 90 | os.getenv("ANTHROPIC_API_KEY"), 91 | os.getenv("GOOGLE_API_KEY"), 92 | os.getenv("GROQ_API_KEY"), 93 | os.getenv("MISTRAL_API_KEY"), 94 | os.getenv("DEEPSEEK_API_KEY"), 95 | os.getenv("COHERE_API_KEY"), 96 | ]) 97 | if not api_keys_available: 98 | pytest.skip("No API keys available for testing") 99 | 100 | 101 | def pytest_sessionstart(session): 102 | """Called after the Session object has been created.""" 103 | print("\n" + "="*60) 104 | print("PyLLMs Test Suite") 105 | print("="*60) 106 | 107 | # Check for available API keys 108 | api_keys = { 109 | "OpenAI": bool(os.getenv("OPENAI_API_KEY")), 110 | "Anthropic": bool(os.getenv("ANTHROPIC_API_KEY")), 111 | "BedrockAnthropic": bool(os.getenv("AWS_ACCESS_KEY_ID") and os.getenv("AWS_SECRET_ACCESS_KEY")), 112 | "Google": bool(os.getenv("GOOGLE_API_KEY")), 113 | "Groq": bool(os.getenv("GROQ_API_KEY")), 114 | "Mistral": bool(os.getenv("MISTRAL_API_KEY")), 115 | "DeepSeek": bool(os.getenv("DEEPSEEK_API_KEY")), 116 | "Cohere": bool(os.getenv("COHERE_API_KEY")), 117 | "Together": bool(os.getenv("TOGETHER_API_KEY")), 118 | "OpenRouter": bool(os.getenv("OPENROUTER_API_KEY")), 119 | "Reka": bool(os.getenv("REKA_API_KEY")), 120 | "AI21": bool(os.getenv("AI21_API_KEY")), 121 | "AlephAlpha": bool(os.getenv("ALEPHALPHA_API_KEY")), 122 | "HuggingfaceHub": bool(os.getenv("HUGGINFACEHUB_API_KEY")), 123 | } 124 | 125 | available_keys = [name for name, available in api_keys.items() if available] 126 | print(f"API Keys Available: {len(available_keys)}/{len(api_keys)}") 127 | 128 | if available_keys: 129 | print("Available providers:", ", ".join(available_keys)) 130 | else: 131 | print("⚠️ No API keys found - only testing provider discovery and initialization") 132 | 133 | print("="*60) 134 | 135 | 136 | def pytest_sessionfinish(session, exitstatus): 137 | """Called after whole test run finished.""" 138 | print("\n" + "="*60) 139 | print("PyLLMs Test Suite Complete") 140 | print("="*60) -------------------------------------------------------------------------------- /llms/providers/reka.py: -------------------------------------------------------------------------------- 1 | from typing import AsyncGenerator, Dict, List, Optional, Union 2 | import tiktoken 3 | 4 | from reka.client import Reka, AsyncReka 5 | 6 | from ..results.result import AsyncStreamResult, Result, StreamResult 7 | from .base_provider import BaseProvider 8 | 9 | 10 | class RekaProvider(BaseProvider): 11 | MODEL_INFO = { 12 | "reka-edge": {"prompt": 0.4, "completion": 1.0, "token_limit": 128000}, 13 | "reka-flash": {"prompt": 0.8, "completion": 2.0, "token_limit": 128000}, 14 | "reka-core": {"prompt": 3.0, "completion": 15.0, "token_limit": 128000}, 15 | } 16 | 17 | def __init__( 18 | self, 19 | api_key: Union[str, None] = None, 20 | model: Union[str, None] = None, 21 | **kwargs 22 | ): 23 | super().__init__(api_key=api_key, model=model) 24 | if model is None: 25 | model = "reka-core" 26 | self.model = model 27 | self.client = Reka(api_key=api_key) 28 | self.async_client = AsyncReka(api_key=api_key) 29 | 30 | def count_tokens(self, content: Union[str, List[dict]]) -> int: 31 | # Reka uses the same tokenizer as OpenAI 32 | enc = tiktoken.encoding_for_model("gpt-3.5-turbo") 33 | if isinstance(content, list): 34 | return sum([len(enc.encode(str(message))) for message in content]) 35 | else: 36 | return len(enc.encode(content)) 37 | 38 | def _prepare_model_inputs( 39 | self, 40 | prompt: str, 41 | history: Optional[List[dict]] = None, 42 | system_message: Union[str, List[dict], None] = None, 43 | temperature: float = 0, 44 | max_tokens: int = 300, 45 | stream: bool = False, 46 | **kwargs, 47 | ) -> Dict: 48 | messages = [{"content": prompt, "role": "user"}] 49 | 50 | if history: 51 | messages = [*history, *messages] 52 | 53 | if isinstance(system_message, str): 54 | messages = [{"role": "system", "content": system_message}, *messages] 55 | elif isinstance(system_message, list): 56 | messages = [*system_message, *messages] 57 | 58 | model_inputs = { 59 | "messages": messages, 60 | "temperature": temperature, 61 | "max_tokens": max_tokens, 62 | **kwargs, 63 | } 64 | return model_inputs 65 | 66 | def complete( 67 | self, 68 | prompt: str, 69 | history: Optional[List[dict]] = None, 70 | system_message: Optional[List[dict]] = None, 71 | temperature: float = 0, 72 | max_tokens: int = 300, 73 | **kwargs, 74 | ) -> Result: 75 | model_inputs = self._prepare_model_inputs( 76 | prompt=prompt, 77 | history=history, 78 | system_message=system_message, 79 | temperature=temperature, 80 | max_tokens=max_tokens, 81 | **kwargs, 82 | ) 83 | 84 | with self.track_latency(): 85 | response = self.client.chat.create(model=self.model, **model_inputs) 86 | 87 | completion = response.responses[0].message.content.strip() 88 | prompt_tokens = self.count_tokens(model_inputs["messages"]) 89 | completion_tokens = self.count_tokens(completion) 90 | 91 | meta = { 92 | "tokens_prompt": prompt_tokens, 93 | "tokens_completion": completion_tokens, 94 | "latency": self.latency, 95 | } 96 | return Result( 97 | text=completion, 98 | model_inputs=model_inputs, 99 | provider=self, 100 | meta=meta, 101 | ) 102 | 103 | async def acomplete( 104 | self, 105 | prompt: str, 106 | history: Optional[List[dict]] = None, 107 | system_message: Optional[List[dict]] = None, 108 | temperature: float = 0, 109 | max_tokens: int = 300, 110 | **kwargs, 111 | ) -> Result: 112 | model_inputs = self._prepare_model_inputs( 113 | prompt=prompt, 114 | history=history, 115 | system_message=system_message, 116 | temperature=temperature, 117 | max_tokens=max_tokens, 118 | **kwargs, 119 | ) 120 | 121 | with self.track_latency(): 122 | response = await self.async_client.chat.create(model=self.model, **model_inputs) 123 | 124 | completion = response.responses[0].message.content.strip() 125 | prompt_tokens = self.count_tokens(model_inputs["messages"]) 126 | completion_tokens = self.count_tokens(completion) 127 | 128 | meta = { 129 | "tokens_prompt": prompt_tokens, 130 | "tokens_completion": completion_tokens, 131 | "latency": self.latency, 132 | } 133 | return Result( 134 | text=completion, 135 | model_inputs=model_inputs, 136 | provider=self, 137 | meta=meta, 138 | ) 139 | 140 | def complete_stream( 141 | self, 142 | prompt: str, 143 | history: Optional[List[dict]] = None, 144 | system_message: Union[str, List[dict], None] = None, 145 | temperature: float = 0, 146 | max_tokens: int = 300, 147 | **kwargs, 148 | ) -> StreamResult: 149 | model_inputs = self._prepare_model_inputs( 150 | prompt=prompt, 151 | history=history, 152 | system_message=system_message, 153 | temperature=temperature, 154 | max_tokens=max_tokens, 155 | **kwargs, 156 | ) 157 | 158 | response = self.client.chat.create_stream(model=self.model, **model_inputs) 159 | stream = self._process_stream(response) 160 | 161 | return StreamResult(stream=stream, model_inputs=model_inputs, provider=self) 162 | 163 | def _process_stream(self, response): 164 | for chunk in response: 165 | yield chunk.responses[0].chunk.content 166 | 167 | async def acomplete_stream( 168 | self, 169 | prompt: str, 170 | history: Optional[List[dict]] = None, 171 | system_message: Union[str, List[dict], None] = None, 172 | temperature: float = 0, 173 | max_tokens: int = 300, 174 | **kwargs, 175 | ) -> AsyncStreamResult: 176 | model_inputs = self._prepare_model_inputs( 177 | prompt=prompt, 178 | history=history, 179 | system_message=system_message, 180 | temperature=temperature, 181 | max_tokens=max_tokens, 182 | **kwargs, 183 | ) 184 | 185 | response = await self.async_client.chat.create_stream(model=self.model, **model_inputs) 186 | stream = self._aprocess_stream(response) 187 | return AsyncStreamResult( 188 | stream=stream, model_inputs=model_inputs, provider=self 189 | ) 190 | 191 | async def _aprocess_stream(self, response): 192 | async for chunk in response: 193 | yield chunk.responses[0].chunk.content 194 | -------------------------------------------------------------------------------- /llms/providers/deepseek.py: -------------------------------------------------------------------------------- 1 | from typing import AsyncGenerator, Dict, List, Optional, Union 2 | import tiktoken 3 | 4 | from openai import AsyncOpenAI, OpenAI 5 | 6 | from ..results.result import AsyncStreamResult, Result, StreamResult 7 | from .base_provider import BaseProvider 8 | 9 | 10 | class DeepSeekProvider(BaseProvider): 11 | MODEL_INFO = { 12 | "deepseek-chat": {"prompt": 0.14, "completion": 0.28, "token_limit": 128000, "is_chat": True, "output_limit": 8192}, 13 | "deepseek-coder": {"prompt": 0.14, "completion": 0.28, "token_limit": 128000, "is_chat": True, "output_limit": 8192}, 14 | "deepseek-reasoner": {"prompt": 0.55, "completion": 2.19, "token_limit": 32768, "is_chat": True, "output_limit": 8192}, 15 | } 16 | 17 | def __init__( 18 | self, 19 | api_key: Union[str, None] = None, 20 | model: Union[str, None] = None, 21 | client_kwargs: Union[dict, None] = None, 22 | async_client_kwargs: Union[dict, None] = None, 23 | ): 24 | if model is None: 25 | model = list(self.MODEL_INFO.keys())[0] 26 | self.model = model 27 | if client_kwargs is None: 28 | client_kwargs = {} 29 | self.client = OpenAI(api_key=api_key, base_url="https://api.deepseek.com/v1", **client_kwargs) 30 | if async_client_kwargs is None: 31 | async_client_kwargs = {} 32 | self.async_client = AsyncOpenAI(api_key=api_key, base_url="https://api.deepseek.com/v1", **async_client_kwargs) 33 | 34 | @property 35 | def is_chat_model(self) -> bool: 36 | return self.MODEL_INFO[self.model]['is_chat'] 37 | 38 | def count_tokens(self, content: Union[str, List[dict]]) -> int: 39 | # DeepSeek uses the same tokenizer as OpenAI 40 | enc = tiktoken.encoding_for_model("gpt-3.5-turbo") 41 | if isinstance(content, list): 42 | formatting_token_count = 4 43 | messages = content 44 | messages_text = ["".join(message.values()) for message in messages] 45 | tokens = [enc.encode(t, disallowed_special=()) for t in messages_text] 46 | 47 | n_tokens_list = [] 48 | for token, message in zip(tokens, messages): 49 | n_tokens = len(token) + formatting_token_count 50 | if "name" in message: 51 | n_tokens += -1 52 | n_tokens_list.append(n_tokens) 53 | return sum(n_tokens_list) 54 | else: 55 | return len(enc.encode(content, disallowed_special=())) 56 | 57 | def _prepare_model_inputs( 58 | self, 59 | prompt: str, 60 | history: Optional[List[dict]] = None, 61 | system_message: Union[str, List[dict], None] = None, 62 | temperature: float = 0, 63 | max_tokens: int = 300, 64 | stream: bool = False, 65 | **kwargs, 66 | ) -> Dict: 67 | messages = [{"role": "user", "content": prompt}] 68 | 69 | if history: 70 | messages = [*history, *messages] 71 | 72 | if isinstance(system_message, str): 73 | messages = [{"role": "system", "content": system_message}, *messages] 74 | elif isinstance(system_message, list): 75 | messages = [*system_message, *messages] 76 | 77 | model_inputs = { 78 | "messages": messages, 79 | "temperature": temperature, 80 | "max_tokens": max_tokens, 81 | "stream": stream, 82 | **kwargs, 83 | } 84 | return model_inputs 85 | 86 | def complete( 87 | self, 88 | prompt: str, 89 | history: Optional[List[dict]] = None, 90 | system_message: Optional[List[dict]] = None, 91 | temperature: float = 0, 92 | max_tokens: int = 300, 93 | **kwargs, 94 | ) -> Result: 95 | model_inputs = self._prepare_model_inputs( 96 | prompt=prompt, 97 | history=history, 98 | system_message=system_message, 99 | temperature=temperature, 100 | max_tokens=max_tokens, 101 | **kwargs, 102 | ) 103 | 104 | with self.track_latency(): 105 | response = self.client.chat.completions.create(model=self.model, **model_inputs) 106 | 107 | completion = response.choices[0].message.content.strip() 108 | usage = response.usage 109 | 110 | meta = { 111 | "tokens_prompt": usage.prompt_tokens, 112 | "tokens_completion": usage.completion_tokens, 113 | "latency": self.latency, 114 | } 115 | return Result( 116 | text=completion, 117 | model_inputs=model_inputs, 118 | provider=self, 119 | meta=meta, 120 | ) 121 | 122 | async def acomplete( 123 | self, 124 | prompt: str, 125 | history: Optional[List[dict]] = None, 126 | system_message: Optional[List[dict]] = None, 127 | temperature: float = 0, 128 | max_tokens: int = 300, 129 | **kwargs, 130 | ) -> Result: 131 | model_inputs = self._prepare_model_inputs( 132 | prompt=prompt, 133 | history=history, 134 | system_message=system_message, 135 | temperature=temperature, 136 | max_tokens=max_tokens, 137 | **kwargs, 138 | ) 139 | 140 | with self.track_latency(): 141 | response = await self.async_client.chat.completions.create(model=self.model, **model_inputs) 142 | 143 | completion = response.choices[0].message.content.strip() 144 | usage = response.usage 145 | 146 | meta = { 147 | "tokens_prompt": usage.prompt_tokens, 148 | "tokens_completion": usage.completion_tokens, 149 | "latency": self.latency, 150 | } 151 | return Result( 152 | text=completion, 153 | model_inputs=model_inputs, 154 | provider=self, 155 | meta=meta, 156 | ) 157 | 158 | def complete_stream( 159 | self, 160 | prompt: str, 161 | history: Optional[List[dict]] = None, 162 | system_message: Union[str, List[dict], None] = None, 163 | temperature: float = 0, 164 | max_tokens: int = 300, 165 | **kwargs, 166 | ) -> StreamResult: 167 | model_inputs = self._prepare_model_inputs( 168 | prompt=prompt, 169 | history=history, 170 | system_message=system_message, 171 | temperature=temperature, 172 | max_tokens=max_tokens, 173 | stream=True, 174 | **kwargs, 175 | ) 176 | 177 | response = self.client.chat.completions.create(model=self.model, **model_inputs) 178 | stream = self._process_stream(response) 179 | 180 | return StreamResult(stream=stream, model_inputs=model_inputs, provider=self) 181 | 182 | def _process_stream(self, response): 183 | for chunk in response: 184 | if chunk.choices[0].delta.content is not None: 185 | yield chunk.choices[0].delta.content 186 | 187 | async def acomplete_stream( 188 | self, 189 | prompt: str, 190 | history: Optional[List[dict]] = None, 191 | system_message: Union[str, List[dict], None] = None, 192 | temperature: float = 0, 193 | max_tokens: int = 300, 194 | **kwargs, 195 | ) -> AsyncStreamResult: 196 | model_inputs = self._prepare_model_inputs( 197 | prompt=prompt, 198 | history=history, 199 | system_message=system_message, 200 | temperature=temperature, 201 | max_tokens=max_tokens, 202 | stream=True, 203 | **kwargs, 204 | ) 205 | 206 | response = await self.async_client.chat.completions.create(model=self.model, **model_inputs) 207 | stream = self._aprocess_stream(response) 208 | return AsyncStreamResult( 209 | stream=stream, model_inputs=model_inputs, provider=self 210 | ) 211 | 212 | async def _aprocess_stream(self, response): 213 | async for chunk in response: 214 | if chunk.choices[0].delta.content is not None: 215 | yield chunk.choices[0].delta.content 216 | -------------------------------------------------------------------------------- /llms/providers/groq.py: -------------------------------------------------------------------------------- 1 | from typing import AsyncGenerator, Dict, List, Optional, Union 2 | import tiktoken 3 | 4 | from openai import AsyncOpenAI, OpenAI 5 | 6 | from ..results.result import AsyncStreamResult, Result, StreamResult 7 | from .base_provider import BaseProvider 8 | 9 | 10 | class GroqProvider(BaseProvider): 11 | MODEL_INFO = { 12 | "llama-3.1-405b-reasoning": {"prompt": 0.59, "completion": 0.79, "token_limit": 131072, "is_chat": True}, 13 | "llama-3.1-70b-versatile": {"prompt": 0.59, "completion": 0.79, "token_limit": 131072, "is_chat": True}, 14 | "llama-3.1-8b-instant": {"prompt": 0.05, "completion": 0.08, "token_limit": 131072, "is_chat": True}, 15 | "gemma2-9b-it": {"prompt": 0.20, "completion": 0.20, "token_limit": 131072, "is_chat": True}, 16 | "llama-3.3-70b-versatile": {"prompt": 0.59, "completion": 0.79, "token_limit": 131072, "is_chat": True}, 17 | "deepseek-r1-distill-llama-70b": {"prompt": 0.59, "completion": 0.99, "token_limit": 131072, "is_chat": True}, 18 | "meta-llama/llama-4-maverick-17b-128e-instruct": {"prompt": 0.20, "completion": 0.60, "token_limit": 131072, "is_chat": True}, 19 | "meta-llama/llama-4-scout-17b-16e-instruct": {"prompt": 0.11, "completion": 0.34, "token_limit": 131072, "is_chat": True}, 20 | } 21 | 22 | def __init__( 23 | self, 24 | api_key: Union[str, None] = None, 25 | model: Union[str, None] = None, 26 | client_kwargs: Union[dict, None] = None, 27 | async_client_kwargs: Union[dict, None] = None, 28 | ): 29 | if model is None: 30 | model = list(self.MODEL_INFO.keys())[0] 31 | self.model = model 32 | if client_kwargs is None: 33 | client_kwargs = {} 34 | self.client = OpenAI(api_key=api_key, base_url="https://api.groq.com/openai/v1", **client_kwargs) 35 | if async_client_kwargs is None: 36 | async_client_kwargs = {} 37 | self.async_client = AsyncOpenAI(api_key=api_key, base_url="https://api.groq.com/openai/v1", **async_client_kwargs) 38 | 39 | @property 40 | def is_chat_model(self) -> bool: 41 | return self.MODEL_INFO[self.model]['is_chat'] 42 | 43 | def count_tokens(self, content: Union[str, List[dict]]) -> int: 44 | # Groq uses the same tokenizer as OpenAI 45 | enc = tiktoken.encoding_for_model("gpt-3.5-turbo") 46 | if isinstance(content, list): 47 | formatting_token_count = 4 48 | messages = content 49 | messages_text = ["".join(message.values()) for message in messages] 50 | tokens = [enc.encode(t, disallowed_special=()) for t in messages_text] 51 | 52 | n_tokens_list = [] 53 | for token, message in zip(tokens, messages): 54 | n_tokens = len(token) + formatting_token_count 55 | if "name" in message: 56 | n_tokens += -1 57 | n_tokens_list.append(n_tokens) 58 | return sum(n_tokens_list) 59 | else: 60 | return len(enc.encode(content, disallowed_special=())) 61 | 62 | def _prepare_model_inputs( 63 | self, 64 | prompt: str, 65 | history: Optional[List[dict]] = None, 66 | system_message: Union[str, List[dict], None] = None, 67 | temperature: float = 0, 68 | max_tokens: int = 300, 69 | stream: bool = False, 70 | **kwargs, 71 | ) -> Dict: 72 | messages = [{"role": "user", "content": prompt}] 73 | 74 | if history: 75 | messages = [*history, *messages] 76 | 77 | if isinstance(system_message, str): 78 | messages = [{"role": "system", "content": system_message}, *messages] 79 | elif isinstance(system_message, list): 80 | messages = [*system_message, *messages] 81 | 82 | model_inputs = { 83 | "messages": messages, 84 | "temperature": temperature, 85 | "max_tokens": max_tokens, 86 | "stream": stream, 87 | **kwargs, 88 | } 89 | return model_inputs 90 | 91 | def complete( 92 | self, 93 | prompt: str, 94 | history: Optional[List[dict]] = None, 95 | system_message: Optional[List[dict]] = None, 96 | temperature: float = 0, 97 | max_tokens: int = 300, 98 | **kwargs, 99 | ) -> Result: 100 | model_inputs = self._prepare_model_inputs( 101 | prompt=prompt, 102 | history=history, 103 | system_message=system_message, 104 | temperature=temperature, 105 | max_tokens=max_tokens, 106 | **kwargs, 107 | ) 108 | 109 | with self.track_latency(): 110 | response = self.client.chat.completions.create(model=self.model, **model_inputs) 111 | 112 | completion = response.choices[0].message.content.strip() 113 | usage = response.usage 114 | 115 | meta = { 116 | "tokens_prompt": usage.prompt_tokens, 117 | "tokens_completion": usage.completion_tokens, 118 | "latency": self.latency, 119 | } 120 | return Result( 121 | text=completion, 122 | model_inputs=model_inputs, 123 | provider=self, 124 | meta=meta, 125 | ) 126 | 127 | async def acomplete( 128 | self, 129 | prompt: str, 130 | history: Optional[List[dict]] = None, 131 | system_message: Optional[List[dict]] = None, 132 | temperature: float = 0, 133 | max_tokens: int = 300, 134 | **kwargs, 135 | ) -> Result: 136 | model_inputs = self._prepare_model_inputs( 137 | prompt=prompt, 138 | history=history, 139 | system_message=system_message, 140 | temperature=temperature, 141 | max_tokens=max_tokens, 142 | **kwargs, 143 | ) 144 | 145 | with self.track_latency(): 146 | response = await self.async_client.chat.completions.create(model=self.model, **model_inputs) 147 | 148 | completion = response.choices[0].message.content.strip() 149 | usage = response.usage 150 | 151 | meta = { 152 | "tokens_prompt": usage.prompt_tokens, 153 | "tokens_completion": usage.completion_tokens, 154 | "latency": self.latency, 155 | } 156 | return Result( 157 | text=completion, 158 | model_inputs=model_inputs, 159 | provider=self, 160 | meta=meta, 161 | ) 162 | 163 | def complete_stream( 164 | self, 165 | prompt: str, 166 | history: Optional[List[dict]] = None, 167 | system_message: Union[str, List[dict], None] = None, 168 | temperature: float = 0, 169 | max_tokens: int = 300, 170 | **kwargs, 171 | ) -> StreamResult: 172 | model_inputs = self._prepare_model_inputs( 173 | prompt=prompt, 174 | history=history, 175 | system_message=system_message, 176 | temperature=temperature, 177 | max_tokens=max_tokens, 178 | stream=True, 179 | **kwargs, 180 | ) 181 | 182 | response = self.client.chat.completions.create(model=self.model, **model_inputs) 183 | stream = self._process_stream(response) 184 | 185 | return StreamResult(stream=stream, model_inputs=model_inputs, provider=self) 186 | 187 | def _process_stream(self, response): 188 | for chunk in response: 189 | if chunk.choices[0].delta.content is not None: 190 | yield chunk.choices[0].delta.content 191 | 192 | async def acomplete_stream( 193 | self, 194 | prompt: str, 195 | history: Optional[List[dict]] = None, 196 | system_message: Union[str, List[dict], None] = None, 197 | temperature: float = 0, 198 | max_tokens: int = 300, 199 | **kwargs, 200 | ) -> AsyncStreamResult: 201 | model_inputs = self._prepare_model_inputs( 202 | prompt=prompt, 203 | history=history, 204 | system_message=system_message, 205 | temperature=temperature, 206 | max_tokens=max_tokens, 207 | stream=True, 208 | **kwargs, 209 | ) 210 | 211 | response = await self.async_client.chat.completions.create(model=self.model, **model_inputs) 212 | stream = self._aprocess_stream(response) 213 | return AsyncStreamResult( 214 | stream=stream, model_inputs=model_inputs, provider=self 215 | ) 216 | 217 | async def _aprocess_stream(self, response): 218 | async for chunk in response: 219 | if chunk.choices[0].delta.content is not None: 220 | yield chunk.choices[0].delta.content 221 | -------------------------------------------------------------------------------- /llms/providers/ollama.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Any, Dict, Generator, List, Optional, Union, AsyncGenerator 3 | 4 | from ollama import Client, AsyncClient 5 | 6 | from ..results.result import Result, StreamResult, AsyncStreamResult 7 | from .base_provider import BaseProvider 8 | 9 | 10 | def _get_model_info(ollama_host: Optional[str] = "http://localhost:11434"): 11 | model_info = {} 12 | try: 13 | pulled_models = Client(host=ollama_host).list().get("models", []) 14 | for model in pulled_models: 15 | name = model["name"] 16 | # Ollama models are free to use locally 17 | model_info[name] = { 18 | "prompt": 0.0, 19 | "completion": 0.0, 20 | "token_limit": 4096 # Default token limit 21 | } 22 | 23 | if not pulled_models: 24 | raise ValueError("Could not retrieve any models from Ollama") 25 | except Exception as e: 26 | # Log the error but continue with empty model info 27 | pass 28 | #print(f"Warning: Could not connect to Ollama server: {str(e)}") 29 | 30 | return model_info 31 | 32 | 33 | class OllamaProvider(BaseProvider): 34 | MODEL_INFO = _get_model_info() 35 | 36 | def count_tokens(self, content: Union[str, List[Dict[str, Any]]]) -> int: 37 | """Estimate token count using simple word-based heuristic""" 38 | if isinstance(content, list): 39 | # For chat messages, concatenate all content 40 | text = " ".join(msg["content"] for msg in content) 41 | else: 42 | text = content 43 | # Rough estimation: split on whitespace and punctuation 44 | return len(text.split()) 45 | 46 | def __init__( 47 | self, 48 | model: Optional[str] = None, 49 | ollama_host: Optional[str] = "http://localhost:11434", 50 | ollama_client_options: Optional[dict] = None 51 | ): 52 | self.model = model 53 | if self.model is None: 54 | self.model = list(self.MODEL_INFO.keys())[0] 55 | 56 | if ollama_client_options is None: 57 | ollama_client_options = {} 58 | 59 | self.client = Client(host=ollama_host, **ollama_client_options) 60 | self.async_client = AsyncClient(host=ollama_host, **ollama_client_options) 61 | self.is_chat_model = True 62 | 63 | def _prepare_model_inputs( 64 | self, 65 | prompt: str, 66 | history: Optional[List[dict]] = None, 67 | system_message: Union[str, List[dict], None] = None, 68 | stream: bool = False, 69 | max_tokens: Optional[int] = None, # Add but don't use 70 | temperature: Optional[float] = None, # Add but don't use 71 | **kwargs 72 | ) -> Dict: 73 | # Remove unsupported parameters 74 | kwargs.pop('max_tokens', None) 75 | kwargs.pop('temperature', None) 76 | if self.is_chat_model: 77 | messages = [{"role": "user", "content": prompt}] 78 | 79 | if history: 80 | messages = history + messages 81 | 82 | 83 | if isinstance(system_message, str): 84 | messages = [{"role": "system", "content": system_message}, *messages] 85 | elif isinstance(system_message, list): 86 | messages = [*system_message, *messages] 87 | 88 | model_inputs = { 89 | "messages": messages, 90 | "stream": stream, 91 | **kwargs, 92 | } 93 | else: 94 | if history: 95 | raise ValueError( 96 | f"history argument is not supported for {self.model} model" 97 | ) 98 | 99 | if system_message: 100 | raise ValueError( 101 | f"system_message argument is not supported for {self.model} model" 102 | ) 103 | 104 | model_inputs = { 105 | "prompt": prompt, 106 | "stream": stream, 107 | **kwargs, 108 | } 109 | 110 | return model_inputs 111 | 112 | def complete( 113 | self, 114 | prompt: str, 115 | history: Optional[List[dict]] = None, 116 | system_message: Optional[List[dict]] = None, 117 | **kwargs 118 | ) -> Result: 119 | try: 120 | model_inputs = self._prepare_model_inputs( 121 | prompt=prompt, 122 | history=history, 123 | system_message=system_message, 124 | **kwargs 125 | ) 126 | 127 | with self.track_latency(): 128 | response = self.client.chat(model=self.model, **model_inputs) 129 | 130 | message = response["message"] 131 | completion = message["content"].strip() 132 | except Exception as e: 133 | raise RuntimeError(f"Ollama completion failed: {str(e)}") 134 | 135 | meta = { 136 | "tokens_prompt": response["prompt_eval_count"], 137 | "tokens_completion": response["eval_count"], 138 | "latency": self.latency, 139 | } 140 | 141 | return Result( 142 | text=completion, 143 | model_inputs=model_inputs, 144 | provider=self, 145 | meta=meta, 146 | ) 147 | 148 | def complete_stream( 149 | self, 150 | prompt: str, 151 | history: Optional[List[dict]] = None, 152 | system_message: Optional[List[dict]] = None, 153 | temperature: float = 0, 154 | max_tokens: int = 300, 155 | safe_prompt: bool = False, 156 | random_seed: Union[int, None] = None, 157 | **kwargs, 158 | ) -> StreamResult: 159 | model_inputs = self._prepare_model_inputs( 160 | prompt=prompt, 161 | history=history, 162 | system_message=system_message, 163 | stream=True, 164 | **kwargs 165 | ) 166 | 167 | with self.track_latency(): 168 | response = self.client.chat(model=self.model, **model_inputs) 169 | stream = self._process_stream(response=response) 170 | 171 | return StreamResult(stream=stream, model_inputs=model_inputs, provider=self) 172 | 173 | def _process_stream(self, response: Generator) -> Generator: 174 | chunk_generator = (chunk["message"]["content"] for chunk in response) 175 | 176 | while not (first_text := next(chunk_generator)): 177 | continue 178 | 179 | yield first_text.lstrip() 180 | for chunk in chunk_generator: 181 | if chunk is not None: 182 | yield chunk 183 | 184 | async def _aprocess_stream(self, response) -> AsyncGenerator: 185 | while True: 186 | first_completion = (await response.__anext__())["message"]["content"] 187 | if first_completion: 188 | yield first_completion.lstrip() 189 | break 190 | 191 | async def acomplete( 192 | self, 193 | prompt: str, 194 | history: Optional[List[dict]] = None, 195 | system_message: Optional[List[dict]] = None, 196 | **kwargs 197 | ) -> Result: 198 | try: 199 | model_inputs = self._prepare_model_inputs( 200 | prompt=prompt, 201 | history=history, 202 | system_message=system_message, 203 | **kwargs 204 | ) 205 | 206 | with self.track_latency(): 207 | response = await self.async_client.chat(model=self.model, **model_inputs) 208 | 209 | message = response["message"] 210 | completion = "" 211 | completion = message["content"].strip() 212 | 213 | meta = { 214 | "tokens_prompt": response["prompt_eval_count"], 215 | "tokens_completion": response["eval_count"], 216 | "latency": self.latency, 217 | } 218 | except Exception as e: 219 | raise RuntimeError(f"Ollama completion failed: {str(e)}") 220 | 221 | return Result( 222 | text=completion, 223 | model_inputs=model_inputs, 224 | provider=self, 225 | meta=meta, 226 | ) 227 | 228 | async def acomplete_stream( 229 | self, 230 | prompt: str, 231 | history: Optional[List[dict]] = None, 232 | system_message: Optional[List[dict]] = None, 233 | temperature: float = 0, 234 | max_tokens: int = 300, 235 | safe_prompt: bool = False, 236 | random_seed: Union[int, None] = None, 237 | **kwargs 238 | ): 239 | model_inputs = self._prepare_model_inputs( 240 | prompt=prompt, 241 | history=history, 242 | system_message=system_message, 243 | stream=True, 244 | **kwargs 245 | ) 246 | 247 | with self.track_latency(): 248 | response = self.async_client.chat(model=self.model, **model_inputs) 249 | stream = self._aprocess_stream(response=response) 250 | 251 | return AsyncStreamResult(stream=stream, model_inputs=model_inputs, provider=self) 252 | -------------------------------------------------------------------------------- /llms/providers/mistral.py: -------------------------------------------------------------------------------- 1 | import tiktoken 2 | from typing import Dict, Union, Optional, List, Generator, AsyncGenerator 3 | from mistralai import Mistral 4 | 5 | from ..results.result import AsyncStreamResult, Result, StreamResult 6 | from .base_provider import BaseProvider 7 | 8 | 9 | class MistralProvider(BaseProvider): 10 | MODEL_INFO = { 11 | "mistral-tiny": {"prompt": 0.25, "completion": 0.25, "token_limit": 32_000}, 12 | # new endpoint for mistral-tiny, mistral-tiny will be deprecated in ~June 2024 13 | "open-mistral-7b": {"prompt": 0.25, "completion": 0.25, "token_limit": 32_000}, 14 | "mistral-small": {"prompt": 0.7, "completion": 0.7, "token_limit": 32_000}, 15 | # new endpoint for mistral-small, mistral-small will be deprecated in ~June 2024 16 | "open-mixtral-8x7b": {"prompt": 0.7, "completion": 0.7, "token_limit": 32_000}, 17 | "mistral-small-latest": {"prompt": 2.0, "completion": 6.0, "token_limit": 32_000}, 18 | "mistral-medium-latest": {"prompt": 2.7, "completion": 8.1, "token_limit": 32_000}, 19 | "mistral-large-latest": {"prompt": 3.0, "completion": 9.0, "token_limit": 32_000}, 20 | "open-mistral-nemo": {"prompt": 0.3, "completion": 0.3, "token_limit": 32_000}, 21 | } 22 | 23 | def __init__( 24 | self, 25 | api_key: Union[str, None] = None, 26 | model: Union[str, None] = None, 27 | client_kwargs: Union[dict, None] = None, 28 | async_client_kwargs: Union[dict, None] = None, 29 | ): 30 | 31 | if model is None: 32 | model = list(self.MODEL_INFO.keys())[0] 33 | self.model = model 34 | 35 | if client_kwargs is None: 36 | client_kwargs = {} 37 | self.client = Mistral(api_key=api_key, **client_kwargs) 38 | 39 | # Use same client for both sync and async 40 | self.async_client = self.client 41 | 42 | def count_tokens(self, content: Union[str, List[Dict[str, str]]]) -> int: 43 | # TODO: update after Mistarl support count token in their SDK 44 | # use gpt 3.5 turbo for estimation now 45 | enc = tiktoken.encoding_for_model("gpt-3.5-turbo") 46 | if isinstance(content, list): 47 | formatting_token_count = 4 48 | messages = content 49 | messages_text = [f"{message['role']}{message['content']}" for message in messages] 50 | tokens = [enc.encode(t, disallowed_special=()) for t in messages_text] 51 | 52 | n_tokens_list = [] 53 | for token, message in zip(tokens, messages): 54 | n_tokens = len(token) + formatting_token_count 55 | n_tokens_list.append(n_tokens) 56 | return sum(n_tokens_list) 57 | else: 58 | return len(enc.encode(content, disallowed_special=())) 59 | 60 | def _prepare_model_inputs( 61 | self, 62 | prompt: str, 63 | history: Optional[List[dict]] = None, 64 | temperature: float = 0, 65 | max_tokens: int = 300, 66 | stop_sequences: Optional[List[str]] = None, 67 | system_message: Union[str, None] = None, 68 | safe_prompt: bool = False, 69 | random_seed: Union[int, None] = None, 70 | **kwargs, 71 | ) -> Dict: 72 | if stop_sequences: 73 | raise ValueError("Parameter `stop` is not supported") 74 | 75 | messages = [{"role": "user", "content": prompt}] 76 | if history: 77 | messages = [{"role": msg["role"], "content": msg["content"]} for msg in history] + messages 78 | 79 | if system_message is None: 80 | pass 81 | elif isinstance(system_message, str): 82 | messages = [{"role": "system", "content": system_message}, *messages] 83 | 84 | model_inputs = { 85 | "messages": messages, 86 | "temperature": temperature, 87 | "max_tokens": max_tokens, 88 | "safe_prompt": safe_prompt, 89 | "random_seed": random_seed, 90 | **kwargs, 91 | } 92 | 93 | return model_inputs 94 | 95 | def complete( 96 | self, 97 | prompt: str, 98 | history: Optional[List[dict]] = None, 99 | system_message: Optional[List[dict]] = None, 100 | temperature: float = 0, 101 | max_tokens: int = 300, 102 | safe_prompt: bool = False, 103 | random_seed: Union[int, None] = None, 104 | **kwargs, 105 | ) -> Result: 106 | model_inputs = self._prepare_model_inputs( 107 | prompt=prompt, 108 | history=history, 109 | system_message=system_message, 110 | temperature=temperature, 111 | max_tokens=max_tokens, 112 | safe_prompt=safe_prompt, 113 | random_seed=random_seed, 114 | **kwargs, 115 | ) 116 | 117 | with self.track_latency(): 118 | response = self.client.chat.complete(model=self.model, **model_inputs) 119 | 120 | completion = response.choices[0].message.content 121 | usage = response.usage 122 | 123 | meta = { 124 | "tokens_prompt": usage.prompt_tokens, 125 | "tokens_completion": usage.completion_tokens, 126 | "latency": self.latency, 127 | } 128 | 129 | return Result( 130 | text=completion, 131 | model_inputs=model_inputs, 132 | provider=self, 133 | meta=meta, 134 | ) 135 | 136 | async def acomplete( 137 | self, 138 | prompt: str, 139 | history: Optional[List[dict]] = None, 140 | system_message: Optional[List[dict]] = None, 141 | temperature: float = 0, 142 | max_tokens: int = 300, 143 | safe_prompt: bool = False, 144 | random_seed: Union[int, None] = None, 145 | **kwargs, 146 | ) -> Result: 147 | 148 | model_inputs = self._prepare_model_inputs( 149 | prompt=prompt, 150 | history=history, 151 | system_message=system_message, 152 | temperature=temperature, 153 | max_tokens=max_tokens, 154 | safe_prompt=safe_prompt, 155 | random_seed=random_seed, 156 | **kwargs, 157 | ) 158 | with self.track_latency(): 159 | response = await self.async_client.chat.complete(model=self.model, **model_inputs) 160 | 161 | completion = response.choices[0].message.content 162 | usage = response.usage 163 | 164 | meta = { 165 | "tokens_prompt": usage.prompt_tokens, 166 | "tokens_completion": usage.completion_tokens, 167 | "latency": self.latency, 168 | } 169 | 170 | return Result( 171 | text=completion, 172 | model_inputs=model_inputs, 173 | provider=self, 174 | meta=meta, 175 | ) 176 | 177 | def complete_stream( 178 | self, 179 | prompt: str, 180 | history: Optional[List[dict]] = None, 181 | system_message: Optional[List[dict]] = None, 182 | temperature: float = 0, 183 | max_tokens: int = 300, 184 | safe_prompt: bool = False, 185 | random_seed: Union[int, None] = None, 186 | **kwargs, 187 | ) -> StreamResult: 188 | 189 | model_inputs = self._prepare_model_inputs( 190 | prompt=prompt, 191 | history=history, 192 | system_message=system_message, 193 | temperature=temperature, 194 | max_tokens=max_tokens, 195 | safe_prompt=safe_prompt, 196 | random_seed=random_seed, 197 | **kwargs, 198 | ) 199 | 200 | model_inputs["stream"] = True 201 | response = self.client.chat.complete(model=self.model, **model_inputs) 202 | stream = self._process_stream(response) 203 | return StreamResult(stream=stream, model_inputs=model_inputs, provider=self) 204 | 205 | def _process_stream(self, response: Generator) -> Generator: 206 | chunk_generator = ( 207 | chunk.choices[0].delta.content for chunk in response 208 | ) 209 | 210 | while not (first_text := next(chunk_generator)): 211 | continue 212 | yield first_text.lstrip() 213 | for chunk in chunk_generator: 214 | if chunk is not None: 215 | yield chunk 216 | 217 | async def acomplete_stream( 218 | self, 219 | prompt: str, 220 | history: Optional[List[dict]] = None, 221 | system_message: Optional[List[dict]] = None, 222 | temperature: float = 0, 223 | max_tokens: int = 300, 224 | safe_prompt: bool = False, 225 | random_seed: Union[int, None] = None, 226 | **kwargs, 227 | ) -> AsyncStreamResult: 228 | 229 | model_inputs = self._prepare_model_inputs( 230 | prompt=prompt, 231 | history=history, 232 | system_message=system_message, 233 | temperature=temperature, 234 | max_tokens=max_tokens, 235 | safe_prompt=safe_prompt, 236 | random_seed=random_seed, 237 | **kwargs, 238 | ) 239 | 240 | with self.track_latency(): 241 | model_inputs["stream"] = True 242 | response = self.async_client.chat.complete(model=self.model, **model_inputs) 243 | stream = self._aprocess_stream(response) 244 | return AsyncStreamResult( 245 | stream=stream, model_inputs=model_inputs, provider=self 246 | ) 247 | 248 | async def _aprocess_stream(self, response) -> AsyncGenerator: 249 | while True: 250 | first_completion = (await response.__anext__()).choices[0].delta.content 251 | if first_completion: 252 | yield first_completion.lstrip() 253 | break 254 | 255 | async for chunk in response: 256 | completion = chunk.choices[0].delta.content 257 | if completion is not None: 258 | yield completion 259 | -------------------------------------------------------------------------------- /llms/providers/openrouter.py: -------------------------------------------------------------------------------- 1 | from typing import AsyncGenerator, Dict, List, Optional, Union 2 | import tiktoken 3 | 4 | from openai import AsyncOpenAI, OpenAI 5 | 6 | from ..results.result import AsyncStreamResult, Result, StreamResult 7 | from .base_provider import BaseProvider 8 | 9 | 10 | class OpenRouterProvider(BaseProvider): 11 | MODEL_INFO = { 12 | "nvidia/llama-3.1-nemotron-70b-instruct": {"prompt": 0.35, "completion": 0.4, "token_limit": 131072, "is_chat": True}, 13 | "x-ai/grok-2": {"prompt": 5.0, "completion": 10.0, "token_limit": 32768, "is_chat": True}, 14 | "nousresearch/hermes-3-llama-3.1-405b:free": {"prompt": 0.0, "completion": 0.0, "token_limit": 8192, "is_chat": True}, 15 | "google/gemini-flash-1.5-exp": {"prompt": 0.0, "completion": 0.0, "token_limit": 1000000, "is_chat": True}, 16 | "liquid/lfm-40b": {"prompt": 0.0, "completion": 0.0, "token_limit": 32768, "is_chat": True}, 17 | "mistralai/ministral-8b": {"prompt": 0.1, "completion": 0.1, "token_limit": 128000, "is_chat": True}, 18 | "qwen/qwen-2.5-72b-instruct": {"prompt": 0.35, "completion": 0.4, "token_limit": 131072, "is_chat": True}, 19 | "openai/o1": {"prompt": 15.0, "completion": 60.0, "token_limit": 200000, "is_chat": True}, 20 | "google/gemini-2.0-flash-thinking-exp:free": {"prompt": 0.0, "completion": 0.0, "token_limit": 40000, "is_chat": True}, 21 | "x-ai/grok-2-1212": {"prompt": 2.0, "completion": 10.0, "token_limit": 131072, "is_chat": True}, 22 | "google/gemini-exp-1206:free": {"prompt": 0.0, "completion": 0.0, "token_limit": 2100000, "is_chat": True}, 23 | "google/gemini-2.0-flash-exp:free": {"prompt": 0.0, "completion": 0.0, "token_limit": 1050000, "is_chat": True}, 24 | "deepseek/deepseek-r1-distill-llama-70b": {"prompt": 0.23, "completion": 0.69, "token_limit": 131000, "is_chat": True}, 25 | "moonshotai/kimi-k2": { 26 | "prompt": 0.14, 27 | "completion": 2.49, 28 | "token_limit": 131072, 29 | "is_chat": True 30 | }, 31 | } 32 | 33 | def __init__( 34 | self, 35 | api_key: Union[str, None] = None, 36 | model: Union[str, None] = None, 37 | client_kwargs: Union[dict, None] = None, 38 | async_client_kwargs: Union[dict, None] = None, 39 | ): 40 | if model is None: 41 | model = list(self.MODEL_INFO.keys())[0] 42 | self.model = model 43 | if client_kwargs is None: 44 | client_kwargs = {} 45 | self.client = OpenAI(api_key=api_key, base_url="https://openrouter.ai/api/v1", **client_kwargs) 46 | if async_client_kwargs is None: 47 | async_client_kwargs = {} 48 | self.async_client = AsyncOpenAI(api_key=api_key, base_url="https://openrouter.ai/api/v1", **async_client_kwargs) 49 | 50 | @property 51 | def is_chat_model(self) -> bool: 52 | return self.MODEL_INFO[self.model]['is_chat'] 53 | 54 | def count_tokens(self, content: Union[str, List[dict]]) -> int: 55 | # OpenRouter uses the same tokenizer as OpenAI 56 | enc = tiktoken.encoding_for_model("gpt-3.5-turbo") 57 | if isinstance(content, list): 58 | formatting_token_count = 4 59 | messages = content 60 | messages_text = ["".join(message.values()) for message in messages] 61 | tokens = [enc.encode(t, disallowed_special=()) for t in messages_text] 62 | 63 | n_tokens_list = [] 64 | for token, message in zip(tokens, messages): 65 | n_tokens = len(token) + formatting_token_count 66 | if "name" in message: 67 | n_tokens += -1 68 | n_tokens_list.append(n_tokens) 69 | return sum(n_tokens_list) 70 | else: 71 | return len(enc.encode(content, disallowed_special=())) 72 | 73 | def _prepare_model_inputs( 74 | self, 75 | prompt: str, 76 | history: Optional[List[dict]] = None, 77 | system_message: Union[str, List[dict], None] = None, 78 | temperature: float = 0, 79 | max_tokens: int = 300, 80 | stream: bool = False, 81 | **kwargs, 82 | ) -> Dict: 83 | messages = [{"role": "user", "content": prompt}] 84 | 85 | if history: 86 | messages = [*history, *messages] 87 | 88 | if isinstance(system_message, str): 89 | messages = [{"role": "system", "content": system_message}, *messages] 90 | elif isinstance(system_message, list): 91 | messages = [*system_message, *messages] 92 | 93 | model_inputs = { 94 | "messages": messages, 95 | "temperature": temperature, 96 | "max_tokens": max_tokens, 97 | "stream": stream, 98 | "extra_headers": { 99 | "HTTP-Referer": kwargs.get("site_url", ""), 100 | "X-Title": kwargs.get("app_name", ""), 101 | }, 102 | **kwargs, 103 | } 104 | return model_inputs 105 | 106 | def complete( 107 | self, 108 | prompt: str, 109 | history: Optional[List[dict]] = None, 110 | system_message: Optional[List[dict]] = None, 111 | temperature: float = 0, 112 | max_tokens: int = 300, 113 | **kwargs, 114 | ) -> Result: 115 | model_inputs = self._prepare_model_inputs( 116 | prompt=prompt, 117 | history=history, 118 | system_message=system_message, 119 | temperature=temperature, 120 | max_tokens=max_tokens, 121 | **kwargs, 122 | ) 123 | 124 | with self.track_latency(): 125 | response = self.client.chat.completions.create(model=self.model, **model_inputs) 126 | 127 | if not response or not hasattr(response, 'choices') or not response.choices: 128 | raise ValueError("Unexpected response structure from OpenRouter API") 129 | 130 | completion = response.choices[0].message.content.strip() if response.choices[0].message else "" 131 | usage = response.usage if hasattr(response, 'usage') else None 132 | 133 | meta = { 134 | "tokens_prompt": usage.prompt_tokens if usage else 0, 135 | "tokens_completion": usage.completion_tokens if usage else 0, 136 | "latency": self.latency, 137 | } 138 | return Result( 139 | text=completion, 140 | model_inputs=model_inputs, 141 | provider=self, 142 | meta=meta, 143 | ) 144 | 145 | async def acomplete( 146 | self, 147 | prompt: str, 148 | history: Optional[List[dict]] = None, 149 | system_message: Optional[List[dict]] = None, 150 | temperature: float = 0, 151 | max_tokens: int = 300, 152 | **kwargs, 153 | ) -> Result: 154 | model_inputs = self._prepare_model_inputs( 155 | prompt=prompt, 156 | history=history, 157 | system_message=system_message, 158 | temperature=temperature, 159 | max_tokens=max_tokens, 160 | **kwargs, 161 | ) 162 | 163 | with self.track_latency(): 164 | response = await self.async_client.chat.completions.create(model=self.model, **model_inputs) 165 | 166 | completion = response.choices[0].message.content.strip() 167 | usage = response.usage 168 | 169 | meta = { 170 | "tokens_prompt": usage.prompt_tokens, 171 | "tokens_completion": usage.completion_tokens, 172 | "latency": self.latency, 173 | } 174 | return Result( 175 | text=completion, 176 | model_inputs=model_inputs, 177 | provider=self, 178 | meta=meta, 179 | ) 180 | 181 | def complete_stream( 182 | self, 183 | prompt: str, 184 | history: Optional[List[dict]] = None, 185 | system_message: Union[str, List[dict], None] = None, 186 | temperature: float = 0, 187 | max_tokens: int = 300, 188 | **kwargs, 189 | ) -> StreamResult: 190 | model_inputs = self._prepare_model_inputs( 191 | prompt=prompt, 192 | history=history, 193 | system_message=system_message, 194 | temperature=temperature, 195 | max_tokens=max_tokens, 196 | stream=True, 197 | **kwargs, 198 | ) 199 | 200 | response = self.client.chat.completions.create(model=self.model, **model_inputs) 201 | stream = self._process_stream(response) 202 | 203 | return StreamResult(stream=stream, model_inputs=model_inputs, provider=self) 204 | 205 | def _process_stream(self, response): 206 | for chunk in response: 207 | if chunk.choices[0].delta.content is not None: 208 | yield chunk.choices[0].delta.content 209 | 210 | async def acomplete_stream( 211 | self, 212 | prompt: str, 213 | history: Optional[List[dict]] = None, 214 | system_message: Union[str, List[dict], None] = None, 215 | temperature: float = 0, 216 | max_tokens: int = 300, 217 | **kwargs, 218 | ) -> AsyncStreamResult: 219 | model_inputs = self._prepare_model_inputs( 220 | prompt=prompt, 221 | history=history, 222 | system_message=system_message, 223 | temperature=temperature, 224 | max_tokens=max_tokens, 225 | stream=True, 226 | **kwargs, 227 | ) 228 | 229 | response = await self.async_client.chat.completions.create(model=self.model, **model_inputs) 230 | stream = self._aprocess_stream(response) 231 | return AsyncStreamResult( 232 | stream=stream, model_inputs=model_inputs, provider=self 233 | ) 234 | 235 | async def _aprocess_stream(self, response): 236 | async for chunk in response: 237 | if chunk.choices[0].delta.content is not None: 238 | yield chunk.choices[0].delta.content 239 | -------------------------------------------------------------------------------- /llms/results/result.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import json 3 | from typing import AsyncGenerator, Dict, Generator, List, Optional 4 | from warnings import warn 5 | 6 | from llms.providers.base_provider import BaseProvider 7 | 8 | 9 | class Result: 10 | def __init__( 11 | self, 12 | text: str, 13 | model_inputs: Dict, 14 | provider: BaseProvider, 15 | meta: Optional[Dict] = None, 16 | function_call: Optional[Dict] = None, 17 | ): 18 | self._meta = meta or {} 19 | self.text = text 20 | self.provider = provider 21 | self.model_inputs = model_inputs 22 | self.function_call = function_call or {} 23 | 24 | @property 25 | def tokens_completion(self) -> int: 26 | if tokens_completion := self._meta.get("tokens_completion"): 27 | return tokens_completion 28 | else: 29 | tokens_completion = self.provider.count_tokens(self.text) 30 | self._meta["tokens_completion"] = tokens_completion 31 | return tokens_completion 32 | 33 | @property 34 | def tokens_prompt(self) -> int: 35 | if tokens_prompt := self._meta.get("tokens_prompt"): 36 | return tokens_prompt 37 | else: 38 | prompt = self.model_inputs.get("prompt") or self.model_inputs.get("messages") 39 | tokens_prompt = self.provider.count_tokens(prompt) 40 | self._meta["tokens_prompt"] = tokens_prompt 41 | return tokens_prompt 42 | 43 | @property 44 | def tokens(self) -> int: 45 | return self.tokens_completion + self.tokens_prompt 46 | 47 | @property 48 | def cost(self) -> float: 49 | if cost := self._meta.get("cost"): 50 | return cost 51 | else: 52 | cost = self.provider.compute_cost( 53 | prompt_tokens=self.tokens_prompt, completion_tokens=self.tokens_completion 54 | ) 55 | self._meta["cost"] = cost 56 | return cost 57 | 58 | @property 59 | def meta(self) -> Dict: 60 | return { 61 | "model": self.provider.model, 62 | "tokens": self.tokens, 63 | "tokens_prompt": self.tokens_prompt, 64 | "tokens_completion": self.tokens_completion, 65 | "cost": self.cost, 66 | "latency": self._meta.get("latency"), 67 | } 68 | 69 | def to_json(self): 70 | model_inputs = self.model_inputs 71 | # remove https related params 72 | model_inputs.pop("headers", None) 73 | model_inputs.pop("request_timeout", None) 74 | model_inputs.pop("aiosession", None) 75 | return json.dumps( 76 | { 77 | "text": self.text, 78 | "meta": self.meta, 79 | "model_inputs": model_inputs, 80 | "provider": str(self.provider), 81 | "function_call": self.function_call 82 | } 83 | ) 84 | 85 | 86 | class Results: 87 | def __init__(self, results: List[Result]): 88 | self._results = results 89 | 90 | @property 91 | def text(self): 92 | return [result.text for result in self._results] 93 | 94 | @property 95 | def meta(self): 96 | return [result.meta for result in self._results] 97 | 98 | def to_json(self): 99 | return json.dumps([result.to_json() for result in self._results]) 100 | 101 | 102 | class StreamResult: 103 | def __init__( 104 | self, 105 | stream: Generator, 106 | model_inputs: Dict, 107 | provider: BaseProvider, 108 | meta: Optional[Dict] = None, 109 | ): 110 | self._stream = stream 111 | self._meta = meta or {} 112 | self.provider = provider 113 | self.model_inputs = model_inputs 114 | 115 | self._streamed_text = [] 116 | 117 | def __iter__(self): 118 | warn( 119 | "Looping through result will be deprecated, please loop through result.stream instead.", 120 | DeprecationWarning, 121 | stacklevel=2, 122 | ) 123 | yield from self.stream 124 | 125 | @property 126 | def stream(self): 127 | if not inspect.getgeneratorstate(self._stream) == "GEN_CLOSED": 128 | for item in self._stream: 129 | self._streamed_text.append(item) 130 | yield item 131 | else: 132 | yield from iter(self._streamed_text) 133 | 134 | @property 135 | def text(self) -> str: 136 | _ = all(self.stream) 137 | return "".join(self._streamed_text) 138 | 139 | @property 140 | def tokens_completion(self) -> int: 141 | if tokens_completion := self._meta.get("tokens_completion"): 142 | return tokens_completion 143 | else: 144 | tokens_completion = self.provider.count_tokens(self.text) 145 | self._meta["tokens_completion"] = tokens_completion 146 | return tokens_completion 147 | 148 | @property 149 | def tokens_prompt(self) -> int: 150 | if tokens_prompt := self._meta.get("tokens_prompt"): 151 | return tokens_prompt 152 | else: 153 | prompt = self.model_inputs.get("prompt") or self.model_inputs.get("messages") 154 | tokens_prompt = self.provider.count_tokens(prompt) 155 | self._meta["tokens_prompt"] = tokens_prompt 156 | return tokens_prompt 157 | 158 | @property 159 | def tokens(self) -> int: 160 | return self.tokens_completion + self.tokens_prompt 161 | 162 | @property 163 | def cost(self) -> float: 164 | if cost := self._meta.get("cost"): 165 | return cost 166 | else: 167 | cost = self.provider.compute_cost( 168 | prompt_tokens=self.tokens_prompt, completion_tokens=self.tokens_completion 169 | ) 170 | self._meta["cost"] = cost 171 | return cost 172 | 173 | @property 174 | def meta(self) -> Dict: 175 | return { 176 | "model": self.provider.model, 177 | "tokens": self.tokens, 178 | "tokens_prompt": self.tokens_prompt, 179 | "tokens_completion": self.tokens_completion, 180 | "cost": self.cost, 181 | } 182 | 183 | def to_json(self): 184 | model_inputs = self.model_inputs 185 | # remove https related params 186 | model_inputs.pop("headers", None) 187 | model_inputs.pop("request_timeout", None) 188 | return json.dumps( 189 | { 190 | "text": self.text, 191 | "meta": self.meta, 192 | "model_inputs": model_inputs, 193 | "provider": str(self.provider), 194 | } 195 | ) 196 | 197 | 198 | class AsyncIteratorWrapper: 199 | def __init__(self, obj): 200 | self._it = iter(obj) 201 | 202 | def __aiter__(self): 203 | return self 204 | 205 | async def __anext__(self): 206 | try: 207 | value = next(self._it) 208 | except StopIteration: 209 | raise StopAsyncIteration 210 | return value 211 | 212 | 213 | class AsyncStreamResult: 214 | def __init__( 215 | self, 216 | stream: AsyncGenerator, 217 | model_inputs: Dict, 218 | provider: BaseProvider, 219 | meta: Optional[Dict] = None, 220 | ): 221 | self._stream = stream 222 | self._meta = meta or {} 223 | self.provider = provider 224 | self.model_inputs = model_inputs 225 | 226 | self._stream_exhausted = False 227 | self._streamed_text = [] 228 | 229 | def __aiter__(self): 230 | warn( 231 | "Looping through result will be deprecated, please loop through result.stream instead.", 232 | DeprecationWarning, 233 | stacklevel=2, 234 | ) 235 | return self 236 | 237 | async def __anext__(self): 238 | return await self._stream.__anext__() 239 | 240 | @property 241 | async def stream(self): 242 | if not self._stream_exhausted: 243 | async for item in self._stream: 244 | self._streamed_text.append(item) 245 | yield item 246 | self._stream_exhausted = True 247 | else: 248 | async for item in AsyncIteratorWrapper(self._streamed_text): 249 | yield item 250 | 251 | @property 252 | def text(self): 253 | if not self._stream_exhausted: 254 | raise RuntimeError("Please finish streaming the result.") 255 | return "".join(self._streamed_text) 256 | 257 | @property 258 | def tokens_completion(self) -> int: 259 | if tokens_completion := self._meta.get("tokens_completion"): 260 | return tokens_completion 261 | else: 262 | tokens_completion = self.provider.count_tokens(self.text) 263 | self._meta["tokens_completion"] = tokens_completion 264 | return tokens_completion 265 | 266 | @property 267 | def tokens_prompt(self) -> int: 268 | if tokens_prompt := self._meta.get("tokens_prompt"): 269 | return tokens_prompt 270 | else: 271 | prompt = self.model_inputs.get("prompt") or self.model_inputs.get("messages") 272 | tokens_prompt = self.provider.count_tokens(prompt) 273 | self._meta["tokens_prompt"] = tokens_prompt 274 | return tokens_prompt 275 | 276 | @property 277 | def tokens(self) -> int: 278 | return self.tokens_completion + self.tokens_prompt 279 | 280 | @property 281 | def cost(self) -> float: 282 | if cost := self._meta.get("cost"): 283 | return cost 284 | else: 285 | cost = self.provider.compute_cost( 286 | prompt_tokens=self.tokens_prompt, completion_tokens=self.tokens_completion 287 | ) 288 | self._meta["cost"] = cost 289 | return cost 290 | 291 | @property 292 | def meta(self) -> Dict: 293 | return { 294 | "model": self.provider.model, 295 | "tokens": self.tokens, 296 | "tokens_prompt": self.tokens_prompt, 297 | "tokens_completion": self.tokens_completion, 298 | "cost": self.cost, 299 | } 300 | 301 | def to_json(self): 302 | model_inputs = self.model_inputs 303 | # remove https related params 304 | model_inputs.pop("headers", None) 305 | model_inputs.pop("request_timeout", None) 306 | model_inputs.pop("aiosession", None) 307 | return json.dumps( 308 | { 309 | "text": self.text, 310 | "meta": self.meta, 311 | "model_inputs": model_inputs, 312 | "provider": str(self.provider), 313 | } 314 | ) 315 | -------------------------------------------------------------------------------- /llms/providers/anthropic.py: -------------------------------------------------------------------------------- 1 | # llms/providers/anthropic.py 2 | 3 | from typing import AsyncGenerator, Dict, Generator, List, Optional, Union 4 | 5 | import anthropic 6 | 7 | from ..results.result import AsyncStreamResult, Result, StreamResult 8 | from .base_provider import BaseProvider 9 | 10 | 11 | class AnthropicProvider(BaseProvider): 12 | MODEL_INFO = { 13 | # Legacy model 14 | "claude-2.1": {"prompt": 8.00, "completion": 24.00, "token_limit": 200_000, "output_limit": 4_096}, 15 | 16 | # Claude 3 family 17 | "claude-3-5-sonnet-20240620": {"prompt": 3.00, "completion": 15, "token_limit": 200_000, "output_limit": 4_096}, 18 | "claude-3-5-sonnet-20241022": {"prompt": 3.00, "completion": 15, "token_limit": 200_000, "output_limit": 4_096}, 19 | "claude-3-5-haiku-20241022": {"prompt": 0.80, "completion": 4, "token_limit": 200_000, "output_limit": 4_096}, 20 | 21 | # Claude 3.7 family 22 | "claude-3-7-sonnet-20250219": {"prompt": 3.00, "completion": 15, "token_limit": 200_000, "output_limit": 4_096}, 23 | 24 | # Claude 4 family (latest) 25 | "claude-sonnet-4-20250514": {"prompt": 3.00, "completion": 15, "token_limit": 200_000, "output_limit": 4_096}, 26 | "claude-3-5-haiku-20241022": {"prompt": 0.80, "completion": 4, "token_limit": 200_000, "output_limit": 4_096}, 27 | "claude-opus-4-1-20250805": { 28 | "prompt": 15.00, 29 | "completion": 75.00, 30 | "token_limit": 200_000, 31 | "output_limit": 4_096, 32 | }, 33 | 34 | "claude-opus-4-20250514": {"prompt": 15.00, "completion": 75, "token_limit": 200_000, "output_limit": 4_096}, 35 | } 36 | 37 | def __init__( 38 | self, 39 | api_key: Union[str, None] = None, 40 | model: Union[str, None] = None, 41 | client_kwargs: Union[dict, None] = None, 42 | async_client_kwargs: Union[dict, None] = None, 43 | ): 44 | if model is None: 45 | model = list(self.MODEL_INFO.keys())[0] 46 | self.model = model 47 | 48 | if client_kwargs is None: 49 | client_kwargs = {} 50 | self.client = anthropic.Anthropic(api_key=api_key, **client_kwargs) 51 | if async_client_kwargs is None: 52 | async_client_kwargs = {} 53 | self.async_client = anthropic.AsyncAnthropic(api_key=api_key, **async_client_kwargs) 54 | 55 | def count_tokens(self, content: str | Dict) -> int: 56 | """Count tokens using Anthropic's native token counting API.""" 57 | 58 | if isinstance(content, str): 59 | # For string content, format as a single user message 60 | messages = [{"role": "user", "content": content}] 61 | elif isinstance(content, list): 62 | # If it's already a list of messages, use directly 63 | messages = content 64 | elif isinstance(content, dict): 65 | # If it's a single message dict, wrap in list 66 | messages = [content] 67 | else: 68 | raise ValueError(f"Unsupported content type: {type(content)}") 69 | 70 | try: 71 | response = self.client.messages.count_tokens( 72 | model=self.model, 73 | messages=messages 74 | ) 75 | return response.input_tokens 76 | except Exception as e: 77 | # Fallback to tiktoken approximation if API fails 78 | import tiktoken 79 | enc = tiktoken.encoding_for_model("gpt-3.5-turbo") 80 | 81 | if isinstance(content, str): 82 | return len(enc.encode(content, disallowed_special=())) 83 | 84 | # Handle message format 85 | formatting_token_count = 4 86 | total = 0 87 | for message in messages: 88 | if isinstance(message.get("content"), str): 89 | total += len(enc.encode(message["content"], disallowed_special=())) + formatting_token_count 90 | return total 91 | 92 | 93 | 94 | def _prepare_message_inputs( 95 | self, 96 | prompt: str, 97 | history: Optional[List[dict]] = None, 98 | temperature: float = 0, 99 | max_tokens: int = 300, 100 | stop_sequences: Optional[List[str]] = None, 101 | ai_prompt: str = "", 102 | system_message: Union[str, None] = None, 103 | **kwargs, 104 | ) -> Dict: 105 | history = history or [] 106 | system_message = system_message or "" 107 | max_tokens = kwargs.pop("max_tokens_to_sample", max_tokens) 108 | messages = [*history, {"role": "user", "content": prompt}] 109 | if ai_prompt: 110 | messages.append({"role": "assistant", "content": ai_prompt}) 111 | 112 | if system_message and self.model.startswith("claude-instant"): 113 | raise ValueError("System message is not supported for Claude instant") 114 | model_inputs = { 115 | "messages": messages, 116 | "system": system_message, 117 | "temperature": temperature, 118 | "max_tokens": max_tokens, 119 | "stop_sequences": stop_sequences, 120 | } 121 | 122 | # Add thinking mode if specified 123 | thinking = kwargs.pop('thinking', None) 124 | if thinking is not None: 125 | model_inputs["thinking"] = { 126 | "type": "enabled", 127 | "budget_tokens": thinking if isinstance(thinking, int) else 32000 128 | } 129 | return model_inputs 130 | 131 | def _prepare_model_inputs( 132 | self, 133 | prompt: str, 134 | history: Optional[List[dict]] = None, 135 | temperature: float = 0, 136 | max_tokens: int = 300, 137 | stop_sequences: Optional[List[str]] = None, 138 | ai_prompt: str = "", 139 | system_message: Union[str, None] = None, 140 | stream: bool = False, 141 | **kwargs, 142 | ) -> Dict: 143 | return self._prepare_message_inputs( 144 | prompt=prompt, 145 | history=history, 146 | temperature=temperature, 147 | max_tokens=max_tokens, 148 | stop_sequences=stop_sequences, 149 | ai_prompt=ai_prompt, 150 | system_message=system_message, 151 | stream=stream, 152 | **kwargs, 153 | ) 154 | 155 | def complete( 156 | self, 157 | prompt: str, 158 | history: Optional[List[dict]] = None, 159 | temperature: float = 0, 160 | max_tokens: int = 300, 161 | stop_sequences: Optional[List[str]] = None, 162 | ai_prompt: str = "", 163 | system_message: Union[str, None] = None, 164 | **kwargs, 165 | ) -> Result: 166 | """ 167 | Args: 168 | history: messages in OpenAI format, 169 | each dict must include role and content key. 170 | ai_prompt: prefix of AI response, for finer control on the output. 171 | """ 172 | 173 | model_inputs = self._prepare_model_inputs( 174 | prompt=prompt, 175 | history=history, 176 | temperature=temperature, 177 | max_tokens=max_tokens, 178 | stop_sequences=stop_sequences, 179 | ai_prompt=ai_prompt, 180 | system_message=system_message, 181 | **kwargs, 182 | ) 183 | 184 | meta = {} 185 | with self.track_latency(): 186 | response = self.client.messages.create(model=self.model, **model_inputs) 187 | if "thinking" in model_inputs: 188 | text_block = next((b for b in response.content if b.type == "text"), None) 189 | completion = text_block.text if text_block else "" 190 | else: 191 | completion = response.content[0].text 192 | meta["tokens_prompt"] = response.usage.input_tokens 193 | meta["tokens_completion"] = response.usage.output_tokens 194 | 195 | meta["latency"] = self.latency 196 | return Result( 197 | text=completion, 198 | model_inputs=model_inputs, 199 | provider=self, 200 | meta=meta, 201 | ) 202 | 203 | async def acomplete( 204 | self, 205 | prompt: str, 206 | history: Optional[List[dict]] = None, 207 | temperature: float = 0, 208 | max_tokens: int = 300, 209 | stop_sequences: Optional[List[str]] = None, 210 | ai_prompt: str = "", 211 | system_message: Union[str, None] = None, 212 | **kwargs, 213 | ): 214 | """ 215 | Args: 216 | history: messages in OpenAI format, 217 | each dict must include role and content key. 218 | ai_prompt: prefix of AI response, for finer control on the output. 219 | """ 220 | 221 | model_inputs = self._prepare_model_inputs( 222 | prompt=prompt, 223 | history=history, 224 | temperature=temperature, 225 | max_tokens=max_tokens, 226 | stop_sequences=stop_sequences, 227 | ai_prompt=ai_prompt, 228 | system_message=system_message, 229 | **kwargs, 230 | ) 231 | 232 | with self.track_latency(): 233 | response = await self.async_client.messages.create(model=self.model, **model_inputs) 234 | if "thinking" in model_inputs: 235 | text_block = next((b for b in response.content if b.type == "text"), None) 236 | completion = text_block.text if text_block else "" 237 | else: 238 | completion = response.content[0].text 239 | 240 | return Result( 241 | text=completion, 242 | model_inputs=model_inputs, 243 | provider=self, 244 | meta={"latency": self.latency}, 245 | ) 246 | 247 | def complete_stream( 248 | self, 249 | prompt: str, 250 | history: Optional[List[dict]] = None, 251 | temperature: float = 0, 252 | max_tokens: int = 300, 253 | stop_sequences: Optional[List[str]] = None, 254 | ai_prompt: str = "", 255 | system_message: Union[str, None] = None, 256 | **kwargs, 257 | ) -> StreamResult: 258 | """ 259 | Args: 260 | history: messages in OpenAI format, 261 | each dict must include role and content key. 262 | ai_prompt: prefix of AI response, for finer control on the output. 263 | """ 264 | 265 | model_inputs = self._prepare_model_inputs( 266 | prompt=prompt, 267 | history=history, 268 | temperature=temperature, 269 | max_tokens=max_tokens, 270 | stop_sequences=stop_sequences, 271 | ai_prompt=ai_prompt, 272 | system_message=system_message, 273 | stream=True, 274 | **kwargs, 275 | ) 276 | with self.track_latency(): 277 | response = self.client.messages.stream(model=self.model, **model_inputs) 278 | stream = self._process_message_stream(response) 279 | 280 | return StreamResult(stream=stream, model_inputs=model_inputs, provider=self) 281 | 282 | def _process_message_stream(self, response) -> Generator: 283 | with response as stream_manager: 284 | for text in stream_manager.text_stream: 285 | yield text 286 | 287 | def _process_stream(self, response: Generator) -> Generator: 288 | first_completion = next(response).completion 289 | yield first_completion.lstrip() 290 | 291 | for data in response: 292 | yield data.completion 293 | -------------------------------------------------------------------------------- /llms/providers/google_genai.py: -------------------------------------------------------------------------------- 1 | # https://googleapis.github.io/python-genai/ 2 | 3 | import os, math 4 | from typing import Dict, Generator, AsyncGenerator 5 | 6 | from google import genai 7 | from google.genai import types 8 | 9 | from ..results.result import Result, StreamResult, AsyncStreamResult 10 | from .base_provider import BaseProvider 11 | 12 | 13 | class GoogleGenAIProvider(BaseProvider): 14 | # cost is per million tokens 15 | MODEL_INFO = { 16 | # Gemini 2.5 family - Enhanced thinking and reasoning 17 | "gemini-2.5-pro": {"prompt": 5.0, "completion": 15.0, "token_limit": 2000000, "uses_characters": True}, 18 | "gemini-2.5-flash": {"prompt": 0.1, "completion": 0.4, "token_limit": 2000000, "uses_characters": True}, 19 | "gemini-2.5-flash-lite-preview-06-17": {"prompt": 0.05, "completion": 0.2, "token_limit": 2000000, "uses_characters": True}, 20 | 21 | # Gemini 2.0 family - Next generation features and speed 22 | "gemini-2.0-flash": {"prompt": 0.075, "completion": 0.3, "token_limit": 2000000, "uses_characters": True}, 23 | "gemini-2.0-flash-lite": {"prompt": 0.0375, "completion": 0.15, "token_limit": 1000000, "uses_characters": True}, 24 | 25 | # Gemini 1.5 family - Stable and reliable models 26 | "gemini-1.5-pro": {"prompt": 3.5, "completion": 10.5, "token_limit": 2000000, "uses_characters": True}, 27 | "gemini-1.5-flash": {"prompt": 0.075, "completion": 0.3, "token_limit": 1000000, "uses_characters": True}, 28 | "gemini-1.5-flash-8b": {"prompt": 0.0375, "completion": 0.15, "token_limit": 1000000, "uses_characters": True}, 29 | } 30 | 31 | def __init__(self, api_key=None, model=None, use_vertexai=False, project=None, location="us-central1", **kwargs): 32 | """ 33 | Initialize Google GenAI Provider with support for both Gemini API and Vertex AI. 34 | 35 | Args: 36 | api_key: API key for Gemini API (not needed for Vertex AI) 37 | model: Model name to use 38 | use_vertexai: Whether to use Vertex AI instead of Gemini API 39 | project: Google Cloud project ID (required for Vertex AI) 40 | location: Google Cloud location (default: us-central1) 41 | **kwargs: Additional arguments 42 | """ 43 | if model is None: 44 | model = list(self.MODEL_INFO.keys())[0] 45 | 46 | self.model = model 47 | self.use_vertexai = use_vertexai 48 | self.project = project 49 | self.location = location 50 | 51 | # Initialize the appropriate client 52 | if use_vertexai: 53 | # For Vertex AI, try to get project from parameter, environment, or let SDK auto-detect 54 | if not project: 55 | project = os.getenv("GOOGLE_CLOUD_PROJECT") 56 | # If still no project, let the SDK try to auto-detect from gcloud config 57 | # The SDK should be able to detect it from Application Default Credentials 58 | 59 | if project: 60 | self.client = genai.Client( 61 | vertexai=True, 62 | project=project, 63 | location=location 64 | ) 65 | else: 66 | # Try without explicit project - let SDK auto-detect from gcloud config 67 | try: 68 | self.client = genai.Client( 69 | vertexai=True, 70 | location=location 71 | ) 72 | # If successful, try to get the project from gcloud config for display 73 | try: 74 | import subprocess 75 | result = subprocess.run(['gcloud', 'config', 'get-value', 'project'], 76 | capture_output=True, text=True, timeout=5) 77 | if result.returncode == 0: 78 | project = result.stdout.strip() 79 | except: 80 | project = "auto-detected" 81 | except Exception as e: 82 | raise ValueError( 83 | f"Could not initialize Vertex AI client. Please either:\n" 84 | f"1. Set GOOGLE_CLOUD_PROJECT environment variable, or\n" 85 | f"2. Pass project parameter, or\n" 86 | f"3. Configure gcloud with: gcloud config set project YOUR_PROJECT_ID\n" 87 | f"Error: {e}" 88 | ) 89 | else: 90 | if api_key is None: 91 | api_key = os.getenv("GOOGLE_API_KEY") 92 | if not api_key: 93 | raise ValueError("api_key parameter is required for Gemini API. Set GOOGLE_API_KEY environment variable or pass api_key parameter.") 94 | 95 | self.client = genai.Client(api_key=api_key) 96 | 97 | def count_tokens(self, content): 98 | """ 99 | Count tokens in the given content. 100 | For Google GenAI, we'll use a simple approximation since the exact 101 | token counting API might not be readily available in streaming context. 102 | """ 103 | if isinstance(content, str): 104 | # Simple approximation: ~4 characters per token for most languages 105 | return max(1, len(content) // 4) 106 | elif isinstance(content, (list, dict)): 107 | # Convert to string and count 108 | import json 109 | content_str = json.dumps(content) if isinstance(content, dict) else str(content) 110 | return max(1, len(content_str) // 4) 111 | else: 112 | return max(1, len(str(content)) // 4) 113 | 114 | def _prepare_model_inputs( 115 | self, 116 | prompt: str, 117 | temperature: float = 0.01, 118 | max_tokens: int = 300, 119 | stream: bool = False, 120 | **kwargs, 121 | ) -> Dict: 122 | temperature = max(temperature, 0.01) 123 | max_output_tokens = kwargs.pop("max_output_tokens", max_tokens) 124 | 125 | # Create config using the modern API 126 | config = types.GenerateContentConfig( 127 | temperature=temperature, 128 | max_output_tokens=max_output_tokens, 129 | **kwargs, 130 | ) 131 | return {"config": config, "contents": prompt} 132 | 133 | def complete( 134 | self, 135 | prompt: str, 136 | temperature: float = 0.01, 137 | max_tokens: int = 300, 138 | context: str = None, 139 | examples: dict = {}, 140 | **kwargs, 141 | ) -> Result: 142 | model_inputs = self._prepare_model_inputs( 143 | prompt=prompt, 144 | temperature=temperature, 145 | max_tokens=max_tokens, 146 | **kwargs, 147 | ) 148 | 149 | with self.track_latency(): 150 | response = self.client.models.generate_content( 151 | model=self.model, 152 | contents=model_inputs["contents"], 153 | config=model_inputs["config"], 154 | ) 155 | 156 | completion = response.text or "" 157 | 158 | # Calculate tokens and cost 159 | prompt_tokens = len(prompt) 160 | completion_tokens = len(completion) 161 | 162 | cost_per_token = self.MODEL_INFO[self.model] 163 | cost = ( 164 | (prompt_tokens * cost_per_token["prompt"]) 165 | + (completion_tokens * cost_per_token["completion"]) 166 | ) / 1_000_000 167 | 168 | # fast approximation. We could call count_message_tokens() but this will add latency 169 | prompt_tokens = math.ceil((prompt_tokens+1) / 4) 170 | completion_tokens = math.ceil((completion_tokens+1) / 4) 171 | total_tokens = math.ceil(prompt_tokens + completion_tokens) 172 | 173 | meta = { 174 | "model": self.model, 175 | "tokens": total_tokens, 176 | "tokens_prompt": prompt_tokens, 177 | "tokens_completion": completion_tokens, 178 | "cost": cost, 179 | "latency": self.latency, 180 | } 181 | return Result( 182 | text=completion, 183 | model_inputs=model_inputs, 184 | provider=self, 185 | meta=meta, 186 | ) 187 | 188 | def complete_stream( 189 | self, 190 | prompt: str, 191 | temperature: float = 0.01, 192 | max_tokens: int = 300, 193 | context: str = None, 194 | examples: dict = {}, 195 | **kwargs, 196 | ) -> StreamResult: 197 | """ 198 | Stream completion for Google GenAI provider. 199 | 200 | Args: 201 | prompt: The text prompt to complete 202 | temperature: Controls randomness (min 0.01 for Google) 203 | max_tokens: Maximum tokens to generate 204 | context: Additional context (unused in this implementation) 205 | examples: Examples dict (unused in this implementation) 206 | **kwargs: Additional parameters passed to the model 207 | """ 208 | model_inputs = self._prepare_model_inputs( 209 | prompt=prompt, 210 | temperature=temperature, 211 | max_tokens=max_tokens, 212 | stream=True, 213 | **kwargs, 214 | ) 215 | 216 | with self.track_latency(): 217 | response = self.client.models.generate_content_stream( 218 | model=self.model, 219 | contents=model_inputs["contents"], 220 | config=model_inputs["config"], 221 | ) 222 | stream = self._process_stream(response) 223 | 224 | return StreamResult(stream=stream, model_inputs=model_inputs, provider=self) 225 | 226 | def _process_stream(self, response) -> Generator: 227 | """ 228 | Process the streaming response from Google GenAI. 229 | 230 | Args: 231 | response: The streaming response from Google's generate_content_stream 232 | 233 | Yields: 234 | str: Individual text chunks from the stream 235 | """ 236 | for chunk in response: 237 | if chunk.text: 238 | yield chunk.text 239 | 240 | async def acomplete_stream( 241 | self, 242 | prompt: str, 243 | temperature: float = 0.01, 244 | max_tokens: int = 300, 245 | context: str = None, 246 | examples: dict = {}, 247 | **kwargs, 248 | ) -> AsyncStreamResult: 249 | """ 250 | Async stream completion for Google GenAI provider. 251 | 252 | Args: 253 | prompt: The text prompt to complete 254 | temperature: Controls randomness (min 0.01 for Google) 255 | max_tokens: Maximum tokens to generate 256 | context: Additional context (unused in this implementation) 257 | examples: Examples dict (unused in this implementation) 258 | **kwargs: Additional parameters passed to the model 259 | """ 260 | model_inputs = self._prepare_model_inputs( 261 | prompt=prompt, 262 | temperature=temperature, 263 | max_tokens=max_tokens, 264 | stream=True, 265 | **kwargs, 266 | ) 267 | 268 | with self.track_latency(): 269 | response = await self.client.aio.models.generate_content_stream( 270 | model=self.model, 271 | contents=model_inputs["contents"], 272 | config=model_inputs["config"], 273 | ) 274 | stream = self._aprocess_stream(response) 275 | 276 | return AsyncStreamResult(stream=stream, model_inputs=model_inputs, provider=self) 277 | 278 | async def _aprocess_stream(self, response) -> AsyncGenerator: 279 | """ 280 | Process the async streaming response from Google GenAI. 281 | 282 | Args: 283 | response: The async streaming response from Google's generate_content_stream 284 | 285 | Yields: 286 | str: Individual text chunks from the stream 287 | """ 288 | async for chunk in response: 289 | if chunk.text: 290 | yield chunk.text 291 | 292 | 293 | class GoogleVertexAIProvider(GoogleGenAIProvider): 294 | """ 295 | Dedicated Google Vertex AI provider that always uses Vertex AI. 296 | This is a convenience class for users who prefer explicit separation. 297 | """ 298 | 299 | def __init__(self, project=None, location="us-central1", model=None, **kwargs): 300 | """ 301 | Initialize Google Vertex AI Provider. 302 | 303 | Args: 304 | project: Google Cloud project ID (auto-detected from gcloud if not provided) 305 | location: Google Cloud location (default: us-central1) 306 | model: Model name to use 307 | **kwargs: Additional arguments 308 | """ 309 | # Always use Vertex AI, ignore any api_key parameter 310 | super().__init__( 311 | api_key=None, 312 | model=model, 313 | use_vertexai=True, 314 | project=project, 315 | location=location, 316 | **kwargs 317 | ) 318 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyLLMs 2 | 3 | [![PyPI version](https://badge.fury.io/py/pyllms.svg)](https://badge.fury.io/py/pyllms) 4 | [![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/license/mit/) 5 | [![](https://dcbadge.vercel.app/api/server/aDNg6E9szy?compact=true&style=flat)](https://discord.gg/aDNg6E9szy) 6 | [![Twitter](https://img.shields.io/twitter/follow/KagiHQ?style=social)](https://twitter.com/KagiHQ) 7 | 8 | PyLLMs is a minimal Python library to connect to various Language Models (LLMs) with a built-in model performance benchmark. 9 | 10 | ## Table of Contents 11 | 12 | - [Features](#features) 13 | - [Installation](#installation) 14 | - [Quick Start](#quick-start) 15 | - [Usage](#usage) 16 | - [Basic Usage](#basic-usage) 17 | - [Multi-model Usage](#multi-model-usage) 18 | - [Async Support](#async-support) 19 | - [Streaming Support](#streaming-support) 20 | - [Chat History and System Message](#chat-history-and-system-message) 21 | - [Other Methods](#other-methods) 22 | - [Configuration](#configuration) 23 | - [Model Benchmarks](#model-benchmarks) 24 | - [Supported Models](#supported-models) 25 | - [Advanced Usage](#advanced-usage) 26 | - [Using OpenAI API on Azure](#using-openai-api-on-azure) 27 | - [Using Google Vertex LLM models](#using-google-vertex-llm-models) 28 | - [Using Local Ollama LLM models](#using-local-ollama-llm-models) 29 | - [Contributing](#contributing) 30 | - [License](#license) 31 | 32 | ## Features 33 | 34 | - Connect to top LLMs in a few lines of code 35 | - Response meta includes tokens processed, cost, and latency standardized across models 36 | - Multi-model support: Get completions from different models simultaneously 37 | - LLM benchmark: Evaluate models on quality, speed, and cost 38 | - Async and streaming support for compatible models 39 | 40 | ## Installation 41 | 42 | Install the package using pip: 43 | 44 | ```bash 45 | pip install pyllms 46 | ``` 47 | 48 | ## Quick Start 49 | 50 | ```python 51 | import llms 52 | 53 | model = llms.init('gpt-4o') 54 | result = model.complete("What is 5+5?") 55 | 56 | print(result.text) 57 | ``` 58 | 59 | ## Usage 60 | 61 | ### Basic Usage 62 | 63 | ```python 64 | import llms 65 | 66 | model = llms.init('gpt-4o') 67 | result = model.complete( 68 | "What is the capital of the country where Mozart was born?", 69 | temperature=0.1, 70 | max_tokens=200 71 | ) 72 | 73 | print(result.text) 74 | print(result.meta) 75 | ``` 76 | 77 | ### Multi-model Usage 78 | 79 | ```python 80 | models = llms.init(model=['gpt-3.5-turbo', 'claude-instant-v1']) 81 | result = models.complete('What is the capital of the country where Mozart was born?') 82 | 83 | print(result.text) 84 | print(result.meta) 85 | ``` 86 | 87 | ### Async Support 88 | 89 | ```python 90 | result = await model.acomplete("What is the capital of the country where Mozart was born?") 91 | ``` 92 | 93 | ### Streaming Support 94 | 95 | ```python 96 | model = llms.init('claude-v1') 97 | result = model.complete_stream("Write an essay on the Civil War") 98 | for chunk in result.stream: 99 | if chunk is not None: 100 | print(chunk, end='') 101 | ``` 102 | 103 | ### Chat History and System Message 104 | 105 | ```python 106 | history = [] 107 | history.append({"role": "user", "content": user_input}) 108 | history.append({"role": "assistant", "content": result.text}) 109 | 110 | model.complete(prompt=prompt, history=history) 111 | 112 | # For OpenAI chat models 113 | model.complete(prompt=prompt, system_message=system, history=history) 114 | ``` 115 | 116 | ### Other Methods 117 | 118 | ```python 119 | count = model.count_tokens('The quick brown fox jumped over the lazy dog') 120 | ``` 121 | 122 | ## Configuration 123 | 124 | PyLLMs will attempt to read API keys and the default model from environment variables. You can set them like this: 125 | 126 | ```bash 127 | export OPENAI_API_KEY="your_api_key_here" 128 | export ANTHROPIC_API_KEY="your_api_key_here" 129 | export AI21_API_KEY="your_api_key_here" 130 | export COHERE_API_KEY="your_api_key_here" 131 | export ALEPHALPHA_API_KEY="your_api_key_here" 132 | export HUGGINFACEHUB_API_KEY="your_api_key_here" 133 | export GOOGLE_API_KEY="your_api_key_here" 134 | export MISTRAL_API_KEY="your_api_key_here" 135 | export REKA_API_KEY="your_api_key_here" 136 | export TOGETHER_API_KEY="your_api_key_here" 137 | export GROQ_API_KEY="your_api_key_here" 138 | export DEEPSEEK_API_KEY="your_api_key_here" 139 | 140 | export LLMS_DEFAULT_MODEL="gpt-3.5-turbo" 141 | ``` 142 | 143 | Alternatively, you can pass initialization values to the `init()` method: 144 | 145 | ```python 146 | model = llms.init(openai_api_key='your_api_key_here', model='gpt-4') 147 | ``` 148 | 149 | ## Model Benchmarks 150 | 151 | PyLLMs includes an automated benchmark system. The quality of models is evaluated using a powerful model (e.g., GPT-4) on a range of predefined questions, or you can supply your own. 152 | 153 | ```python 154 | model = llms.init(model=['claude-3-haiku-20240307', 'gpt-4o-mini', 'claude-3-5-sonnet-20240620', 'gpt-4o', 'mistral-large-latest', 'open-mistral-nemo', 'gpt-4', 'gpt-3.5-turbo', 'deepseek-coder', 'deepseek-chat', 'llama-3.1-8b-instant', 'llama-3.1-70b-versatile']) 155 | 156 | gpt4 = llms.init('gpt-4o') 157 | 158 | models.benchmark(evaluator=gpt4) 159 | ``` 160 | 161 | Check [Kagi LLM Benchmarking Project](https://help.kagi.com/kagi/ai/llm-benchmark.html) for the latest benchmarks! 162 | 163 | To evaluate models on your own prompts: 164 | 165 | ```python 166 | models.benchmark(prompts=[("What is the capital of Finland?", "Helsinki")], evaluator=gpt4) 167 | ``` 168 | 169 | ## Supported Models 170 | 171 | To get a full list of supported models: 172 | 173 | ```python 174 | model = llms.init() 175 | model.list() # list all models 176 | 177 | model.list("gpt") # lists only models with 'gpt' in name/provider name 178 | ``` 179 | 180 | Currently supported models (may be outdated): 181 | 182 | | **Provider** | **Models** | 183 | | ------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | 184 | | OpenAIProvider | gpt-3.5-turbo, gpt-3.5-turbo-1106, gpt-3.5-turbo-instruct, gpt-4, gpt-4-1106-preview, gpt-4-turbo-preview, gpt-4-turbo, gpt-4o, gpt-4o-mini, gpt-4o-2024-08-06, gpt-4.1, gpt-4.1-mini, gpt-4.1-nano, gpt-4.5-preview, chatgpt-4o-latest, o1-preview, o1-mini, o1, o1-pro, o3-mini, o3, o3-pro, o4-mini | 185 | | AnthropicProvider | claude-2.1, claude-3-5-sonnet-20240620, claude-3-5-sonnet-20241022, claude-3-5-haiku-20241022, claude-3-7-sonnet-20250219, claude-sonnet-4-20250514, claude-opus-4-20250514 | 186 | | BedrockAnthropicProvider | anthropic.claude-instant-v1, anthropic.claude-v1, anthropic.claude-v2, anthropic.claude-3-haiku-20240307-v1:0, anthropic.claude-3-sonnet-20240229-v1:0, anthropic.claude-3-5-sonnet-20240620-v1:0 | 187 | | AI21Provider | j2-grande-instruct, j2-jumbo-instruct | 188 | | CohereProvider | command, command-nightly | 189 | | AlephAlphaProvider | luminous-base, luminous-extended, luminous-supreme, luminous-supreme-control | 190 | | HuggingfaceHubProvider | hf_pythia, hf_falcon40b, hf_falcon7b, hf_mptinstruct, hf_mptchat, hf_llava, hf_dolly, hf_vicuna | 191 | | GoogleGenAIProvider | gemini-2.5-pro, gemini-2.5-flash, gemini-2.5-flash-lite-preview-06-17, gemini-2.0-flash, gemini-2.0-flash-lite, gemini-1.5-pro, gemini-1.5-flash, gemini-1.5-flash-8b | 192 | | GoogleVertexAIProvider | gemini-2.5-pro, gemini-2.5-flash, gemini-2.5-flash-lite-preview-06-17, gemini-2.0-flash, gemini-2.0-flash-lite, gemini-1.5-pro, gemini-1.5-flash, gemini-1.5-flash-8b | 193 | | OllamaProvider | vanilj/Phi-4:latest, falcon3:10b, smollm2:latest, llama3.2:3b-instruct-q8_0, qwen2:1.5b, mistral:7b-instruct-v0.2-q4_K_S, phi3:latest, phi3:3.8b, phi:latest, tinyllama:latest, magicoder:latest, deepseek-coder:6.7b, deepseek-coder:latest, dolphin-phi:latest, stablelm-zephyr:latest | 194 | | DeepSeekProvider | deepseek-chat, deepseek-coder | 195 | | GroqProvider | llama-3.1-405b-reasoning, llama-3.1-70b-versatile, llama-3.1-8b-instant, gemma2-9b-it | 196 | | RekaProvider | reka-edge, reka-flash, reka-core | 197 | | TogetherProvider | meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo | 198 | | OpenRouterProvider | nvidia/llama-3.1-nemotron-70b-instruct, x-ai/grok-2, nousresearch/hermes-3-llama-3.1-405b:free, google/gemini-flash-1.5-exp, liquid/lfm-40b, mistralai/ministral-8b, qwen/qwen-2.5-72b-instruct | 199 | | MistralProvider | mistral-tiny, open-mistral-7b, mistral-small, open-mixtral-8x7b, mistral-small-latest, mistral-medium-latest, mistral-large-latest, open-mistral-nemo | 200 | 201 | ## Advanced Usage 202 | 203 | ### Using OpenAI API on Azure 204 | 205 | ```python 206 | import llms 207 | AZURE_API_BASE = "{insert here}" 208 | AZURE_API_KEY = "{insert here}" 209 | 210 | model = llms.init('gpt-4') 211 | 212 | azure_args = { 213 | "engine": "gpt-4", # Azure deployment_id 214 | "api_base": AZURE_API_BASE, 215 | "api_type": "azure", 216 | "api_version": "2023-05-15", 217 | "api_key": AZURE_API_KEY, 218 | } 219 | 220 | azure_result = model.complete("What is 5+5?", **azure_args) 221 | ``` 222 | 223 | ### Using Google AI Models 224 | 225 | PyLLMs supports Google's AI models through two providers: 226 | 227 | #### Option 1: Gemini API (GoogleGenAI) 228 | 229 | Uses direct Gemini API with API key authentication: 230 | 231 | ```python 232 | # Set your API key 233 | export GOOGLE_API_KEY="your_api_key_here" 234 | 235 | # Use any Gemini model 236 | model = llms.init('gemini-2.5-flash') 237 | result = model.complete("Hello!") 238 | ``` 239 | 240 | #### Option 2: Vertex AI (GoogleVertexAI) 241 | 242 | Uses Google Cloud Vertex AI with Application Default Credentials: 243 | 244 | 1. Set up a GCP account and create a project 245 | 2. Enable Vertex AI APIs in your GCP project 246 | 3. Install gcloud CLI tool 247 | 4. Set up Application Default Credentials: 248 | ```bash 249 | gcloud auth application-default login 250 | gcloud config set project YOUR_PROJECT_ID 251 | ``` 252 | 253 | Then use models through Vertex AI: 254 | 255 | ```python 256 | # Option A: Direct provider usage for Vertex AI 257 | from llms.providers.google_genai import GoogleVertexAIProvider 258 | provider = GoogleVertexAIProvider() 259 | result = provider.complete("Hello!") 260 | 261 | # Option B: Unified provider with Vertex AI flag 262 | from llms.providers.google_genai import GoogleGenAIProvider 263 | provider = GoogleGenAIProvider(use_vertexai=True) 264 | result = provider.complete("Hello!") 265 | ``` 266 | 267 | **Note:** Both providers support the same model names. If both `GOOGLE_API_KEY` and gcloud credentials are configured, `llms.init('gemini-2.5-flash')` will use both providers simultaneously. 268 | 269 | ### Using Local Ollama LLM models 270 | 271 | 1. Ensure Ollama is running and you've pulled the desired model 272 | 2. Get the name of the LLM you want to use 273 | 3. Initialize PyLLMs: 274 | 275 | ```python 276 | model = llms.init("tinyllama:latest") 277 | result = model.complete("Hello!") 278 | ``` 279 | 280 | ## Contributing 281 | 282 | Contributions are welcome! Please feel free to submit a Pull Request. 283 | 284 | ## License 285 | 286 | This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details. 287 | -------------------------------------------------------------------------------- /tests/test_all_models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Comprehensive test suite for all PyLLMs models. 3 | 4 | This test suite dynamically discovers all available providers and models 5 | using the existing LLMS infrastructure, checks for API keys in environment 6 | variables, and runs tests only for providers with valid API keys available. 7 | 8 | No models or providers are hardcoded - everything is discovered dynamically. 9 | """ 10 | 11 | import os 12 | import pytest 13 | from typing import Dict, List, Tuple, Any 14 | 15 | # Import the main LLMS class 16 | import llms 17 | from llms.llms import LLMS 18 | 19 | 20 | def get_available_providers() -> Dict[str, Any]: 21 | """ 22 | Get all providers that have API keys available or don't need them. 23 | Uses the existing LLMS._provider_map infrastructure. 24 | 25 | Returns: 26 | Dict mapping provider names to their Provider objects if API key is available 27 | """ 28 | available_providers = {} 29 | 30 | for provider_name, provider_config in LLMS._provider_map.items(): 31 | if not provider_config.needs_api_key: 32 | # Providers that don't need API keys (like Ollama, Google Vertex) 33 | available_providers[provider_name] = provider_config 34 | elif provider_config.custom_credential_check: 35 | # Providers with custom credential checking (like BedrockAnthropic) 36 | if provider_config.custom_credential_check(): 37 | available_providers[provider_name] = provider_config 38 | elif provider_config.api_key_name and os.getenv(provider_config.api_key_name): 39 | # Providers that need API keys and have them available 40 | available_providers[provider_name] = provider_config 41 | 42 | return available_providers 43 | 44 | 45 | def get_all_models() -> List[Tuple[str, str, Any]]: 46 | """ 47 | Dynamically discover all models from available providers. 48 | Uses the existing LLMS infrastructure. 49 | 50 | Returns: 51 | List of tuples: (provider_name, model_name, provider_class) 52 | """ 53 | available_providers = get_available_providers() 54 | all_models = [] 55 | 56 | for provider_name, provider_config in available_providers.items(): 57 | provider_class = provider_config.provider 58 | 59 | # Get all models from the provider's MODEL_INFO 60 | if hasattr(provider_class, 'MODEL_INFO'): 61 | models = list(provider_class.MODEL_INFO.keys()) 62 | for model_name in models: 63 | all_models.append((provider_name, model_name, provider_class)) 64 | 65 | return all_models 66 | 67 | 68 | @pytest.fixture(scope="session") 69 | def available_providers(): 70 | """Fixture that returns all available providers.""" 71 | return get_available_providers() 72 | 73 | 74 | @pytest.fixture(scope="session") 75 | def all_model_combinations(): 76 | """Fixture that returns all available model combinations.""" 77 | return get_all_models() 78 | 79 | 80 | class TestModelDiscovery: 81 | """Test that we can discover models and providers correctly using LLMS infrastructure.""" 82 | 83 | def test_provider_discovery(self, available_providers): 84 | """Test that we can discover available providers from LLMS._provider_map.""" 85 | assert len(available_providers) > 0, "No providers with API keys found" 86 | print(f"\nFound {len(available_providers)} available providers:") 87 | 88 | for provider_name, provider_config in available_providers.items(): 89 | if not provider_config.needs_api_key: 90 | api_key_status = "No API key needed" 91 | elif provider_config.custom_credential_check: 92 | api_key_status = "Custom credentials" 93 | else: 94 | api_key_status = f"API key: {provider_config.api_key_name}" 95 | print(f" ✓ {provider_name} ({api_key_status})") 96 | 97 | assert hasattr(provider_config, 'provider') 98 | assert hasattr(provider_config, 'needs_api_key') 99 | 100 | if provider_config.needs_api_key and not provider_config.custom_credential_check: 101 | assert provider_config.api_key_name is not None 102 | assert os.getenv(provider_config.api_key_name) is not None 103 | 104 | def test_model_discovery(self, all_model_combinations): 105 | """Test that we can discover models from providers using MODEL_INFO.""" 106 | assert len(all_model_combinations) > 0, "No models discovered" 107 | print(f"\nFound {len(all_model_combinations)} total models") 108 | 109 | models_by_provider = {} 110 | for provider_name, model_name, provider_class in all_model_combinations: 111 | if provider_name not in models_by_provider: 112 | models_by_provider[provider_name] = [] 113 | models_by_provider[provider_name].append(model_name) 114 | 115 | assert isinstance(provider_name, str) 116 | assert isinstance(model_name, str) 117 | assert hasattr(provider_class, 'MODEL_INFO') 118 | assert model_name in provider_class.MODEL_INFO 119 | 120 | for provider_name, models in models_by_provider.items(): 121 | print(f" {provider_name}: {len(models)} models") 122 | 123 | 124 | class TestModelInitialization: 125 | """Test that models can be initialized correctly using llms.init().""" 126 | 127 | @pytest.mark.parametrize("provider_name,model_name,provider_class", get_all_models()) 128 | def test_model_initialization_via_llms_init(self, provider_name, model_name, provider_class): 129 | """Test that each model can be initialized through llms.init().""" 130 | try: 131 | model = llms.init(model_name) 132 | assert model is not None 133 | assert len(model._providers) == 1 134 | assert model._models == [model_name] 135 | assert hasattr(model, 'complete') 136 | assert hasattr(model, 'count_tokens') 137 | except Exception as e: 138 | pytest.fail(f"Failed to initialize model {model_name} from {provider_name}: {e}") 139 | 140 | @pytest.mark.parametrize("provider_name,model_name,provider_class", get_all_models()) 141 | def test_provider_direct_initialization(self, provider_name, model_name, provider_class): 142 | """Test that each provider can be initialized directly.""" 143 | provider_config = LLMS._provider_map[provider_name] 144 | 145 | try: 146 | if provider_name == "BedrockAnthropic": 147 | # Special case for BedrockAnthropic which uses AWS credentials 148 | provider = provider_class( 149 | model=model_name, 150 | aws_access_key=os.getenv("AWS_ACCESS_KEY_ID"), 151 | aws_secret_key=os.getenv("AWS_SECRET_ACCESS_KEY"), 152 | aws_region=os.getenv("AWS_DEFAULT_REGION", "us-east-1") 153 | ) 154 | elif provider_config.needs_api_key and provider_config.api_key_name: 155 | api_key = os.getenv(provider_config.api_key_name) 156 | provider = provider_class(api_key=api_key, model=model_name) 157 | else: 158 | provider = provider_class(model=model_name) 159 | 160 | assert provider is not None 161 | assert provider.model == model_name 162 | assert hasattr(provider, 'complete') 163 | assert hasattr(provider, 'count_tokens') 164 | except Exception as e: 165 | pytest.fail(f"Failed to initialize provider {provider_name} with model {model_name}: {e}") 166 | 167 | 168 | class TestBasicModelFunctionality: 169 | """Test basic functionality of each model.""" 170 | 171 | @pytest.mark.parametrize("provider_name,model_name,provider_class", get_all_models()) 172 | def test_model_completion(self, provider_name, model_name, provider_class): 173 | """Test that each model can complete a simple prompt.""" 174 | 175 | # Skip embedding and rerank models as they have different interfaces 176 | if any(keyword in model_name.lower() for keyword in ['embed', 'rerank']): 177 | pytest.skip(f"Skipping {model_name} - embedding/rerank model with different interface") 178 | 179 | try: 180 | model = llms.init(model_name) 181 | 182 | # Simple test prompt that should work across all text models 183 | prompt = "What is 2+2? Answer with just the number." 184 | 185 | # Set reasonable parameters that work for all models including thinking models 186 | result = model.complete( 187 | prompt, 188 | max_tokens=2048, 189 | temperature=0 190 | ) 191 | 192 | assert result is not None 193 | assert hasattr(result, 'text') 194 | assert len(result.text.strip()) > 0 195 | assert hasattr(result, 'meta') 196 | 197 | print(f"✓ {provider_name}/{model_name}: '{result.text.strip()}'") 198 | 199 | except Exception as e: 200 | pytest.fail(f"Model {provider_name}/{model_name} failed: {e}") 201 | 202 | @pytest.mark.parametrize("provider_name,model_name,provider_class", get_all_models()) 203 | def test_token_counting(self, provider_name, model_name, provider_class): 204 | """Test token counting functionality for each model.""" 205 | 206 | if any(keyword in model_name.lower() for keyword in ['embed', 'rerank']): 207 | pytest.skip(f"Skipping {model_name} - embedding/rerank model") 208 | 209 | try: 210 | model = llms.init(model_name) 211 | 212 | test_text = "Hello, world! This is a test." 213 | token_count = model.count_tokens(test_text) 214 | 215 | assert isinstance(token_count, int) 216 | assert token_count > 0 217 | 218 | print(f"✓ {provider_name}/{model_name}: {token_count} tokens for '{test_text}'") 219 | 220 | except NotImplementedError as e: 221 | # Some providers legitimately don't support token counting 222 | if "Count tokens is currently not supported" in str(e): 223 | pytest.skip(f"Skipping {provider_name}/{model_name} - token counting not supported by provider") 224 | else: 225 | pytest.fail(f"Unexpected NotImplementedError for {provider_name}/{model_name}: {e}") 226 | except Exception as e: 227 | pytest.fail(f"Token counting failed for {provider_name}/{model_name}: {e}") 228 | 229 | 230 | class TestAsyncFunctionality: 231 | """Test async functionality where supported.""" 232 | 233 | @pytest.mark.asyncio 234 | @pytest.mark.parametrize("provider_name,model_name,provider_class", 235 | [model for model in get_all_models() if not any(kw in model[1].lower() for kw in ['embed', 'rerank'])][:3]) # Test first 3 non-embed models 236 | async def test_async_completion(self, provider_name, model_name, provider_class): 237 | """Test async completion for supported models.""" 238 | 239 | try: 240 | model = llms.init(model_name) 241 | 242 | prompt = "What is 3+3? Answer with just the number." 243 | 244 | result = await model.acomplete( 245 | prompt, 246 | max_tokens=10, 247 | temperature=0 248 | ) 249 | 250 | assert result is not None 251 | assert hasattr(result, 'text') 252 | assert len(result.text.strip()) > 0 253 | 254 | print(f"✓ Async {provider_name}/{model_name}: '{result.text.strip()}'") 255 | 256 | except Exception as e: 257 | pytest.fail(f"Async test failed for {provider_name}/{model_name}: {e}") 258 | 259 | 260 | class TestStreamingFunctionality: 261 | """Test streaming functionality where supported.""" 262 | 263 | @pytest.mark.parametrize("provider_name,model_name,provider_class", 264 | [model for model in get_all_models() if not any(kw in model[1].lower() for kw in ['embed', 'rerank'])][:2]) # Test first 2 non-embed models 265 | def test_streaming_completion(self, provider_name, model_name, provider_class): 266 | """Test streaming completion for supported models.""" 267 | 268 | try: 269 | model = llms.init(model_name) 270 | 271 | prompt = "Count: 1, 2, 3" 272 | 273 | result = model.complete_stream( 274 | prompt, 275 | max_tokens=20, 276 | temperature=0 277 | ) 278 | 279 | assert result is not None 280 | assert hasattr(result, 'stream') 281 | 282 | # Collect stream chunks 283 | chunks = [] 284 | for chunk in result.stream: 285 | if chunk is not None: 286 | chunks.append(chunk) 287 | if len(chunks) >= 3: # Don't collect too many to avoid rate limits 288 | break 289 | 290 | assert len(chunks) > 0 291 | full_text = ''.join(chunks) 292 | assert len(full_text.strip()) > 0 293 | 294 | print(f"✓ Stream {provider_name}/{model_name}: '{full_text.strip()}'") 295 | 296 | except Exception as e: 297 | pytest.fail(f"Streaming test failed for {provider_name}/{model_name}: {e}") 298 | 299 | 300 | class TestModelInformation: 301 | """Test that model information is correctly defined.""" 302 | 303 | @pytest.mark.parametrize("provider_name,model_name,provider_class", get_all_models()) 304 | def test_model_info_structure(self, provider_name, model_name, provider_class): 305 | """Test that model info has required fields.""" 306 | model_info = provider_class.MODEL_INFO[model_name] 307 | 308 | # All models should have prompt and completion pricing 309 | assert 'prompt' in model_info, f"{provider_name}/{model_name} missing 'prompt' pricing" 310 | assert 'completion' in model_info, f"{provider_name}/{model_name} missing 'completion' pricing" 311 | assert isinstance(model_info['prompt'], (int, float)), f"{provider_name}/{model_name} prompt pricing not numeric" 312 | assert isinstance(model_info['completion'], (int, float)), f"{provider_name}/{model_name} completion pricing not numeric" 313 | 314 | # Token limit should be specified 315 | assert 'token_limit' in model_info, f"{provider_name}/{model_name} missing 'token_limit'" 316 | assert isinstance(model_info['token_limit'], int), f"{provider_name}/{model_name} token_limit not integer" 317 | assert model_info['token_limit'] >= 0, f"{provider_name}/{model_name} token_limit negative" 318 | 319 | 320 | class TestMultiModelFunctionality: 321 | """Test multi-model functionality.""" 322 | 323 | def test_multi_model_init(self, all_model_combinations): 324 | """Test initializing multiple models at once.""" 325 | # Get first 3 available models to avoid overwhelming APIs 326 | available_models = [model[1] for model in all_model_combinations if not any(kw in model[1].lower() for kw in ['embed', 'rerank'])][:3] 327 | 328 | if len(available_models) < 2: 329 | pytest.skip("Need at least 2 models for multi-model test") 330 | 331 | try: 332 | models = llms.init(model=available_models) 333 | assert models is not None 334 | assert len(models._providers) == len(available_models) 335 | assert models._models == available_models 336 | 337 | except Exception as e: 338 | pytest.fail(f"Multi-model initialization failed: {e}") 339 | 340 | 341 | def test_no_hardcoded_models(): 342 | """Ensure we're not hardcoding models anywhere in tests.""" 343 | # This test ensures we're discovering models dynamically 344 | all_models = get_all_models() 345 | assert len(all_models) > 0, "No models discovered - discovery mechanism failed" 346 | 347 | # Verify we're discovering models dynamically (not hardcoded) 348 | providers_with_models = set(model[0] for model in all_models) 349 | assert len(providers_with_models) >= 1, "No providers found with models" 350 | 351 | # Verify that for each available provider, we're finding their models 352 | available_providers = get_available_providers() 353 | for provider_name in providers_with_models: 354 | assert provider_name in available_providers, f"Found models for {provider_name} but provider not in available_providers" 355 | 356 | if len(providers_with_models) == 1: 357 | print(f"✓ Dynamic discovery working: {len(all_models)} models from 1 provider: {list(providers_with_models)[0]}") 358 | else: 359 | print(f"✓ Dynamic discovery working: {len(all_models)} models from {len(providers_with_models)} providers") 360 | 361 | 362 | def test_llms_list_method(): 363 | """Test that LLMS.list() method works correctly.""" 364 | # Create a minimal LLMS instance to test the list method 365 | # We'll override _initialize_providers to avoid needing API keys 366 | class TestLLMS(LLMS): 367 | def _initialize_providers(self, kwargs): 368 | # Skip provider initialization for list testing 369 | self._providers = [] 370 | 371 | temp_llms = TestLLMS() 372 | all_models_list = temp_llms.list() 373 | 374 | assert len(all_models_list) > 0, "LLMS.list() returned no models" 375 | 376 | # Verify structure 377 | for model_info in all_models_list[:5]: # Check first 5 378 | assert 'provider' in model_info 379 | assert 'name' in model_info 380 | assert 'cost' in model_info 381 | 382 | print(f"✓ LLMS.list() returned {len(all_models_list)} models") 383 | 384 | 385 | if __name__ == "__main__": 386 | # When run directly, show available providers and models 387 | print("=== PyLLMs Dynamic Model Discovery ===") 388 | 389 | available = get_available_providers() 390 | print(f"\nAvailable Providers ({len(available)}):") 391 | for name, config in available.items(): 392 | api_key_status = "✓ No API key needed" if not config.needs_api_key else f"✓ {config.api_key_name}" 393 | print(f" {api_key_status} {name}") 394 | 395 | all_models = get_all_models() 396 | print(f"\nTotal Models Available: {len(all_models)}") 397 | 398 | # Group by provider 399 | by_provider = {} 400 | for provider_name, model_name, _ in all_models: 401 | if provider_name not in by_provider: 402 | by_provider[provider_name] = [] 403 | by_provider[provider_name].append(model_name) 404 | 405 | for provider_name, models in by_provider.items(): 406 | print(f" {provider_name}: {len(models)} models") 407 | 408 | print("\nRun tests with: pytest tests/test_all_models.py -v") 409 | print("Run with output: pytest tests/test_all_models.py -v -s") 410 | print("Run specific test: pytest tests/test_all_models.py::TestModelDiscovery -v") -------------------------------------------------------------------------------- /llms/providers/openai.py: -------------------------------------------------------------------------------- 1 | from typing import AsyncGenerator, Dict, Generator, List, Optional, Union 2 | import tiktoken 3 | 4 | from openai import AsyncOpenAI, OpenAI 5 | import json 6 | 7 | from ..results.result import AsyncStreamResult, Result, StreamResult 8 | from .base_provider import BaseProvider 9 | 10 | 11 | class OpenAIProvider(BaseProvider): 12 | # cost is per million tokens 13 | MODEL_INFO = { 14 | "gpt-5-nano": {"prompt": 0.05, "completion": 0.40, "token_limit": 128000, "is_chat": True, "output_limit": 4096}, 15 | "gpt-4.1-nano": {"prompt": 0.10, "completion": 0.40, "token_limit": 128000, "is_chat": True, "output_limit": 16384}, 16 | "gpt-5-mini": {"prompt": 0.25, "completion": 2.00, "token_limit": 128000, "is_chat": True, "output_limit": 4096}, 17 | "gpt-4.1-mini": {"prompt": 0.40, "completion": 1.60, "token_limit": 128000, "is_chat": True, "output_limit": 16384}, 18 | "gpt-4o-mini": {"prompt": 0.15, "completion": 0.60, "token_limit": 128000, "is_chat": True, "output_limit": 4096}, 19 | "gpt-4o-mini-audio-preview": {"prompt": 0.15, "completion": 0.60, "token_limit": 128000, "is_chat": True, "output_limit": 4096}, 20 | "o1-mini": {"prompt": 1.10, "completion": 4.40, "token_limit": 128000, "is_chat": True, "output_limit": 4096, "use_max_completion_tokens": True}, 21 | "o3-mini": {"prompt": 1.10, "completion": 4.40, "token_limit": 128000, "is_chat": True, "output_limit": 4096, "use_max_completion_tokens": True}, 22 | "o4-mini": {"prompt": 1.10, "completion": 4.40, "token_limit": 128000, "is_chat": True, "output_limit": 4096, "use_max_completion_tokens": True}, 23 | "gpt-4.1": {"prompt": 2.00, "completion": 8.00, "token_limit": 128000, "is_chat": True, "output_limit": 16384}, 24 | "o3": {"prompt": 2.00, "completion": 8.00, "token_limit": 200000, "is_chat": True, "output_limit": 100000, "use_max_completion_tokens": True}, 25 | "o4-mini-deep-research": {"prompt": 2.00, "completion": 8.00, "token_limit": 128000, "is_chat": True, "output_limit": 4096}, 26 | "gpt-5": {"prompt": 1.25, "completion": 10.00, "token_limit": 128000, "is_chat": True, "output_limit": 4096, "use_responses_api": True}, 27 | "gpt-5-chat-latest": {"prompt": 1.25, "completion": 10.00, "token_limit": 128000, "is_chat": True, "output_limit": 4096}, 28 | "gpt-4o": {"prompt": 2.50, "completion": 10.00, "token_limit": 128000, "is_chat": True, "output_limit": 4096}, 29 | "o3-deep-research": {"prompt": 10.00, "completion": 40.00, "token_limit": 200000, "is_chat": True, "output_limit": 100000, "use_max_completion_tokens": True}, 30 | "o1": {"prompt": 15.00, "completion": 60.00, "token_limit": 200000, "is_chat": True, "output_limit": 100000, "use_max_completion_tokens": True}, 31 | "o1-preview": {"prompt": 15.00, "completion": 60.00, "token_limit": 128000, "is_chat": True, "output_limit": 4096, "use_max_completion_tokens": True}, 32 | "o3-pro": {"prompt": 20.00, "completion": 80.00, "token_limit": 200000, "is_chat": True, "output_limit": 100000, "use_max_completion_tokens": True, "use_responses_api": True}, 33 | "o1-pro": {"prompt": 150.00, "completion": 600.00, "token_limit": 200000, "is_chat": True, "output_limit": 100000, "use_max_completion_tokens": True, "use_responses_api": True}, 34 | } 35 | 36 | def __init__( 37 | self, 38 | api_key: Union[str, None] = None, 39 | model: Union[str, None] = None, 40 | client_kwargs: Union[dict, None] = None, 41 | async_client_kwargs: Union[dict, None] = None, 42 | ): 43 | if model is None: 44 | model = list(self.MODEL_INFO.keys())[0] 45 | self.model = model 46 | if client_kwargs is None: 47 | client_kwargs = {} 48 | self.client = OpenAI(api_key=api_key, **client_kwargs) 49 | if async_client_kwargs is None: 50 | async_client_kwargs = {} 51 | self.async_client = AsyncOpenAI(api_key=api_key, **async_client_kwargs) 52 | 53 | @property 54 | def is_chat_model(self) -> bool: 55 | return self.MODEL_INFO[self.model]['is_chat'] 56 | 57 | @property 58 | def uses_responses_api(self) -> bool: 59 | return self.MODEL_INFO[self.model].get('use_responses_api', False) 60 | 61 | def count_tokens(self, content: Union[str, List[dict]]) -> int: 62 | try: 63 | enc = tiktoken.encoding_for_model(self.model) 64 | except KeyError: 65 | # For new models not yet in tiktoken, use gpt-4 as fallback 66 | enc = tiktoken.encoding_for_model("gpt-4") 67 | 68 | if isinstance(content, list): 69 | # When field name is present, ChatGPT will ignore the role token. 70 | # Adopted from OpenAI cookbook 71 | # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb 72 | # every message follows {role/name}\n{content}\n 73 | formatting_token_count = 4 74 | 75 | messages = content 76 | messages_text = ["".join(message.values()) for message in messages] 77 | tokens = [enc.encode(t, disallowed_special=()) for t in messages_text] 78 | 79 | n_tokens_list = [] 80 | for token, message in zip(tokens, messages): 81 | n_tokens = len(token) + formatting_token_count 82 | if "name" in message: 83 | n_tokens += -1 84 | n_tokens_list.append(n_tokens) 85 | return sum(n_tokens_list) 86 | else: 87 | return len(enc.encode(content, disallowed_special=())) 88 | 89 | def _prepare_model_inputs( 90 | self, 91 | prompt: str, 92 | history: Optional[List[dict]] = None, 93 | system_message: Union[str, List[dict], None] = None, 94 | temperature: float = 0, 95 | max_tokens: int = 300, 96 | stream: bool = False, 97 | reasoning_effort: Optional[str] = None, 98 | **kwargs, 99 | ) -> Dict: 100 | if self.is_chat_model: 101 | messages = [{"role": "user", "content": prompt}] 102 | 103 | if history: 104 | messages = [*history, *messages] 105 | 106 | if isinstance(system_message, str): 107 | messages = [{"role": "system", "content": system_message}, *messages] 108 | 109 | # users can input multiple full system message in dict form 110 | elif isinstance(system_message, list): 111 | messages = [*system_message, *messages] 112 | 113 | model_inputs = { 114 | "messages": messages, 115 | "stream": stream, 116 | **({'reasoning_effort': reasoning_effort} if reasoning_effort else {}), 117 | **kwargs, 118 | } 119 | 120 | # Use max_completion_tokens for models that require it 121 | if self.MODEL_INFO[self.model].get("use_max_completion_tokens", False): 122 | model_inputs["max_completion_tokens"] = max_tokens 123 | else: 124 | model_inputs["max_tokens"] = max_tokens 125 | model_inputs["temperature"] = temperature 126 | 127 | else: 128 | if history: 129 | raise ValueError( 130 | f"history argument is not supported for {self.model} model" 131 | ) 132 | 133 | if system_message: 134 | raise ValueError( 135 | f"system_message argument is not supported for {self.model} model" 136 | ) 137 | 138 | model_inputs = { 139 | "prompt": prompt, 140 | "temperature": temperature, 141 | "max_tokens": max_tokens, 142 | "stream": stream, 143 | **kwargs, 144 | } 145 | return model_inputs 146 | 147 | def complete( 148 | self, 149 | prompt: str, 150 | history: Optional[List[dict]] = None, 151 | system_message: Optional[List[dict]] = None, 152 | temperature: float = 0, 153 | max_tokens: int = 300, 154 | **kwargs, 155 | ) -> Result: 156 | """ 157 | Args: 158 | history: messages in OpenAI format, each dict must include role and content key. 159 | system_message: system messages in OpenAI format, must have role and content key. 160 | It can has name key to include few-shots examples. 161 | """ 162 | 163 | model_inputs = self._prepare_model_inputs( 164 | prompt=prompt, 165 | history=history, 166 | system_message=system_message, 167 | temperature=temperature, 168 | max_tokens=max_tokens, 169 | **kwargs, 170 | ) 171 | 172 | with self.track_latency(): 173 | if self.uses_responses_api: 174 | # Convert messages format for Responses API 175 | input_messages = model_inputs.pop("messages") 176 | # Handle any reasoning_effort parameter 177 | reasoning = {} 178 | if "reasoning_effort" in model_inputs: 179 | reasoning["effort"] = model_inputs.pop("reasoning_effort") 180 | 181 | # Prepare parameters for Responses API 182 | responses_params = { 183 | "model": self.model, 184 | "input": input_messages 185 | } 186 | 187 | # Temperature is not supported for some models with Responses API 188 | # Only add it if the model supports it 189 | 190 | # For Responses API, max_tokens should be converted to max_output_tokens 191 | if max_tokens is not None: 192 | responses_params["max_output_tokens"] = max_tokens 193 | 194 | # Add any other supported parameters 195 | for key, value in model_inputs.items(): 196 | if key not in ["messages", "max_completion_tokens", "max_tokens", "temperature", "reasoning_effort"]: 197 | responses_params[key] = value 198 | 199 | # Add reasoning if present 200 | if reasoning: 201 | responses_params["reasoning"] = reasoning 202 | 203 | response = self.client.responses.create(**responses_params) 204 | elif self.is_chat_model: 205 | response = self.client.chat.completions.create(model=self.model, **model_inputs) 206 | else: 207 | response = self.client.completions.create(model=self.model, **model_inputs) 208 | 209 | function_call = {} 210 | completion = "" 211 | 212 | if self.uses_responses_api: 213 | # Extract text from Responses API 214 | # Find the output_text in the response 215 | for item in response.output: 216 | if item.type == "message" and hasattr(item, "content"): 217 | for content_item in item.content: 218 | if content_item.type == "output_text": 219 | completion = content_item.text.strip() 220 | break 221 | 222 | # Handle function calls if present 223 | if hasattr(response, 'output') and hasattr(response.output, 'function_calls'): 224 | function_call = { 225 | "name": response.output.function_calls[0].name, 226 | "arguments": response.output.function_calls[0].arguments 227 | } 228 | 229 | # Usage has different field names in Responses API 230 | usage = { 231 | "prompt_tokens": response.usage.input_tokens, 232 | "completion_tokens": response.usage.output_tokens, 233 | "total_tokens": response.usage.total_tokens 234 | } 235 | else: 236 | is_func_call = response.choices[0].finish_reason == "function_call" 237 | if self.is_chat_model: 238 | if is_func_call: 239 | function_call = { 240 | "name": response.choices[0].message.function_call.name, 241 | "arguments": json.loads(response.choices[0].message.function_call.arguments) 242 | } 243 | else: 244 | completion = response.choices[0].message.content.strip() 245 | else: 246 | completion = response.choices[0].text.strip() 247 | usage = response.usage 248 | 249 | meta = { 250 | "tokens_prompt": usage["prompt_tokens"] if isinstance(usage, dict) else usage.prompt_tokens, 251 | "tokens_completion": usage["completion_tokens"] if isinstance(usage, dict) else usage.completion_tokens, 252 | "latency": self.latency, 253 | } 254 | return Result( 255 | text=completion, 256 | model_inputs=model_inputs, 257 | provider=self, 258 | meta=meta, 259 | function_call=function_call, 260 | ) 261 | 262 | async def acomplete( 263 | self, 264 | prompt: str, 265 | history: Optional[List[dict]] = None, 266 | system_message: Optional[List[dict]] = None, 267 | temperature: float = 0, 268 | max_tokens: int = 300, 269 | **kwargs, 270 | ) -> Result: 271 | """ 272 | Args: 273 | history: messages in OpenAI format, each dict must include role and content key. 274 | system_message: system messages in OpenAI format, must have role and content key. 275 | It can has name key to include few-shots examples. 276 | """ 277 | model_inputs = self._prepare_model_inputs( 278 | prompt=prompt, 279 | history=history, 280 | system_message=system_message, 281 | temperature=temperature, 282 | max_tokens=max_tokens, 283 | **kwargs, 284 | ) 285 | 286 | with self.track_latency(): 287 | if self.uses_responses_api: 288 | # Convert messages format for Responses API 289 | input_messages = model_inputs.pop("messages") 290 | # Handle any reasoning_effort parameter 291 | reasoning = {} 292 | if "reasoning_effort" in model_inputs: 293 | reasoning["effort"] = model_inputs.pop("reasoning_effort") 294 | 295 | # Prepare parameters for Responses API 296 | responses_params = { 297 | "model": self.model, 298 | "input": input_messages 299 | } 300 | 301 | # Temperature is not supported for some models with Responses API 302 | # Only add it if the model supports it 303 | 304 | # For Responses API, max_tokens should be converted to max_output_tokens 305 | if max_tokens is not None: 306 | responses_params["max_output_tokens"] = max_tokens 307 | 308 | # Add any other supported parameters 309 | for key, value in model_inputs.items(): 310 | if key not in ["messages", "max_completion_tokens", "max_tokens", "temperature", "reasoning_effort"]: 311 | responses_params[key] = value 312 | 313 | # Add reasoning if present 314 | if reasoning: 315 | responses_params["reasoning"] = reasoning 316 | 317 | response = await self.async_client.responses.create(**responses_params) 318 | # Find the output_text in the response 319 | completion = "" 320 | for item in response.output: 321 | if item.type == "message" and hasattr(item, "content"): 322 | for content_item in item.content: 323 | if content_item.type == "output_text": 324 | completion = content_item.text.strip() 325 | break 326 | 327 | # Usage has different field names in Responses API 328 | usage = { 329 | "prompt_tokens": response.usage.input_tokens, 330 | "completion_tokens": response.usage.output_tokens, 331 | "total_tokens": response.usage.total_tokens 332 | } 333 | elif self.is_chat_model: 334 | response = await self.async_client.chat.completions.create(model=self.model, **model_inputs) 335 | completion = response.choices[0].message.content.strip() 336 | usage = response.usage 337 | else: 338 | response = await self.async_client.completions.create(model=self.model, **model_inputs) 339 | completion = response.choices[0].text.strip() 340 | usage = response.usage 341 | 342 | # Handle usage consistently 343 | if isinstance(usage, dict): 344 | meta = { 345 | "tokens_prompt": usage["prompt_tokens"], 346 | "tokens_completion": usage["completion_tokens"], 347 | "latency": self.latency, 348 | } 349 | else: 350 | meta = { 351 | "tokens_prompt": usage.prompt_tokens, 352 | "tokens_completion": usage.completion_tokens, 353 | "latency": self.latency, 354 | } 355 | return Result( 356 | text=completion, 357 | model_inputs=model_inputs, 358 | provider=self, 359 | meta=meta, 360 | ) 361 | 362 | def complete_stream( 363 | self, 364 | prompt: str, 365 | history: Optional[List[dict]] = None, 366 | system_message: Union[str, List[dict], None] = None, 367 | temperature: float = 0, 368 | max_tokens: int = 300, 369 | **kwargs, 370 | ) -> StreamResult: 371 | """ 372 | Args: 373 | history: messages in OpenAI format, each dict must include role and content key. 374 | system_message: system messages in OpenAI format, must have role and content key. 375 | It can has name key to include few-shots examples. 376 | """ 377 | model_inputs = self._prepare_model_inputs( 378 | prompt=prompt, 379 | history=history, 380 | system_message=system_message, 381 | temperature=temperature, 382 | max_tokens=max_tokens, 383 | stream=True, 384 | **kwargs, 385 | ) 386 | 387 | if self.uses_responses_api: 388 | # Responses API doesn't support streaming in the same way 389 | # For now, we'll use the chat completions API for streaming 390 | response = self.client.chat.completions.create(model=self.model, **model_inputs) 391 | elif self.is_chat_model: 392 | response = self.client.chat.completions.create(model=self.model, **model_inputs) 393 | else: 394 | response = self.client.completions.create(model=self.model, **model_inputs) 395 | stream = self._process_stream(response) 396 | 397 | return StreamResult(stream=stream, model_inputs=model_inputs, provider=self) 398 | 399 | def _process_stream(self, response: Generator) -> Generator: 400 | if self.is_chat_model: 401 | chunk_generator = ( 402 | chunk.choices[0].delta.content for chunk in response 403 | ) 404 | else: 405 | chunk_generator = ( 406 | chunk.choices[0].text for chunk in response 407 | ) 408 | 409 | while not (first_text := next(chunk_generator)): 410 | continue 411 | yield first_text.lstrip() 412 | for chunk in chunk_generator: 413 | if chunk is not None: 414 | yield chunk 415 | 416 | async def acomplete_stream( 417 | self, 418 | prompt: str, 419 | history: Optional[List[dict]] = None, 420 | system_message: Union[str, List[dict], None] = None, 421 | temperature: float = 0, 422 | max_tokens: int = 300, 423 | **kwargs, 424 | ) -> AsyncStreamResult: 425 | """ 426 | Args: 427 | history: messages in OpenAI format, each dict must include role and content key. 428 | system_message: system messages in OpenAI format, must have role and content key. 429 | It can has name key to include few-shots examples. 430 | """ 431 | model_inputs = self._prepare_model_inputs( 432 | prompt=prompt, 433 | history=history, 434 | system_message=system_message, 435 | temperature=temperature, 436 | max_tokens=max_tokens, 437 | stream=True, 438 | **kwargs, 439 | ) 440 | 441 | with self.track_latency(): 442 | if self.uses_responses_api: 443 | # Responses API doesn't support streaming in the same way 444 | # For now, we'll use the chat completions API for streaming 445 | response = await self.async_client.chat.completions.create(model=self.model, **model_inputs) 446 | elif self.is_chat_model: 447 | response = await self.async_client.chat.completions.create(model=self.model, **model_inputs) 448 | else: 449 | response = await self.async_client.completions.create(model=self.model, **model_inputs) 450 | stream = self._aprocess_stream(response) 451 | return AsyncStreamResult( 452 | stream=stream, model_inputs=model_inputs, provider=self 453 | ) 454 | 455 | async def _aprocess_stream(self, response: AsyncGenerator) -> AsyncGenerator: 456 | if self.is_chat_model: 457 | while True: 458 | first_completion = (await response.__anext__()).choices[0].delta.content 459 | if first_completion: 460 | yield first_completion.lstrip() 461 | break 462 | 463 | async for chunk in response: 464 | completion = chunk.choices[0].delta.content 465 | if completion is not None: 466 | yield completion 467 | else: 468 | while True: 469 | first_completion = (await response.__anext__()).choices[0].text 470 | if first_completion: 471 | yield first_completion.lstrip() 472 | break 473 | 474 | async for chunk in response: 475 | completion = chunk.choices[0].text 476 | if completion is not None: 477 | yield completion 478 | --------------------------------------------------------------------------------