├── .env.sample ├── .github └── workflows │ ├── black.yml │ └── run_pytest.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── aisuite ├── __init__.py ├── client.py ├── framework │ ├── __init__.py │ ├── chat_completion_response.py │ ├── choice.py │ ├── message.py │ └── provider_interface.py ├── provider.py ├── providers │ ├── __init__.py │ ├── anthropic_provider.py │ ├── aws_provider.py │ ├── azure_provider.py │ ├── cerebras_provider.py │ ├── cohere_provider.py │ ├── deepseek_provider.py │ ├── fireworks_provider.py │ ├── google_provider.py │ ├── groq_provider.py │ ├── huggingface_provider.py │ ├── message_converter.py │ ├── mistral_provider.py │ ├── nebius_provider.py │ ├── ollama_provider.py │ ├── openai_provider.py │ ├── sambanova_provider.py │ ├── together_provider.py │ ├── watsonx_provider.py │ └── xai_provider.py └── utils │ └── tools.py ├── examples ├── AISuiteDemo.ipynb ├── DeepseekPost.ipynb ├── QnA_with_pdf.ipynb ├── aisuite_tool_abstraction.ipynb ├── chat-ui │ ├── .streamlit │ │ └── config.toml │ ├── README.md │ ├── chat.py │ └── config.yaml ├── client.ipynb ├── llm_reasoning.ipynb ├── simple_tool_calling.ipynb └── tool_calling_abstraction.ipynb ├── guides ├── README.md ├── anthropic.md ├── aws.md ├── azure.md ├── cerebras.md ├── cohere.md ├── deepseek.md ├── google.md ├── groq.md ├── huggingface.md ├── mistral.md ├── nebius.md ├── openai.md ├── sambanova.md ├── watsonx.md └── xai.md ├── poetry.lock ├── pyproject.toml └── tests ├── __init__.py ├── client ├── __init__.py ├── test_client.py └── test_prerelease.py ├── providers ├── __init__.py ├── test_anthropic_converter.py ├── test_aws_converter.py ├── test_azure_provider.py ├── test_cerebras_provider.py ├── test_cohere_provider.py ├── test_deepseek_provider.py ├── test_google_converter.py ├── test_google_provider.py ├── test_groq_provider.py ├── test_mistral_provider.py ├── test_nebius_provider.py ├── test_ollama_provider.py ├── test_sambanova_provider.py └── test_watsonx_provider.py └── utils └── test_tool_manager.py /.env.sample: -------------------------------------------------------------------------------- 1 | # OpenAI API Key 2 | OPENAI_API_KEY= 3 | 4 | # Anthropic API Key 5 | ANTHROPIC_API_KEY= 6 | 7 | # AWS SDK credentials 8 | AWS_ACCESS_KEY_ID= 9 | AWS_SECRET_ACCESS_KEY= 10 | AWS_REGION= 11 | 12 | # Azure 13 | AZURE_API_KEY= 14 | 15 | # Cerebras 16 | CEREBRAS_API_KEY= 17 | 18 | # Google Cloud 19 | GOOGLE_APPLICATION_CREDENTIALS=./google-adc 20 | GOOGLE_REGION= 21 | GOOGLE_PROJECT_ID= 22 | 23 | # Hugging Face token 24 | HF_TOKEN= 25 | 26 | # Fireworks 27 | FIREWORKS_API_KEY= 28 | 29 | # Mistral 30 | MISTRAL_API_KEY= 31 | 32 | # Together AI 33 | TOGETHER_API_KEY= 34 | 35 | # WatsonX 36 | WATSONX_SERVICE_URL= 37 | WATSONX_API_KEY= 38 | WATSONX_PROJECT_ID= 39 | 40 | # xAI 41 | XAI_API_KEY= 42 | 43 | # Sambanova 44 | SAMBANOVA_API_KEY= 45 | -------------------------------------------------------------------------------- /.github/workflows/black.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | lint: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v3 10 | - uses: psf/black@stable -------------------------------------------------------------------------------- /.github/workflows/run_pytest.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build_and_test: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | python-version: [ "3.10", "3.11", "3.12" ] 11 | steps: 12 | - uses: actions/checkout@v4 13 | - name: Set up Python ${{ matrix.python-version }} 14 | uses: actions/setup-python@v5 15 | with: 16 | python-version: ${{ matrix.python-version }} 17 | - name: Install dependencies 18 | run: | 19 | python -m pip install --upgrade pip 20 | pip install poetry 21 | poetry install --all-extras --with test 22 | - name: Test with pytest 23 | run: poetry run pytest -m "not integration" 24 | 25 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | .vscode/ 3 | __pycache__/ 4 | env/ 5 | .env 6 | .google-adc 7 | *.whl 8 | 9 | # Testing 10 | .coverage 11 | 12 | # pyenv 13 | .python-version 14 | 15 | .DS_Store 16 | **/.DS_Store 17 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | # Using this mirror lets us use mypyc-compiled black, which is about 2x faster 3 | - repo: https://github.com/psf/black-pre-commit-mirror 4 | rev: 24.4.2 5 | hooks: 6 | - id: black 7 | # It is recommended to specify the latest version of Python 8 | # supported by your project here, or alternatively use 9 | # pre-commit's default_language_version, see 10 | # https://pre-commit.com/#top_level-default_language_version 11 | language_version: python3.12 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Andrew Ng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and 6 | associated documentation files (the "Software"), to deal in the Software without restriction, including 7 | without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the 9 | following conditions: 10 | 11 | The above copyright notice and this permission notice shall be included in all copies or substantial 12 | portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT 15 | LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 16 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 17 | WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 18 | SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # aisuite 2 | 3 | [![PyPI](https://img.shields.io/pypi/v/aisuite)](https://pypi.org/project/aisuite/) 4 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 5 | 6 | Simple, unified interface to multiple Generative AI providers. 7 | 8 | `aisuite` makes it easy for developers to use multiple LLM through a standardized interface. Using an interface similar to OpenAI's, `aisuite` makes it easy to interact with the most popular LLMs and compare the results. It is a thin wrapper around python client libraries, and allows creators to seamlessly swap out and test responses from different LLM providers without changing their code. Today, the library is primarily focussed on chat completions. We will expand it cover more use cases in near future. 9 | 10 | Currently supported providers are: 11 | - Anthropic 12 | - AWS 13 | - Azure 14 | - Cerebras 15 | - Google 16 | - Groq 17 | - HuggingFace Ollama 18 | - Mistral 19 | - OpenAI 20 | - Sambanova 21 | - Watsonx 22 | 23 | To maximize stability, `aisuite` uses either the HTTP endpoint or the SDK for making calls to the provider. 24 | 25 | ## Installation 26 | 27 | You can install just the base `aisuite` package, or install a provider's package along with `aisuite`. 28 | 29 | This installs just the base package without installing any provider's SDK. 30 | 31 | ```shell 32 | pip install aisuite 33 | ``` 34 | 35 | This installs aisuite along with anthropic's library. 36 | 37 | ```shell 38 | pip install 'aisuite[anthropic]' 39 | ``` 40 | 41 | This installs all the provider-specific libraries 42 | 43 | ```shell 44 | pip install 'aisuite[all]' 45 | ``` 46 | 47 | ## Set up 48 | 49 | To get started, you will need API Keys for the providers you intend to use. You'll need to 50 | install the provider-specific library either separately or when installing aisuite. 51 | 52 | The API Keys can be set as environment variables, or can be passed as config to the aisuite Client constructor. 53 | You can use tools like [`python-dotenv`](https://pypi.org/project/python-dotenv/) or [`direnv`](https://direnv.net/) to set the environment variables manually. Please take a look at the `examples` folder to see usage. 54 | 55 | Here is a short example of using `aisuite` to generate chat completion responses from gpt-4o and claude-3-5-sonnet. 56 | 57 | Set the API keys. 58 | 59 | ```shell 60 | export OPENAI_API_KEY="your-openai-api-key" 61 | export ANTHROPIC_API_KEY="your-anthropic-api-key" 62 | ``` 63 | 64 | Use the python client. 65 | 66 | ```python 67 | import aisuite as ai 68 | client = ai.Client() 69 | 70 | models = ["openai:gpt-4o", "anthropic:claude-3-5-sonnet-20240620"] 71 | 72 | messages = [ 73 | {"role": "system", "content": "Respond in Pirate English."}, 74 | {"role": "user", "content": "Tell me a joke."}, 75 | ] 76 | 77 | for model in models: 78 | response = client.chat.completions.create( 79 | model=model, 80 | messages=messages, 81 | temperature=0.75 82 | ) 83 | print(response.choices[0].message.content) 84 | 85 | ``` 86 | 87 | Note that the model name in the create() call uses the format - `:`. 88 | `aisuite` will call the appropriate provider with the right parameters based on the provider value. 89 | For a list of provider values, you can look at the directory - `aisuite/providers/`. The list of supported providers are of the format - `_provider.py` in that directory. We welcome providers adding support to this library by adding an implementation file in this directory. Please see section below for how to contribute. 90 | 91 | For more examples, check out the `examples` directory where you will find several notebooks that you can run to experiment with the interface. 92 | 93 | ## Adding support for a provider 94 | 95 | We have made easy for a provider or volunteer to add support for a new platform. 96 | 97 | ### Naming Convention for Provider Modules 98 | 99 | We follow a convention-based approach for loading providers, which relies on strict naming conventions for both the module name and the class name. The format is based on the model identifier in the form `provider:model`. 100 | 101 | - The provider's module file must be named in the format `_provider.py`. 102 | - The class inside this module must follow the format: the provider name with the first letter capitalized, followed by the suffix `Provider`. 103 | 104 | #### Examples 105 | 106 | - **Hugging Face**: 107 | The provider class should be defined as: 108 | 109 | ```python 110 | class HuggingfaceProvider(BaseProvider) 111 | ``` 112 | 113 | in providers/huggingface_provider.py. 114 | 115 | - **OpenAI**: 116 | The provider class should be defined as: 117 | 118 | ```python 119 | class OpenaiProvider(BaseProvider) 120 | ``` 121 | 122 | in providers/openai_provider.py 123 | 124 | This convention simplifies the addition of new providers and ensures consistency across provider implementations. 125 | 126 | ## Tool Calling 127 | 128 | `aisuite` provides a simple abstraction for tool/function calling that works across supported providers. This is in addition to the regular abstraction of passing JSON spec of the tool to the `tools` parameter. The tool calling abstraction makes it easy to use tools with different LLMs without changing your code. 129 | 130 | There are two ways to use tools with `aisuite`: 131 | 132 | ### 1. Manual Tool Handling 133 | 134 | This is the default behavior when `max_turns` is not specified. 135 | You can pass tools in the OpenAI tool format: 136 | 137 | ```python 138 | def will_it_rain(location: str, time_of_day: str): 139 | """Check if it will rain in a location at a given time today. 140 | 141 | Args: 142 | location (str): Name of the city 143 | time_of_day (str): Time of the day in HH:MM format. 144 | """ 145 | return "YES" 146 | 147 | tools = [{ 148 | "type": "function", 149 | "function": { 150 | "name": "will_it_rain", 151 | "description": "Check if it will rain in a location at a given time today", 152 | "parameters": { 153 | "type": "object", 154 | "properties": { 155 | "location": { 156 | "type": "string", 157 | "description": "Name of the city" 158 | }, 159 | "time_of_day": { 160 | "type": "string", 161 | "description": "Time of the day in HH:MM format." 162 | } 163 | }, 164 | "required": ["location", "time_of_day"] 165 | } 166 | } 167 | }] 168 | 169 | response = client.chat.completions.create( 170 | model="openai:gpt-4o", 171 | messages=messages, 172 | tools=tools 173 | ) 174 | ``` 175 | 176 | ### 2. Automatic Tool Execution 177 | 178 | When `max_turns` is specified, you can pass a list of callable Python functions as the `tools` parameter. `aisuite` will automatically handle the tool calling flow: 179 | 180 | ```python 181 | def will_it_rain(location: str, time_of_day: str): 182 | """Check if it will rain in a location at a given time today. 183 | 184 | Args: 185 | location (str): Name of the city 186 | time_of_day (str): Time of the day in HH:MM format. 187 | """ 188 | return "YES" 189 | 190 | client = ai.Client() 191 | messages = [{ 192 | "role": "user", 193 | "content": "I live in San Francisco. Can you check for weather " 194 | "and plan an outdoor picnic for me at 2pm?" 195 | }] 196 | 197 | # Automatic tool execution with max_turns 198 | response = client.chat.completions.create( 199 | model="openai:gpt-4o", 200 | messages=messages, 201 | tools=[will_it_rain], 202 | max_turns=2 # Maximum number of back-and-forth tool calls 203 | ) 204 | print(response.choices[0].message.content) 205 | ``` 206 | 207 | When `max_turns` is specified, `aisuite` will: 208 | 1. Send your message to the LLM 209 | 2. Execute any tool calls the LLM requests 210 | 3. Send the tool results back to the LLM 211 | 4. Repeat until the conversation is complete or max_turns is reached 212 | 213 | In addition to `response.choices[0].message`, there is an additional field `response.choices[0].intermediate_messages`: which contains the list of all messages including tool interactions used. This can be used to continue the conversation with the model. 214 | For more detailed examples of tool calling, check out the `examples/tool_calling_abstraction.ipynb` notebook. 215 | 216 | ## License 217 | 218 | aisuite is released under the MIT License. You are free to use, modify, and distribute the code for both commercial and non-commercial purposes. 219 | 220 | ## Contributing 221 | 222 | If you would like to contribute, please read our [Contributing Guide](https://github.com/andrewyng/aisuite/blob/main/CONTRIBUTING.md) and join our [Discord](https://discord.gg/T6Nvn8ExSb) server! 223 | -------------------------------------------------------------------------------- /aisuite/__init__.py: -------------------------------------------------------------------------------- 1 | from .client import Client 2 | from .framework.message import Message 3 | from .utils.tools import Tools 4 | -------------------------------------------------------------------------------- /aisuite/framework/__init__.py: -------------------------------------------------------------------------------- 1 | from .provider_interface import ProviderInterface 2 | from .chat_completion_response import ChatCompletionResponse 3 | from .message import Message 4 | -------------------------------------------------------------------------------- /aisuite/framework/chat_completion_response.py: -------------------------------------------------------------------------------- 1 | from aisuite.framework.choice import Choice 2 | 3 | 4 | class ChatCompletionResponse: 5 | """Used to conform to the response model of OpenAI""" 6 | 7 | def __init__(self): 8 | self.choices = [Choice()] # Adjust the range as needed for more choices 9 | -------------------------------------------------------------------------------- /aisuite/framework/choice.py: -------------------------------------------------------------------------------- 1 | from aisuite.framework.message import Message 2 | from typing import Literal, Optional, List 3 | 4 | 5 | class Choice: 6 | def __init__(self): 7 | self.finish_reason: Optional[Literal["stop", "tool_calls"]] = None 8 | self.message = Message( 9 | content=None, 10 | tool_calls=None, 11 | role="assistant", 12 | refusal=None, 13 | reasoning_content=None, 14 | ) 15 | self.intermediate_messages: List[Message] = [] 16 | -------------------------------------------------------------------------------- /aisuite/framework/message.py: -------------------------------------------------------------------------------- 1 | """Interface to hold contents of api responses when they do not confirm to the OpenAI style response""" 2 | 3 | from pydantic import BaseModel 4 | from typing import Literal, Optional, List 5 | 6 | 7 | class Function(BaseModel): 8 | arguments: str 9 | name: str 10 | 11 | 12 | class ChatCompletionMessageToolCall(BaseModel): 13 | id: str 14 | function: Function 15 | type: Literal["function"] 16 | 17 | 18 | class Message(BaseModel): 19 | content: Optional[str] = None 20 | reasoning_content: Optional[str] = None 21 | tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None 22 | role: Optional[Literal["user", "assistant", "system", "tool"]] = None 23 | refusal: Optional[str] = None 24 | -------------------------------------------------------------------------------- /aisuite/framework/provider_interface.py: -------------------------------------------------------------------------------- 1 | """The shared interface for model providers.""" 2 | 3 | 4 | # TODO(rohit): Remove this. This interface is obsolete in favor of Provider. 5 | class ProviderInterface: 6 | """Defines the expected behavior for provider-specific interfaces.""" 7 | 8 | def chat_completion_create(self, messages=None, model=None, temperature=0) -> None: 9 | """Create a chat completion using the specified messages, model, and temperature. 10 | 11 | This method must be implemented by subclasses to perform completions. 12 | 13 | Args: 14 | ---- 15 | messages (list): The chat history. 16 | model (str): The identifier of the model to be used in the completion. 17 | temperature (float): The temperature to use in the completion. 18 | 19 | Raises: 20 | ------ 21 | NotImplementedError: If this method has not been implemented by a subclass. 22 | 23 | """ 24 | raise NotImplementedError( 25 | "Provider Interface has not implemented chat_completion_create()" 26 | ) 27 | -------------------------------------------------------------------------------- /aisuite/provider.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from pathlib import Path 3 | import importlib 4 | import os 5 | import functools 6 | 7 | 8 | class LLMError(Exception): 9 | """Custom exception for LLM errors.""" 10 | 11 | def __init__(self, message): 12 | super().__init__(message) 13 | 14 | 15 | class Provider(ABC): 16 | @abstractmethod 17 | def chat_completions_create(self, model, messages): 18 | """Abstract method for chat completion calls, to be implemented by each provider.""" 19 | pass 20 | 21 | 22 | class ProviderFactory: 23 | """Factory to dynamically load provider instances based on naming conventions.""" 24 | 25 | PROVIDERS_DIR = Path(__file__).parent / "providers" 26 | 27 | @classmethod 28 | def create_provider(cls, provider_key, config): 29 | """Dynamically load and create an instance of a provider based on the naming convention.""" 30 | # Convert provider_key to the expected module and class names 31 | provider_class_name = f"{provider_key.capitalize()}Provider" 32 | provider_module_name = f"{provider_key}_provider" 33 | 34 | module_path = f"aisuite.providers.{provider_module_name}" 35 | 36 | # Lazily load the module 37 | try: 38 | module = importlib.import_module(module_path) 39 | except ImportError as e: 40 | raise ImportError( 41 | f"Could not import module {module_path}: {str(e)}. Please ensure the provider is supported by doing ProviderFactory.get_supported_providers()" 42 | ) 43 | 44 | # Instantiate the provider class 45 | provider_class = getattr(module, provider_class_name) 46 | return provider_class(**config) 47 | 48 | @classmethod 49 | @functools.cache 50 | def get_supported_providers(cls): 51 | """List all supported provider names based on files present in the providers directory.""" 52 | provider_files = Path(cls.PROVIDERS_DIR).glob("*_provider.py") 53 | return {file.stem.replace("_provider", "") for file in provider_files} 54 | -------------------------------------------------------------------------------- /aisuite/providers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewyng/aisuite/ddba58493e6e43b2ff7d507779e3c7735c09c33e/aisuite/providers/__init__.py -------------------------------------------------------------------------------- /aisuite/providers/anthropic_provider.py: -------------------------------------------------------------------------------- 1 | # Anthropic provider 2 | # Links: 3 | # Tool calling docs - https://docs.anthropic.com/en/docs/build-with-claude/tool-use 4 | 5 | import anthropic 6 | import json 7 | from aisuite.provider import Provider 8 | from aisuite.framework import ChatCompletionResponse 9 | from aisuite.framework.message import Message, ChatCompletionMessageToolCall, Function 10 | 11 | # Define a constant for the default max_tokens value 12 | DEFAULT_MAX_TOKENS = 4096 13 | 14 | 15 | class AnthropicMessageConverter: 16 | # Role constants 17 | ROLE_USER = "user" 18 | ROLE_ASSISTANT = "assistant" 19 | ROLE_TOOL = "tool" 20 | ROLE_SYSTEM = "system" 21 | 22 | # Finish reason mapping 23 | FINISH_REASON_MAPPING = { 24 | "end_turn": "stop", 25 | "max_tokens": "length", 26 | "tool_use": "tool_calls", 27 | } 28 | 29 | def convert_request(self, messages): 30 | """Convert framework messages to Anthropic format.""" 31 | system_message = self._extract_system_message(messages) 32 | converted_messages = [self._convert_single_message(msg) for msg in messages] 33 | return system_message, converted_messages 34 | 35 | def convert_response(self, response): 36 | """Normalize the response from the Anthropic API to match OpenAI's response format.""" 37 | normalized_response = ChatCompletionResponse() 38 | normalized_response.choices[0].finish_reason = self._get_finish_reason(response) 39 | normalized_response.usage = self._get_usage_stats(response) 40 | normalized_response.choices[0].message = self._get_message(response) 41 | return normalized_response 42 | 43 | def _convert_single_message(self, msg): 44 | """Convert a single message to Anthropic format.""" 45 | if isinstance(msg, dict): 46 | return self._convert_dict_message(msg) 47 | return self._convert_message_object(msg) 48 | 49 | def _convert_dict_message(self, msg): 50 | """Convert a dictionary message to Anthropic format.""" 51 | if msg["role"] == self.ROLE_TOOL: 52 | return self._create_tool_result_message(msg["tool_call_id"], msg["content"]) 53 | elif msg["role"] == self.ROLE_ASSISTANT and "tool_calls" in msg: 54 | return self._create_assistant_tool_message( 55 | msg["content"], msg["tool_calls"] 56 | ) 57 | return {"role": msg["role"], "content": msg["content"]} 58 | 59 | def _convert_message_object(self, msg): 60 | """Convert a Message object to Anthropic format.""" 61 | if msg.role == self.ROLE_TOOL: 62 | return self._create_tool_result_message(msg.tool_call_id, msg.content) 63 | elif msg.role == self.ROLE_ASSISTANT and msg.tool_calls: 64 | return self._create_assistant_tool_message(msg.content, msg.tool_calls) 65 | return {"role": msg.role, "content": msg.content} 66 | 67 | def _create_tool_result_message(self, tool_call_id, content): 68 | """Create a tool result message in Anthropic format.""" 69 | return { 70 | "role": self.ROLE_USER, 71 | "content": [ 72 | { 73 | "type": "tool_result", 74 | "tool_use_id": tool_call_id, 75 | "content": content, 76 | } 77 | ], 78 | } 79 | 80 | def _create_assistant_tool_message(self, content, tool_calls): 81 | """Create an assistant message with tool calls in Anthropic format.""" 82 | message_content = [] 83 | if content: 84 | message_content.append({"type": "text", "text": content}) 85 | 86 | for tool_call in tool_calls: 87 | tool_input = ( 88 | tool_call["function"]["arguments"] 89 | if isinstance(tool_call, dict) 90 | else tool_call.function.arguments 91 | ) 92 | message_content.append( 93 | { 94 | "type": "tool_use", 95 | "id": ( 96 | tool_call["id"] if isinstance(tool_call, dict) else tool_call.id 97 | ), 98 | "name": ( 99 | tool_call["function"]["name"] 100 | if isinstance(tool_call, dict) 101 | else tool_call.function.name 102 | ), 103 | "input": json.loads(tool_input), 104 | } 105 | ) 106 | 107 | return {"role": self.ROLE_ASSISTANT, "content": message_content} 108 | 109 | def _extract_system_message(self, messages): 110 | """Extract system message if present, otherwise return empty list.""" 111 | # TODO: This is a temporary solution to extract the system message. 112 | # User can pass multiple system messages, which can mingled with other messages. 113 | # This needs to be fixed to handle this case. 114 | if messages and messages[0]["role"] == "system": 115 | system_message = messages[0]["content"] 116 | messages.pop(0) 117 | return system_message 118 | return [] 119 | 120 | def _get_finish_reason(self, response): 121 | """Get the normalized finish reason.""" 122 | return self.FINISH_REASON_MAPPING.get(response.stop_reason, "stop") 123 | 124 | def _get_usage_stats(self, response): 125 | """Get the usage statistics.""" 126 | return { 127 | "prompt_tokens": response.usage.input_tokens, 128 | "completion_tokens": response.usage.output_tokens, 129 | "total_tokens": response.usage.input_tokens + response.usage.output_tokens, 130 | } 131 | 132 | def _get_message(self, response): 133 | """Get the appropriate message based on response type.""" 134 | if response.stop_reason == "tool_use": 135 | tool_message = self.convert_response_with_tool_use(response) 136 | if tool_message: 137 | return tool_message 138 | 139 | return Message( 140 | content=response.content[0].text, 141 | role="assistant", 142 | tool_calls=None, 143 | refusal=None, 144 | ) 145 | 146 | def convert_response_with_tool_use(self, response): 147 | """Convert Anthropic tool use response to the framework's format.""" 148 | tool_call = next( 149 | (content for content in response.content if content.type == "tool_use"), 150 | None, 151 | ) 152 | 153 | if tool_call: 154 | function = Function( 155 | name=tool_call.name, arguments=json.dumps(tool_call.input) 156 | ) 157 | tool_call_obj = ChatCompletionMessageToolCall( 158 | id=tool_call.id, function=function, type="function" 159 | ) 160 | text_content = next( 161 | ( 162 | content.text 163 | for content in response.content 164 | if content.type == "text" 165 | ), 166 | "", 167 | ) 168 | 169 | return Message( 170 | content=text_content or None, 171 | tool_calls=[tool_call_obj] if tool_call else None, 172 | role="assistant", 173 | refusal=None, 174 | ) 175 | return None 176 | 177 | def convert_tool_spec(self, openai_tools): 178 | """Convert OpenAI tool specification to Anthropic format.""" 179 | anthropic_tools = [] 180 | 181 | for tool in openai_tools: 182 | if tool.get("type") != "function": 183 | continue 184 | 185 | function = tool["function"] 186 | anthropic_tool = { 187 | "name": function["name"], 188 | "description": function["description"], 189 | "input_schema": { 190 | "type": "object", 191 | "properties": function["parameters"]["properties"], 192 | "required": function["parameters"].get("required", []), 193 | }, 194 | } 195 | anthropic_tools.append(anthropic_tool) 196 | 197 | return anthropic_tools 198 | 199 | 200 | class AnthropicProvider(Provider): 201 | def __init__(self, **config): 202 | """Initialize the Anthropic provider with the given configuration.""" 203 | self.client = anthropic.Anthropic(**config) 204 | self.converter = AnthropicMessageConverter() 205 | 206 | def chat_completions_create(self, model, messages, **kwargs): 207 | """Create a chat completion using the Anthropic API.""" 208 | kwargs = self._prepare_kwargs(kwargs) 209 | system_message, converted_messages = self.converter.convert_request(messages) 210 | 211 | response = self.client.messages.create( 212 | model=model, system=system_message, messages=converted_messages, **kwargs 213 | ) 214 | return self.converter.convert_response(response) 215 | 216 | def _prepare_kwargs(self, kwargs): 217 | """Prepare kwargs for the API call.""" 218 | kwargs = kwargs.copy() 219 | kwargs.setdefault("max_tokens", DEFAULT_MAX_TOKENS) 220 | 221 | if "tools" in kwargs: 222 | kwargs["tools"] = self.converter.convert_tool_spec(kwargs["tools"]) 223 | 224 | return kwargs 225 | -------------------------------------------------------------------------------- /aisuite/providers/azure_provider.py: -------------------------------------------------------------------------------- 1 | import urllib.request 2 | import json 3 | import os 4 | 5 | from aisuite.provider import Provider 6 | from aisuite.framework import ChatCompletionResponse 7 | from aisuite.framework.message import Message, ChatCompletionMessageToolCall, Function 8 | 9 | # Azure provider is based on the documentation here - 10 | # https://learn.microsoft.com/en-us/azure/machine-learning/reference-model-inference-api?view=azureml-api-2&source=recommendations&tabs=python 11 | # Azure AI Model Inference API is used. 12 | # From the documentation - 13 | # """ 14 | # The Azure AI Model Inference is an API that exposes a common set of capabilities for foundational models 15 | # and that can be used by developers to consume predictions from a diverse set of models in a uniform and consistent way. 16 | # Developers can talk with different models deployed in Azure AI Foundry portal without changing the underlying code they are using. 17 | # 18 | # The Azure AI Model Inference API is available in the following models: 19 | # 20 | # Models deployed to serverless API endpoints: 21 | # Cohere Embed V3 family of models 22 | # Cohere Command R family of models 23 | # Meta Llama 2 chat family of models 24 | # Meta Llama 3 instruct family of models 25 | # Mistral-Small 26 | # Mistral-Large 27 | # Jais family of models 28 | # Jamba family of models 29 | # Phi-3 family of models 30 | # 31 | # Models deployed to managed inference: 32 | # Meta Llama 3 instruct family of models 33 | # Phi-3 family of models 34 | # Mixtral famility of models 35 | # 36 | # The API is compatible with Azure OpenAI model deployments. 37 | # """ 38 | 39 | 40 | class AzureMessageConverter: 41 | @staticmethod 42 | def convert_request(messages): 43 | """Convert messages to Azure format.""" 44 | transformed_messages = [] 45 | for message in messages: 46 | if isinstance(message, Message): 47 | transformed_messages.append(message.model_dump(mode="json")) 48 | else: 49 | transformed_messages.append(message) 50 | return transformed_messages 51 | 52 | @staticmethod 53 | def convert_response(resp_json) -> ChatCompletionResponse: 54 | """Normalize the response from the Azure API to match OpenAI's response format.""" 55 | completion_response = ChatCompletionResponse() 56 | choice = resp_json["choices"][0] 57 | message = choice["message"] 58 | 59 | # Set basic message content 60 | completion_response.choices[0].message.content = message.get("content") 61 | completion_response.choices[0].message.role = message.get("role", "assistant") 62 | 63 | # Handle tool calls if present 64 | if "tool_calls" in message and message["tool_calls"] is not None: 65 | tool_calls = [] 66 | for tool_call in message["tool_calls"]: 67 | new_tool_call = ChatCompletionMessageToolCall( 68 | id=tool_call["id"], 69 | type=tool_call["type"], 70 | function={ 71 | "name": tool_call["function"]["name"], 72 | "arguments": tool_call["function"]["arguments"], 73 | }, 74 | ) 75 | tool_calls.append(new_tool_call) 76 | completion_response.choices[0].message.tool_calls = tool_calls 77 | 78 | return completion_response 79 | 80 | 81 | class AzureProvider(Provider): 82 | def __init__(self, **config): 83 | self.base_url = config.get("base_url") or os.getenv("AZURE_BASE_URL") 84 | self.api_key = config.get("api_key") or os.getenv("AZURE_API_KEY") 85 | self.api_version = config.get("api_version") or os.getenv("AZURE_API_VERSION") 86 | if not self.api_key: 87 | raise ValueError("For Azure, api_key is required.") 88 | if not self.base_url: 89 | raise ValueError( 90 | "For Azure, base_url is required. Check your deployment page for a URL like this - https://..models.ai.azure.com" 91 | ) 92 | self.transformer = AzureMessageConverter() 93 | 94 | def chat_completions_create(self, model, messages, **kwargs): 95 | url = f"{self.base_url}/chat/completions" 96 | 97 | if self.api_version: 98 | url = f"{url}?api-version={self.api_version}" 99 | 100 | # Remove 'stream' from kwargs if present 101 | kwargs.pop("stream", None) 102 | 103 | # Transform messages using converter 104 | transformed_messages = self.transformer.convert_request(messages) 105 | 106 | # Prepare the request payload 107 | data = {"messages": transformed_messages} 108 | 109 | # Add tools if provided 110 | if "tools" in kwargs: 111 | data["tools"] = kwargs["tools"] 112 | kwargs.pop("tools") 113 | 114 | # Add tool_choice if provided 115 | if "tool_choice" in kwargs: 116 | data["tool_choice"] = kwargs["tool_choice"] 117 | kwargs.pop("tool_choice") 118 | 119 | # Add remaining kwargs 120 | data.update(kwargs) 121 | 122 | body = json.dumps(data).encode("utf-8") 123 | headers = {"Content-Type": "application/json", "Authorization": self.api_key} 124 | 125 | try: 126 | req = urllib.request.Request(url, body, headers) 127 | with urllib.request.urlopen(req) as response: 128 | result = response.read() 129 | resp_json = json.loads(result) 130 | return self.transformer.convert_response(resp_json) 131 | 132 | except urllib.error.HTTPError as error: 133 | error_message = f"The request failed with status code: {error.code}\n" 134 | error_message += f"Headers: {error.info()}\n" 135 | error_message += error.read().decode("utf-8", "ignore") 136 | raise Exception(error_message) 137 | -------------------------------------------------------------------------------- /aisuite/providers/cerebras_provider.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cerebras.cloud.sdk as cerebras 3 | from aisuite.provider import Provider, LLMError 4 | from aisuite.providers.message_converter import OpenAICompliantMessageConverter 5 | 6 | 7 | class CerebrasMessageConverter(OpenAICompliantMessageConverter): 8 | """ 9 | Cerebras-specific message converter if needed. 10 | """ 11 | 12 | pass 13 | 14 | 15 | class CerebrasProvider(Provider): 16 | def __init__(self, **config): 17 | self.client = cerebras.Cerebras(**config) 18 | self.transformer = CerebrasMessageConverter() 19 | 20 | def chat_completions_create(self, model, messages, **kwargs): 21 | """ 22 | Makes a request to the Cerebras chat completions endpoint using the official client. 23 | """ 24 | try: 25 | response = self.client.chat.completions.create( 26 | model=model, 27 | messages=messages, 28 | **kwargs, # Pass any additional arguments to the Cerebras API. 29 | ) 30 | return self.transformer.convert_response(response.model_dump()) 31 | 32 | # Re-raise Cerebras API-specific exceptions. 33 | except cerebras.cloud.sdk.PermissionDeniedError as e: 34 | raise 35 | except cerebras.cloud.sdk.AuthenticationError as e: 36 | raise 37 | except cerebras.cloud.sdk.RateLimitError as e: 38 | raise 39 | 40 | # Wrap all other exceptions in LLMError. 41 | except Exception as e: 42 | raise LLMError(f"An error occurred: {e}") 43 | -------------------------------------------------------------------------------- /aisuite/providers/cohere_provider.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cohere 3 | import json 4 | from aisuite.framework import ChatCompletionResponse 5 | from aisuite.framework.message import Message, ChatCompletionMessageToolCall, Function 6 | from aisuite.provider import Provider, LLMError 7 | 8 | 9 | class CohereMessageConverter: 10 | """ 11 | Cohere-specific message converter 12 | """ 13 | 14 | def convert_request(self, messages): 15 | """Convert framework messages to Cohere format.""" 16 | converted_messages = [] 17 | 18 | for message in messages: 19 | if isinstance(message, dict): 20 | role = message.get("role") 21 | content = message.get("content") 22 | tool_calls = message.get("tool_calls") 23 | tool_plan = message.get("tool_plan") 24 | else: 25 | role = message.role 26 | content = message.content 27 | tool_calls = message.tool_calls 28 | tool_plan = getattr(message, "tool_plan", None) 29 | 30 | # Convert to Cohere's format 31 | if role == "tool": 32 | # Handle tool response messages 33 | converted_message = { 34 | "role": role, 35 | "tool_call_id": ( 36 | message.get("tool_call_id") 37 | if isinstance(message, dict) 38 | else message.tool_call_id 39 | ), 40 | "content": self._convert_tool_content(content), 41 | } 42 | elif role == "assistant" and tool_calls: 43 | # Handle assistant messages with tool calls 44 | converted_message = { 45 | "role": role, 46 | "tool_calls": [ 47 | { 48 | "id": tc.id if not isinstance(tc, dict) else tc["id"], 49 | "function": { 50 | "name": ( 51 | tc.function.name 52 | if not isinstance(tc, dict) 53 | else tc["function"]["name"] 54 | ), 55 | "arguments": ( 56 | tc.function.arguments 57 | if not isinstance(tc, dict) 58 | else tc["function"]["arguments"] 59 | ), 60 | }, 61 | "type": "function", 62 | } 63 | for tc in tool_calls 64 | ], 65 | "tool_plan": tool_plan, 66 | } 67 | if content: 68 | converted_message["content"] = content 69 | else: 70 | # Handle regular messages 71 | converted_message = {"role": role, "content": content} 72 | 73 | converted_messages.append(converted_message) 74 | 75 | return converted_messages 76 | 77 | def _convert_tool_content(self, content): 78 | """Convert tool response content to Cohere's expected format.""" 79 | if isinstance(content, str): 80 | try: 81 | # Try to parse as JSON first 82 | data = json.loads(content) 83 | return [{"type": "document", "document": {"data": json.dumps(data)}}] 84 | except json.JSONDecodeError: 85 | # If not JSON, return as plain text 86 | return content 87 | elif isinstance(content, list): 88 | # If content is already in Cohere's format, return as is 89 | return content 90 | else: 91 | # For other types, convert to string 92 | return str(content) 93 | 94 | @staticmethod 95 | def convert_response(response_data) -> ChatCompletionResponse: 96 | """Convert Cohere's response to our standard format.""" 97 | normalized_response = ChatCompletionResponse() 98 | 99 | # Set usage information 100 | normalized_response.usage = { 101 | "prompt_tokens": response_data.usage.tokens.input_tokens, 102 | "completion_tokens": response_data.usage.tokens.output_tokens, 103 | "total_tokens": response_data.usage.tokens.input_tokens 104 | + response_data.usage.tokens.output_tokens, 105 | } 106 | 107 | # Handle tool calls 108 | if response_data.finish_reason == "TOOL_CALL": 109 | tool_call = response_data.message.tool_calls[0] 110 | function = Function( 111 | name=tool_call.function.name, arguments=tool_call.function.arguments 112 | ) 113 | tool_call_obj = ChatCompletionMessageToolCall( 114 | id=tool_call.id, function=function, type="function" 115 | ) 116 | normalized_response.choices[0].message = Message( 117 | content=response_data.message.tool_plan, # Use tool_plan as content 118 | tool_calls=[tool_call_obj], 119 | role="assistant", 120 | refusal=None, 121 | ) 122 | normalized_response.choices[0].finish_reason = "tool_calls" 123 | else: 124 | # Handle regular text response 125 | normalized_response.choices[0].message.content = ( 126 | response_data.message.content[0].text 127 | ) 128 | normalized_response.choices[0].finish_reason = "stop" 129 | 130 | return normalized_response 131 | 132 | 133 | class CohereProvider(Provider): 134 | def __init__(self, **config): 135 | """ 136 | Initialize the Cohere provider with the given configuration. 137 | Pass the entire configuration dictionary to the Cohere client constructor. 138 | """ 139 | # Ensure API key is provided either in config or via environment variable 140 | config.setdefault("api_key", os.getenv("CO_API_KEY")) 141 | if not config["api_key"]: 142 | raise ValueError( 143 | "Cohere API key is missing. Please provide it in the config or set the CO_API_KEY environment variable." 144 | ) 145 | self.client = cohere.ClientV2(**config) 146 | self.transformer = CohereMessageConverter() 147 | 148 | def chat_completions_create(self, model, messages, **kwargs): 149 | """ 150 | Makes a request to Cohere using the official client. 151 | """ 152 | try: 153 | # Transform messages using converter 154 | transformed_messages = self.transformer.convert_request(messages) 155 | 156 | # Make the request to Cohere 157 | response = self.client.chat( 158 | model=model, messages=transformed_messages, **kwargs 159 | ) 160 | 161 | return self.transformer.convert_response(response) 162 | except Exception as e: 163 | raise LLMError(f"An error occurred: {e}") 164 | -------------------------------------------------------------------------------- /aisuite/providers/deepseek_provider.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import os 3 | from aisuite.provider import Provider, LLMError 4 | 5 | 6 | class DeepseekProvider(Provider): 7 | def __init__(self, **config): 8 | """ 9 | Initialize the DeepSeek provider with the given configuration. 10 | Pass the entire configuration dictionary to the OpenAI client constructor. 11 | """ 12 | # Ensure API key is provided either in config or via environment variable 13 | config.setdefault("api_key", os.getenv("DEEPSEEK_API_KEY")) 14 | if not config["api_key"]: 15 | raise ValueError( 16 | "DeepSeek API key is missing. Please provide it in the config or set the OPENAI_API_KEY environment variable." 17 | ) 18 | config["base_url"] = "https://api.deepseek.com" 19 | 20 | # NOTE: We could choose to remove above lines for api_key since OpenAI will automatically 21 | # infer certain values from the environment variables. 22 | # Eg: OPENAI_API_KEY, OPENAI_ORG_ID, OPENAI_PROJECT_ID. Except for OPEN_AI_BASE_URL which has to be the deepseek url 23 | 24 | # Pass the entire config to the OpenAI client constructor 25 | self.client = openai.OpenAI(**config) 26 | 27 | def chat_completions_create(self, model, messages, **kwargs): 28 | # Any exception raised by OpenAI will be returned to the caller. 29 | # Maybe we should catch them and raise a custom LLMError. 30 | response = self.client.chat.completions.create( 31 | model=model, 32 | messages=messages, 33 | **kwargs # Pass any additional arguments to the OpenAI API 34 | ) 35 | return response 36 | -------------------------------------------------------------------------------- /aisuite/providers/fireworks_provider.py: -------------------------------------------------------------------------------- 1 | import os 2 | import httpx 3 | import json 4 | from aisuite.provider import Provider, LLMError 5 | from aisuite.framework import ChatCompletionResponse 6 | from aisuite.framework.message import Message, ChatCompletionMessageToolCall 7 | 8 | 9 | class FireworksMessageConverter: 10 | @staticmethod 11 | def convert_request(messages): 12 | """Convert messages to Fireworks format.""" 13 | transformed_messages = [] 14 | for message in messages: 15 | if isinstance(message, Message): 16 | message_dict = message.model_dump(mode="json") 17 | message_dict.pop("refusal", None) # Remove refusal field if present 18 | transformed_messages.append(message_dict) 19 | else: 20 | transformed_messages.append(message) 21 | return transformed_messages 22 | 23 | @staticmethod 24 | def convert_response(resp_json) -> ChatCompletionResponse: 25 | """Normalize the response from the Fireworks API to match OpenAI's response format.""" 26 | completion_response = ChatCompletionResponse() 27 | choice = resp_json["choices"][0] 28 | message = choice["message"] 29 | 30 | # Set basic message content 31 | completion_response.choices[0].message.content = message.get("content") 32 | completion_response.choices[0].message.role = message.get("role", "assistant") 33 | 34 | # Handle tool calls if present 35 | if "tool_calls" in message and message["tool_calls"] is not None: 36 | tool_calls = [] 37 | for tool_call in message["tool_calls"]: 38 | new_tool_call = ChatCompletionMessageToolCall( 39 | id=tool_call["id"], 40 | type=tool_call["type"], 41 | function={ 42 | "name": tool_call["function"]["name"], 43 | "arguments": tool_call["function"]["arguments"], 44 | }, 45 | ) 46 | tool_calls.append(new_tool_call) 47 | completion_response.choices[0].message.tool_calls = tool_calls 48 | 49 | return completion_response 50 | 51 | 52 | # Models that support tool calls: 53 | # [As of 01/20/2025 from https://docs.fireworks.ai/guides/function-calling] 54 | # Llama 3.1 405B Instruct 55 | # Llama 3.1 70B Instruct 56 | # Qwen 2.5 72B Instruct 57 | # Mixtral MoE 8x22B Instruct 58 | # Firefunction-v2: Latest and most performant model, optimized for complex function calling scenarios (on-demand only) 59 | # Firefunction-v1: Previous generation, Mixtral-based function calling model optimized for fast routing and structured output (on-demand only) 60 | class FireworksProvider(Provider): 61 | """ 62 | Fireworks AI Provider using httpx for direct API calls. 63 | """ 64 | 65 | BASE_URL = "https://api.fireworks.ai/inference/v1/chat/completions" 66 | 67 | def __init__(self, **config): 68 | """ 69 | Initialize the Fireworks provider with the given configuration. 70 | The API key is fetched from the config or environment variables. 71 | """ 72 | self.api_key = config.get("api_key", os.getenv("FIREWORKS_API_KEY")) 73 | if not self.api_key: 74 | raise ValueError( 75 | "Fireworks API key is missing. Please provide it in the config or set the FIREWORKS_API_KEY environment variable." 76 | ) 77 | 78 | # Optionally set a custom timeout (default to 30s) 79 | self.timeout = config.get("timeout", 30) 80 | self.transformer = FireworksMessageConverter() 81 | 82 | def chat_completions_create(self, model, messages, **kwargs): 83 | """ 84 | Makes a request to the Fireworks AI chat completions endpoint using httpx. 85 | """ 86 | # Remove 'stream' from kwargs if present 87 | kwargs.pop("stream", None) 88 | 89 | # Transform messages using converter 90 | transformed_messages = self.transformer.convert_request(messages) 91 | 92 | # Prepare the request payload 93 | data = { 94 | "model": model, 95 | "messages": transformed_messages, 96 | } 97 | 98 | # Add tools if provided 99 | if "tools" in kwargs: 100 | data["tools"] = kwargs["tools"] 101 | kwargs.pop("tools") 102 | 103 | # Add tool_choice if provided 104 | if "tool_choice" in kwargs: 105 | data["tool_choice"] = kwargs["tool_choice"] 106 | kwargs.pop("tool_choice") 107 | 108 | # Add remaining kwargs 109 | data.update(kwargs) 110 | 111 | headers = { 112 | "Authorization": f"Bearer {self.api_key}", 113 | "Content-Type": "application/json", 114 | } 115 | 116 | try: 117 | # Make the request to Fireworks AI endpoint. 118 | response = httpx.post( 119 | self.BASE_URL, json=data, headers=headers, timeout=self.timeout 120 | ) 121 | response.raise_for_status() 122 | return self.transformer.convert_response(response.json()) 123 | except httpx.HTTPStatusError as error: 124 | error_message = ( 125 | f"The request failed with status code: {error.status_code}\n" 126 | ) 127 | error_message += f"Headers: {error.headers}\n" 128 | error_message += error.response.text 129 | raise LLMError(error_message) 130 | except Exception as e: 131 | raise LLMError(f"An error occurred: {e}") 132 | 133 | def _normalize_response(self, response_data): 134 | """ 135 | Normalize the response to a common format (ChatCompletionResponse). 136 | """ 137 | normalized_response = ChatCompletionResponse() 138 | normalized_response.choices[0].message.content = response_data["choices"][0][ 139 | "message" 140 | ]["content"] 141 | return normalized_response 142 | -------------------------------------------------------------------------------- /aisuite/providers/groq_provider.py: -------------------------------------------------------------------------------- 1 | import os 2 | import groq 3 | from aisuite.provider import Provider, LLMError 4 | from aisuite.providers.message_converter import OpenAICompliantMessageConverter 5 | 6 | # Implementation of Groq provider. 7 | # Groq's message format is same as OpenAI's. 8 | # Tool calling specification is also exactly the same as OpenAI's. 9 | # Links: 10 | # https://console.groq.com/docs/tool-use 11 | # Groq supports tool calling for the following models, as of 16th Nov 2024: 12 | # llama3-groq-70b-8192-tool-use-preview 13 | # llama3-groq-8b-8192-tool-use-preview 14 | # llama-3.1-70b-versatile 15 | # llama-3.1-8b-instant 16 | # llama3-70b-8192 17 | # llama3-8b-8192 18 | # mixtral-8x7b-32768 (parallel tool use not supported) 19 | # gemma-7b-it (parallel tool use not supported) 20 | # gemma2-9b-it (parallel tool use not supported) 21 | 22 | 23 | class GroqMessageConverter(OpenAICompliantMessageConverter): 24 | """ 25 | Groq-specific message converter if needed 26 | """ 27 | 28 | pass 29 | 30 | 31 | class GroqProvider(Provider): 32 | def __init__(self, **config): 33 | """ 34 | Initialize the Groq provider with the given configuration. 35 | Pass the entire configuration dictionary to the Groq client constructor. 36 | """ 37 | # Ensure API key is provided either in config or via environment variable 38 | self.api_key = config.get("api_key", os.getenv("GROQ_API_KEY")) 39 | if not self.api_key: 40 | raise ValueError( 41 | "Groq API key is missing. Please provide it in the config or set the GROQ_API_KEY environment variable." 42 | ) 43 | config["api_key"] = self.api_key 44 | self.client = groq.Groq(**config) 45 | self.transformer = GroqMessageConverter() 46 | 47 | def chat_completions_create(self, model, messages, **kwargs): 48 | """ 49 | Makes a request to the Groq chat completions endpoint using the official client. 50 | """ 51 | try: 52 | # Transform messages using converter 53 | transformed_messages = self.transformer.convert_request(messages) 54 | 55 | response = self.client.chat.completions.create( 56 | model=model, 57 | messages=transformed_messages, 58 | **kwargs, # Pass any additional arguments to the Groq API 59 | ) 60 | return self.transformer.convert_response(response.model_dump()) 61 | except Exception as e: 62 | raise LLMError(f"An error occurred: {e}") 63 | -------------------------------------------------------------------------------- /aisuite/providers/huggingface_provider.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from huggingface_hub import InferenceClient 4 | from aisuite.provider import Provider, LLMError 5 | from aisuite.framework import ChatCompletionResponse 6 | from aisuite.framework.message import Message 7 | 8 | 9 | class HuggingfaceProvider(Provider): 10 | """ 11 | HuggingFace Provider using the official InferenceClient. 12 | This provider supports calls to HF serverless Inference Endpoints 13 | which use Text Generation Inference (TGI) as the backend. 14 | TGI is OpenAI protocol compliant. 15 | https://huggingface.co/inference-endpoints/ 16 | """ 17 | 18 | def __init__(self, **config): 19 | """ 20 | Initialize the provider with the given configuration. 21 | The token is fetched from the config or environment variables. 22 | """ 23 | # Ensure API key is provided either in config or via environment variable 24 | self.token = config.get("token") or os.getenv("HF_TOKEN") 25 | if not self.token: 26 | raise ValueError( 27 | "Hugging Face token is missing. Please provide it in the config or set the HF_TOKEN environment variable." 28 | ) 29 | 30 | # Initialize the InferenceClient with the specified model and timeout if provided 31 | self.model = config.get("model") 32 | self.timeout = config.get("timeout", 30) 33 | self.client = InferenceClient( 34 | token=self.token, model=self.model, timeout=self.timeout 35 | ) 36 | 37 | def chat_completions_create(self, model, messages, **kwargs): 38 | """ 39 | Makes a request to the Inference API endpoint using InferenceClient. 40 | """ 41 | # Validate and transform messages 42 | transformed_messages = [] 43 | for message in messages: 44 | if isinstance(message, Message): 45 | transformed_message = self.transform_from_message(message) 46 | elif isinstance(message, dict): 47 | transformed_message = message 48 | else: 49 | raise ValueError(f"Invalid message format: {message}") 50 | 51 | # Ensure 'content' is a non-empty string 52 | if ( 53 | "content" not in transformed_message 54 | or transformed_message["content"] is None 55 | ): 56 | transformed_message["content"] = "" 57 | 58 | transformed_messages.append(transformed_message) 59 | 60 | try: 61 | # Prepare the payload 62 | payload = { 63 | "messages": transformed_messages, 64 | **kwargs, # Include other parameters like temperature, max_tokens, etc. 65 | } 66 | 67 | # Make the API call using the client 68 | response = self.client.chat_completion(model=model, **payload) 69 | 70 | return self._normalize_response(response) 71 | 72 | except Exception as e: 73 | raise LLMError(f"An error occurred: {e}") 74 | 75 | def transform_from_message(self, message: Message): 76 | """Transform framework Message to a format that HuggingFace understands.""" 77 | # Ensure content is a string 78 | content = message.content if message.content is not None else "" 79 | 80 | # Transform the message 81 | transformed_message = { 82 | "role": message.role, 83 | "content": content, 84 | } 85 | 86 | # Include tool_calls if present 87 | if message.tool_calls: 88 | transformed_message["tool_calls"] = [ 89 | { 90 | "id": tool_call.id, 91 | "function": { 92 | "name": tool_call.function.name, 93 | "arguments": tool_call.function.arguments, 94 | }, 95 | "type": tool_call.type, 96 | } 97 | for tool_call in message.tool_calls 98 | ] 99 | 100 | return transformed_message 101 | 102 | def transform_to_message(self, message_dict: dict): 103 | """Transform HuggingFace message (dict) to a format that the framework Message understands.""" 104 | # Ensure required fields are present 105 | message_dict.setdefault("content", "") # Set empty string if content is missing 106 | message_dict.setdefault("refusal", None) # Set None if refusal is missing 107 | message_dict.setdefault("tool_calls", None) # Set None if tool_calls is missing 108 | 109 | # Handle tool calls if present and not None 110 | if message_dict.get("tool_calls"): 111 | for tool_call in message_dict["tool_calls"]: 112 | if "function" in tool_call: 113 | # Ensure function arguments are stringified 114 | if isinstance(tool_call["function"].get("arguments"), dict): 115 | tool_call["function"]["arguments"] = json.dumps( 116 | tool_call["function"]["arguments"] 117 | ) 118 | 119 | return Message(**message_dict) 120 | 121 | def _normalize_response(self, response_data): 122 | """ 123 | Normalize the response to a common format (ChatCompletionResponse). 124 | """ 125 | normalized_response = ChatCompletionResponse() 126 | message_data = response_data["choices"][0]["message"] 127 | normalized_response.choices[0].message = self.transform_to_message(message_data) 128 | return normalized_response 129 | -------------------------------------------------------------------------------- /aisuite/providers/message_converter.py: -------------------------------------------------------------------------------- 1 | from aisuite.framework import ChatCompletionResponse 2 | from aisuite.framework.message import Message, ChatCompletionMessageToolCall 3 | 4 | 5 | class OpenAICompliantMessageConverter: 6 | """ 7 | Base class for message converters that are compatible with OpenAI's API. 8 | """ 9 | 10 | # Class variable that derived classes can override 11 | tool_results_as_strings = False 12 | 13 | @staticmethod 14 | def convert_request(messages): 15 | """Convert messages to OpenAI-compatible format.""" 16 | transformed_messages = [] 17 | for message in messages: 18 | tmsg = None 19 | if isinstance(message, Message): 20 | message_dict = message.model_dump(mode="json") 21 | message_dict.pop("refusal", None) # Remove refusal field if present 22 | tmsg = message_dict 23 | else: 24 | tmsg = message 25 | # Check if tmsg is a dict, otherwise get role attribute 26 | role = tmsg["role"] if isinstance(tmsg, dict) else tmsg.role 27 | if role == "tool": 28 | if OpenAICompliantMessageConverter.tool_results_as_strings: 29 | # Handle both dict and object cases for content 30 | if isinstance(tmsg, dict): 31 | tmsg["content"] = str(tmsg["content"]) 32 | else: 33 | tmsg.content = str(tmsg.content) 34 | 35 | transformed_messages.append(tmsg) 36 | return transformed_messages 37 | 38 | @staticmethod 39 | def convert_response(response_data) -> ChatCompletionResponse: 40 | """Normalize the response to match OpenAI's response format.""" 41 | completion_response = ChatCompletionResponse() 42 | choice = response_data["choices"][0] 43 | message = choice["message"] 44 | 45 | # Set basic message content 46 | completion_response.choices[0].message.content = message["content"] 47 | completion_response.choices[0].message.role = message.get("role", "assistant") 48 | 49 | # Handle tool calls if present 50 | if "tool_calls" in message and message["tool_calls"] is not None: 51 | tool_calls = [] 52 | for tool_call in message["tool_calls"]: 53 | tool_calls.append( 54 | ChatCompletionMessageToolCall( 55 | id=tool_call.get("id"), 56 | type="function", # Always set to "function" as it's the only valid value 57 | function=tool_call.get("function"), 58 | ) 59 | ) 60 | completion_response.choices[0].message.tool_calls = tool_calls 61 | 62 | return completion_response 63 | -------------------------------------------------------------------------------- /aisuite/providers/mistral_provider.py: -------------------------------------------------------------------------------- 1 | import os 2 | from mistralai import Mistral 3 | from aisuite.framework.message import Message 4 | from aisuite.framework import ChatCompletionResponse 5 | from aisuite.provider import Provider, LLMError 6 | from aisuite.providers.message_converter import OpenAICompliantMessageConverter 7 | 8 | 9 | # Implementation of Mistral provider. 10 | # Mistral's message format is same as OpenAI's. Just different class names, but fully cross-compatible. 11 | # Links: 12 | # https://docs.mistral.ai/capabilities/function_calling/ 13 | 14 | 15 | class MistralMessageConverter(OpenAICompliantMessageConverter): 16 | """ 17 | Mistral-specific message converter 18 | """ 19 | 20 | @staticmethod 21 | def convert_response(response_data) -> ChatCompletionResponse: 22 | """Convert Mistral's response to our standard format.""" 23 | # Convert Mistral's response object to dict format 24 | response_dict = response_data.model_dump() 25 | return super(MistralMessageConverter, MistralMessageConverter).convert_response( 26 | response_dict 27 | ) 28 | 29 | 30 | # Function calling is available for the following models: 31 | # [As of 01/19/2025 from https://docs.mistral.ai/capabilities/function_calling/] 32 | # Mistral Large 33 | # Mistral Small 34 | # Codestral 22B 35 | # Ministral 8B 36 | # Ministral 3B 37 | # Pixtral 12B 38 | # Mixtral 8x22B 39 | # Mistral Nemo 40 | class MistralProvider(Provider): 41 | """ 42 | Mistral AI Provider using the official Mistral client. 43 | """ 44 | 45 | def __init__(self, **config): 46 | """ 47 | Initialize the Mistral provider with the given configuration. 48 | Pass the entire configuration dictionary to the Mistral client constructor. 49 | """ 50 | # Ensure API key is provided either in config or via environment variable 51 | config.setdefault("api_key", os.getenv("MISTRAL_API_KEY")) 52 | if not config["api_key"]: 53 | raise ValueError( 54 | "Mistral API key is missing. Please provide it in the config or set the MISTRAL_API_KEY environment variable." 55 | ) 56 | self.client = Mistral(**config) 57 | self.transformer = MistralMessageConverter() 58 | 59 | def chat_completions_create(self, model, messages, **kwargs): 60 | """ 61 | Makes a request to Mistral using the official client. 62 | """ 63 | try: 64 | # Transform messages using converter 65 | transformed_messages = self.transformer.convert_request(messages) 66 | 67 | # Make the request to Mistral 68 | response = self.client.chat.complete( 69 | model=model, messages=transformed_messages, **kwargs 70 | ) 71 | 72 | return self.transformer.convert_response(response) 73 | except Exception as e: 74 | raise LLMError(f"An error occurred: {e}") 75 | -------------------------------------------------------------------------------- /aisuite/providers/nebius_provider.py: -------------------------------------------------------------------------------- 1 | import os 2 | from aisuite.provider import Provider 3 | from openai import Client 4 | 5 | 6 | BASE_URL = "https://api.studio.nebius.ai/v1" 7 | 8 | 9 | # TODO(rohitcp): This needs to be added to our internal testbed. Tool calling not tested. 10 | class NebiusProvider(Provider): 11 | def __init__(self, **config): 12 | """ 13 | Initialize the Nebius AI Studio provider with the given configuration. 14 | Pass the entire configuration dictionary to the OpenAI client constructor. 15 | """ 16 | # Ensure API key is provided either in config or via environment variable 17 | config.setdefault("api_key", os.getenv("NEBIUS_API_KEY")) 18 | if not config["api_key"]: 19 | raise ValueError( 20 | "Nebius AI Studio API key is missing. Please provide it in the config or set the NEBIUS_API_KEY environment variable. You can get your API key at https://studio.nebius.ai/settings/api-keys" 21 | ) 22 | 23 | config["base_url"] = BASE_URL 24 | # Pass the entire config to the OpenAI client constructor 25 | self.client = Client(**config) 26 | 27 | def chat_completions_create(self, model, messages, **kwargs): 28 | return self.client.chat.completions.create( 29 | model=model, 30 | messages=messages, 31 | **kwargs # Pass any additional arguments to the Nebius API 32 | ) 33 | -------------------------------------------------------------------------------- /aisuite/providers/ollama_provider.py: -------------------------------------------------------------------------------- 1 | import os 2 | import httpx 3 | from aisuite.provider import Provider, LLMError 4 | from aisuite.framework import ChatCompletionResponse 5 | 6 | 7 | class OllamaProvider(Provider): 8 | """ 9 | Ollama Provider that makes HTTP calls instead of using SDK. 10 | It uses the /api/chat endpoint. 11 | Read more here - https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion 12 | If OLLAMA_API_URL is not set and not passed in config, then it will default to "http://localhost:11434" 13 | """ 14 | 15 | _CHAT_COMPLETION_ENDPOINT = "/api/chat" 16 | _CONNECT_ERROR_MESSAGE = "Ollama is likely not running. Start Ollama by running `ollama serve` on your host." 17 | 18 | def __init__(self, **config): 19 | """ 20 | Initialize the Ollama provider with the given configuration. 21 | """ 22 | self.url = config.get("api_url") or os.getenv( 23 | "OLLAMA_API_URL", "http://localhost:11434" 24 | ) 25 | 26 | # Optionally set a custom timeout (default to 30s) 27 | self.timeout = config.get("timeout", 30) 28 | 29 | def chat_completions_create(self, model, messages, **kwargs): 30 | """ 31 | Makes a request to the chat completions endpoint using httpx. 32 | """ 33 | kwargs["stream"] = False 34 | data = { 35 | "model": model, 36 | "messages": messages, 37 | **kwargs, # Pass any additional arguments to the API 38 | } 39 | 40 | try: 41 | response = httpx.post( 42 | self.url.rstrip("/") + self._CHAT_COMPLETION_ENDPOINT, 43 | json=data, 44 | timeout=self.timeout, 45 | ) 46 | response.raise_for_status() 47 | except httpx.ConnectError: # Handle connection errors 48 | raise LLMError(f"Connection failed: {self._CONNECT_ERROR_MESSAGE}") 49 | except httpx.HTTPStatusError as http_err: 50 | raise LLMError(f"Ollama request failed: {http_err}") 51 | except Exception as e: 52 | raise LLMError(f"An error occurred: {e}") 53 | 54 | # Return the normalized response 55 | return self._normalize_response(response.json()) 56 | 57 | def _normalize_response(self, response_data): 58 | """ 59 | Normalize the API response to a common format (ChatCompletionResponse). 60 | """ 61 | normalized_response = ChatCompletionResponse() 62 | normalized_response.choices[0].message.content = response_data["message"][ 63 | "content" 64 | ] 65 | return normalized_response 66 | -------------------------------------------------------------------------------- /aisuite/providers/openai_provider.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import os 3 | from aisuite.provider import Provider, LLMError 4 | from aisuite.providers.message_converter import OpenAICompliantMessageConverter 5 | 6 | 7 | class OpenaiProvider(Provider): 8 | def __init__(self, **config): 9 | """ 10 | Initialize the OpenAI provider with the given configuration. 11 | Pass the entire configuration dictionary to the OpenAI client constructor. 12 | """ 13 | # Ensure API key is provided either in config or via environment variable 14 | config.setdefault("api_key", os.getenv("OPENAI_API_KEY")) 15 | if not config["api_key"]: 16 | raise ValueError( 17 | "OpenAI API key is missing. Please provide it in the config or set the OPENAI_API_KEY environment variable." 18 | ) 19 | 20 | # NOTE: We could choose to remove above lines for api_key since OpenAI will automatically 21 | # infer certain values from the environment variables. 22 | # Eg: OPENAI_API_KEY, OPENAI_ORG_ID, OPENAI_PROJECT_ID, OPENAI_BASE_URL, etc. 23 | 24 | # Pass the entire config to the OpenAI client constructor 25 | self.client = openai.OpenAI(**config) 26 | self.transformer = OpenAICompliantMessageConverter() 27 | 28 | def chat_completions_create(self, model, messages, **kwargs): 29 | # Any exception raised by OpenAI will be returned to the caller. 30 | # Maybe we should catch them and raise a custom LLMError. 31 | try: 32 | transformed_messages = self.transformer.convert_request(messages) 33 | response = self.client.chat.completions.create( 34 | model=model, 35 | messages=transformed_messages, 36 | **kwargs, # Pass any additional arguments to the OpenAI API 37 | ) 38 | return response 39 | except Exception as e: 40 | raise LLMError(f"An error occurred: {e}") 41 | -------------------------------------------------------------------------------- /aisuite/providers/sambanova_provider.py: -------------------------------------------------------------------------------- 1 | import os 2 | from aisuite.provider import Provider, LLMError 3 | from openai import OpenAI 4 | from aisuite.providers.message_converter import OpenAICompliantMessageConverter 5 | 6 | 7 | class SambanovaMessageConverter(OpenAICompliantMessageConverter): 8 | """ 9 | SambaNova-specific message converter. 10 | """ 11 | 12 | pass 13 | 14 | 15 | class SambanovaProvider(Provider): 16 | """ 17 | SambaNova Provider using OpenAI client for API calls. 18 | """ 19 | 20 | def __init__(self, **config): 21 | """ 22 | Initialize the SambaNova provider with the given configuration. 23 | Pass the entire configuration dictionary to the OpenAI client constructor. 24 | """ 25 | # Ensure API key is provided either in config or via environment variable 26 | self.api_key = config.get("api_key", os.getenv("SAMBANOVA_API_KEY")) 27 | if not self.api_key: 28 | raise ValueError( 29 | "Sambanova API key is missing. Please provide it in the config or set the SAMBANOVA_API_KEY environment variable." 30 | ) 31 | 32 | config["api_key"] = self.api_key 33 | config["base_url"] = "https://api.sambanova.ai/v1/" 34 | # Pass the entire config to the OpenAI client constructor 35 | self.client = OpenAI(**config) 36 | self.transformer = SambanovaMessageConverter() 37 | 38 | def chat_completions_create(self, model, messages, **kwargs): 39 | """ 40 | Makes a request to the SambaNova chat completions endpoint using the OpenAI client. 41 | """ 42 | try: 43 | # Transform messages using converter 44 | transformed_messages = self.transformer.convert_request(messages) 45 | 46 | response = self.client.chat.completions.create( 47 | model=model, 48 | messages=transformed_messages, 49 | **kwargs, # Pass any additional arguments to the Sambanova API 50 | ) 51 | return self.transformer.convert_response(response.model_dump()) 52 | except Exception as e: 53 | raise LLMError(f"An error occurred: {e}") 54 | -------------------------------------------------------------------------------- /aisuite/providers/together_provider.py: -------------------------------------------------------------------------------- 1 | import os 2 | import httpx 3 | from aisuite.provider import Provider, LLMError 4 | from aisuite.providers.message_converter import OpenAICompliantMessageConverter 5 | 6 | 7 | class TogetherMessageConverter(OpenAICompliantMessageConverter): 8 | """ 9 | Together-specific message converter if needed 10 | """ 11 | 12 | pass 13 | 14 | 15 | class TogetherProvider(Provider): 16 | """ 17 | Together AI Provider using httpx for direct API calls. 18 | """ 19 | 20 | BASE_URL = "https://api.together.xyz/v1/chat/completions" 21 | 22 | def __init__(self, **config): 23 | """ 24 | Initialize the Together provider with the given configuration. 25 | The API key is fetched from the config or environment variables. 26 | """ 27 | self.api_key = config.get("api_key", os.getenv("TOGETHER_API_KEY")) 28 | if not self.api_key: 29 | raise ValueError( 30 | "Together API key is missing. Please provide it in the config or set the TOGETHER_API_KEY environment variable." 31 | ) 32 | 33 | # Optionally set a custom timeout (default to 30s) 34 | self.timeout = config.get("timeout", 30) 35 | self.transformer = TogetherMessageConverter() 36 | 37 | def chat_completions_create(self, model, messages, **kwargs): 38 | """ 39 | Makes a request to the Together AI chat completions endpoint using httpx. 40 | """ 41 | # Transform messages using converter 42 | transformed_messages = self.transformer.convert_request(messages) 43 | 44 | headers = { 45 | "Authorization": f"Bearer {self.api_key}", 46 | "Content-Type": "application/json", 47 | } 48 | 49 | data = { 50 | "model": model, 51 | "messages": transformed_messages, 52 | **kwargs, # Pass any additional arguments to the API 53 | } 54 | 55 | try: 56 | # Make the request to Together AI endpoint. 57 | response = httpx.post( 58 | self.BASE_URL, json=data, headers=headers, timeout=self.timeout 59 | ) 60 | response.raise_for_status() 61 | return self.transformer.convert_response(response.json()) 62 | except httpx.HTTPStatusError as http_err: 63 | raise LLMError(f"Together AI request failed: {http_err}") 64 | except Exception as e: 65 | raise LLMError(f"An error occurred: {e}") 66 | -------------------------------------------------------------------------------- /aisuite/providers/watsonx_provider.py: -------------------------------------------------------------------------------- 1 | from aisuite.provider import Provider 2 | import os 3 | from ibm_watsonx_ai import Credentials 4 | from ibm_watsonx_ai.foundation_models import ModelInference 5 | from aisuite.framework import ChatCompletionResponse 6 | 7 | 8 | class WatsonxProvider(Provider): 9 | def __init__(self, **config): 10 | self.service_url = config.get("service_url") or os.getenv("WATSONX_SERVICE_URL") 11 | self.api_key = config.get("api_key") or os.getenv("WATSONX_API_KEY") 12 | self.project_id = config.get("project_id") or os.getenv("WATSONX_PROJECT_ID") 13 | 14 | if not self.service_url or not self.api_key or not self.project_id: 15 | raise EnvironmentError( 16 | "Missing one or more required WatsonX environment variables: " 17 | "WATSONX_SERVICE_URL, WATSONX_API_KEY, WATSONX_PROJECT_ID. " 18 | "Please refer to the setup guide: /guides/watsonx.md." 19 | ) 20 | 21 | def chat_completions_create(self, model, messages, **kwargs): 22 | model = ModelInference( 23 | model_id=model, 24 | credentials=Credentials( 25 | api_key=self.api_key, 26 | url=self.service_url, 27 | ), 28 | project_id=self.project_id, 29 | ) 30 | 31 | res = model.chat(messages=messages, params=kwargs) 32 | return self.normalize_response(res) 33 | 34 | def normalize_response(self, response): 35 | openai_response = ChatCompletionResponse() 36 | openai_response.choices[0].message.content = response["choices"][0]["message"][ 37 | "content" 38 | ] 39 | return openai_response 40 | -------------------------------------------------------------------------------- /aisuite/providers/xai_provider.py: -------------------------------------------------------------------------------- 1 | import os 2 | import httpx 3 | from aisuite.provider import Provider, LLMError 4 | from aisuite.framework import ChatCompletionResponse 5 | from aisuite.providers.message_converter import OpenAICompliantMessageConverter 6 | 7 | 8 | class XaiMessageConverter(OpenAICompliantMessageConverter): 9 | """ 10 | xAI-specific message converter if needed 11 | """ 12 | 13 | pass 14 | 15 | 16 | class XaiProvider(Provider): 17 | """ 18 | xAI Provider using httpx for direct API calls. 19 | """ 20 | 21 | BASE_URL = "https://api.x.ai/v1/chat/completions" 22 | 23 | def __init__(self, **config): 24 | """ 25 | Initialize the xAI provider with the given configuration. 26 | The API key is fetched from the config or environment variables. 27 | """ 28 | self.api_key = config.get("api_key", os.getenv("XAI_API_KEY")) 29 | if not self.api_key: 30 | raise ValueError( 31 | "xAI API key is missing. Please provide it in the config or set the XAI_API_KEY environment variable." 32 | ) 33 | 34 | # Optionally set a custom timeout (default to 30s) 35 | self.timeout = config.get("timeout", 30) 36 | self.transformer = XaiMessageConverter() 37 | 38 | def chat_completions_create(self, model, messages, **kwargs): 39 | """ 40 | Makes a request to the xAI chat completions endpoint using httpx. 41 | """ 42 | # Transform messages using converter 43 | transformed_messages = self.transformer.convert_request(messages) 44 | 45 | headers = { 46 | "Authorization": f"Bearer {self.api_key}", 47 | "Content-Type": "application/json", 48 | } 49 | 50 | data = { 51 | "model": model, 52 | "messages": transformed_messages, 53 | **kwargs, # Pass any additional arguments to the API 54 | } 55 | 56 | try: 57 | # Make the request to xAI endpoint. 58 | response = httpx.post( 59 | self.BASE_URL, json=data, headers=headers, timeout=self.timeout 60 | ) 61 | response.raise_for_status() 62 | return self.transformer.convert_response(response.json()) 63 | except httpx.HTTPStatusError as http_err: 64 | raise LLMError(f"xAI request failed: {http_err}") 65 | except Exception as e: 66 | raise LLMError(f"An error occurred: {e}") 67 | -------------------------------------------------------------------------------- /examples/QnA_with_pdf.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "#!pip install PyMuPDF requests" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import sys\n", 19 | "from dotenv import load_dotenv, find_dotenv\n", 20 | "\n", 21 | "sys.path.append('../aisuite')" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 3, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import aisuite as ai" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 4, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "import os\n", 40 | "def configure_environment(additional_env_vars=None):\n", 41 | " \"\"\"\n", 42 | " Load environment variables from .env file and apply any additional variables.\n", 43 | " :param additional_env_vars: A dictionary of additional environment variables to apply.\n", 44 | " \"\"\"\n", 45 | " # Load from .env file if available\n", 46 | " load_dotenv(find_dotenv())\n", 47 | "\n", 48 | " # Apply additional environment variables\n", 49 | " if additional_env_vars:\n", 50 | " for key, value in additional_env_vars.items():\n", 51 | " os.environ[key] = value\n", 52 | "\n", 53 | "# Define additional API keys and credentials\n", 54 | "additional_keys = {}\n", 55 | "\n", 56 | "# Configure environment\n", 57 | "configure_environment(additional_env_vars=additional_keys)" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 5, 63 | "metadata": {}, 64 | "outputs": [ 65 | { 66 | "name": "stdout", 67 | "output_type": "stream", 68 | "text": [ 69 | "Downloaded and extracted text from pdf.\n" 70 | ] 71 | } 72 | ], 73 | "source": [ 74 | "import requests\n", 75 | "import fitz\n", 76 | "from io import BytesIO\n", 77 | "\n", 78 | "# Link to paper in pdf format on the cost of avocados.\n", 79 | "pdf_path = \"https://arxiv.org/pdf/2104.04649\"\n", 80 | "pdf_text = \"\"\n", 81 | "# Download PDF and load it into memory\n", 82 | "response = requests.get(pdf_path)\n", 83 | "if response.status_code == 200:\n", 84 | " pdf_data = BytesIO(response.content) # Load PDF data into BytesIO\n", 85 | " # Open PDF from memory using fitz\n", 86 | " with fitz.open(stream=pdf_data, filetype=\"pdf\") as pdf:\n", 87 | " text = \"\"\n", 88 | " for page_num in range(pdf.page_count):\n", 89 | " page = pdf[page_num]\n", 90 | " pdf_text += page.get_text(\"text\") # Extract text\n", 91 | " pdf_text += \"\\n\" + \"=\"*50 + \"\\n\" # Separator for each page\n", 92 | " print(\"Downloaded and extracted text from pdf.\")\n", 93 | "else:\n", 94 | " print(f\"Failed to download PDF: {response.status_code}\")\n", 95 | "\n", 96 | "question = \"Is the price of organic avocados higher than non-organic avocados? What has been the trend?\"" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 6, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "client = ai.Client()\n", 106 | "messages = [\n", 107 | " {\"role\": \"system\", \"content\": \"You are a helpful assistant. Answer the question only based on the below text.\"},\n", 108 | " {\"role\": \"user\", \"content\": f\"Answer the question based on the following text:\\n\\n{pdf_text}\\n\\nQuestion: {question}\\n\"},\n", 109 | "]" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 7, 115 | "metadata": {}, 116 | "outputs": [ 117 | { 118 | "name": "stdout", 119 | "output_type": "stream", 120 | "text": [ 121 | "Based on the information provided in the text, yes, the price of organic avocados is consistently higher than conventional (non-organic) avocados. Specifically:\n", 122 | "\n", 123 | "1. Figure 2 shows a bar chart comparing average prices of conventional and organic avocados from 2015-2020. The text states that \"the average price of organic avocados is generally always higher than conventional avocados.\"\n", 124 | "\n", 125 | "2. Figure 3, a pie chart, illustrates that \"Nearly 58% of organic avocado sales averaged $1.80 per avocado and roughly 42% of conventional avocados averaged $1.30 per avocado.\"\n", 126 | "\n", 127 | "3. In the conclusion section, the text explicitly states: \"The price of organic avocados is on average 35-40% higher than conventional avocados.\"\n", 128 | "\n", 129 | "Regarding the trend, while the text doesn't provide detailed information on price trends over time, Figure 2 shows the average prices for both organic and conventional avocados from 2015-2020, indicating that this price difference has been consistent over that period.\n" 130 | ] 131 | } 132 | ], 133 | "source": [ 134 | "anthropic_claude_3_opus = \"anthropic:claude-3-5-sonnet-20240620\"\n", 135 | "response = client.chat.completions.create(model=anthropic_claude_3_opus, messages=messages)\n", 136 | "print(response.choices[0].message.content)" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 7, 142 | "metadata": {}, 143 | "outputs": [ 144 | { 145 | "name": "stdout", 146 | "output_type": "stream", 147 | "text": [ 148 | "Yes, according to the analysis presented in the text, the price of organic avocados is higher\n" 149 | ] 150 | } 151 | ], 152 | "source": [ 153 | "\n", 154 | "hf_model = \"huggingface:mistralai/Mistral-7B-Instruct-v0.3\"\n", 155 | "response = client.chat.completions.create(model=hf_model, messages=messages)\n", 156 | "print(response.choices[0].message.content)" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 21, 162 | "metadata": {}, 163 | "outputs": [ 164 | { 165 | "name": "stdout", 166 | "output_type": "stream", 167 | "text": [ 168 | "According to the text, yes, the price of organic avocados is on average 35-40% higher than conventional avocados.\n", 169 | "\n", 170 | "As for the trend, it can be observed that there is a steady growth in sales volume year after year for both conventional and organic avocados.\n", 171 | "\n", 172 | "However, in terms of price, the average price of organic avocados has been consistently higher than conventional avocados over the years. This can also be seen in Figure 2, which shows that the average price of organic avocados is generally always higher than conventional avocados.\n" 173 | ] 174 | } 175 | ], 176 | "source": [ 177 | "fireworks_model = \"fireworks:accounts/fireworks/models/llama-v3p2-3b-instruct\"\n", 178 | "response = client.chat.completions.create(model=fireworks_model, messages=messages, temperature=0.75, presence_penalty=0.5, frequency_penalty=0.5)\n", 179 | "print(response.choices[0].message.content)" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 8, 185 | "metadata": {}, 186 | "outputs": [ 187 | { 188 | "name": "stdout", 189 | "output_type": "stream", 190 | "text": [ 191 | "Yes, the price of organic avocados is higher than non-organic avocados. According to the text, the average price of organic avocados is generally 35-40% higher than conventional avocados.\n" 192 | ] 193 | } 194 | ], 195 | "source": [ 196 | "nebius_model = \"nebius:meta-llama/Meta-Llama-3.1-8B-Instruct-fast\"\n", 197 | "response = client.chat.completions.create(model=nebius_model, messages=messages, top_p=0.01)\n", 198 | "print(response.choices[0].message.content)" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": null, 204 | "metadata": {}, 205 | "outputs": [], 206 | "source": [] 207 | } 208 | ], 209 | "metadata": { 210 | "kernelspec": { 211 | "display_name": "Python 3 (ipykernel)", 212 | "language": "python", 213 | "name": "python3" 214 | }, 215 | "language_info": { 216 | "codemirror_mode": { 217 | "name": "ipython", 218 | "version": 3 219 | }, 220 | "file_extension": ".py", 221 | "mimetype": "text/x-python", 222 | "name": "python", 223 | "nbconvert_exporter": "python", 224 | "pygments_lexer": "ipython3", 225 | "version": "3.12.6" 226 | } 227 | }, 228 | "nbformat": 4, 229 | "nbformat_minor": 4 230 | } 231 | -------------------------------------------------------------------------------- /examples/aisuite_tool_abstraction.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 8, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import json\n", 10 | "import sys\n", 11 | "from dotenv import load_dotenv, find_dotenv\n", 12 | "import os\n", 13 | "\n", 14 | "sys.path.append('../../aisuite')\n", 15 | "# Load from .env file if available\n", 16 | "load_dotenv(find_dotenv())\n", 17 | "os.environ['ALLOW_MULTI_TURN'] = 'true'" 18 | ] 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "metadata": {}, 23 | "source": [ 24 | "### Define the functions" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 9, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "# Mock tool functions.\n", 34 | "def get_current_temperature(location: str, unit: str):\n", 35 | " \"\"\"This is a short description of what the function does.\n", 36 | "\n", 37 | " This is a longer description that can span\n", 38 | " multiple lines and provide more details.\n", 39 | "\n", 40 | " Args:\n", 41 | " param1: Description of param1\n", 42 | " param2: Description of param2\n", 43 | " \"\"\"\n", 44 | " return \"70\"\n", 45 | "\n", 46 | "def is_it_raining(location: str):\n", 47 | " # Simulate fetching rain probability\n", 48 | " return \"yes\"" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "### Call the model with tools" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 10, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "from aisuite import Client\n", 65 | "\n", 66 | "client = Client()\n", 67 | "messages = [{\n", 68 | " \"role\": \"user\",\n", 69 | " \"content\": \"Can you plan a picnic for today afternoon in San Francisco? Check the temperature and if its raining.\"}]" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 11, 75 | "metadata": {}, 76 | "outputs": [ 77 | { 78 | "name": "stdout", 79 | "output_type": "stream", 80 | "text": [ 81 | "--------- response from LLM ---------\n", 82 | "ChatCompletion(id='chatcmpl-AvuR3w6M83nWHL9sIO23pgvD0E5PF', choices=[Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content=None, refusal=None, role='assistant', audio=None, function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_PVhrHTUQ2qY7FDyPkCT54phU', function=Function(arguments='{\\n\"location\": \"San Francisco\",\\n\"unit\": \"Fahrenheit\"\\n}', name='get_current_temperature'), type='function')]))], created=1738364997, model='gpt-4-0613', object='chat.completion', service_tier='default', system_fingerprint=None, usage=CompletionUsage(completion_tokens=24, prompt_tokens=110, total_tokens=134, completion_tokens_details=CompletionTokensDetails(accepted_prediction_tokens=0, audio_tokens=0, reasoning_tokens=0, rejected_prediction_tokens=0), prompt_tokens_details=PromptTokensDetails(audio_tokens=0, cached_tokens=0)))\n", 83 | "Executing tool: get_current_temperature\n", 84 | "--------- tool_message to send to LLM ---------\n", 85 | "[{'role': 'tool', 'name': 'get_current_temperature', 'content': '\"70\"', 'tool_call_id': 'call_PVhrHTUQ2qY7FDyPkCT54phU'}]\n", 86 | "--------- response from LLM ---------\n", 87 | "ChatCompletion(id='chatcmpl-AvuR4hpiCgXFxRMNLg1Scv1UD2JlR', choices=[Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content=None, refusal=None, role='assistant', audio=None, function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_Fvy3JUhIbVqzV0Nb0QfSnWQC', function=Function(arguments='{\\n\"location\": \"San Francisco\"\\n}', name='is_it_raining'), type='function')]))], created=1738364998, model='gpt-4-0613', object='chat.completion', service_tier='default', system_fingerprint=None, usage=CompletionUsage(completion_tokens=18, prompt_tokens=145, total_tokens=163, completion_tokens_details=CompletionTokensDetails(accepted_prediction_tokens=0, audio_tokens=0, reasoning_tokens=0, rejected_prediction_tokens=0), prompt_tokens_details=PromptTokensDetails(audio_tokens=0, cached_tokens=0)))\n", 88 | "Executing tool: is_it_raining\n", 89 | "--------- tool_message to send to LLM ---------\n", 90 | "[{'role': 'tool', 'name': 'is_it_raining', 'content': '\"yes\"', 'tool_call_id': 'call_Fvy3JUhIbVqzV0Nb0QfSnWQC'}]\n", 91 | "--------- response from LLM ---------\n", 92 | "ChatCompletion(id='chatcmpl-AvuR7puWSJLEpCNIZLLkF40gYJp3c', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=\"I'm sorry, it seems like it will be raining this afternoon in San Francisco. You might want to plan your picnic for another day. Also the temperature is forecasted to be around 70 degrees Fahrenheit.\", refusal=None, role='assistant', audio=None, function_call=None, tool_calls=None))], created=1738365001, model='gpt-4-0613', object='chat.completion', service_tier='default', system_fingerprint=None, usage=CompletionUsage(completion_tokens=44, prompt_tokens=175, total_tokens=219, completion_tokens_details=CompletionTokensDetails(accepted_prediction_tokens=0, audio_tokens=0, reasoning_tokens=0, rejected_prediction_tokens=0), prompt_tokens_details=PromptTokensDetails(audio_tokens=0, cached_tokens=0)))\n", 93 | "I'm sorry, it seems like it will be raining this afternoon in San Francisco. You might want to plan your picnic for another day. Also the temperature is forecasted to be around 70 degrees Fahrenheit.\n" 94 | ] 95 | } 96 | ], 97 | "source": [ 98 | "response = client.chat.completions.create(\n", 99 | " model=\"openai:gpt-4\", messages=messages, tools=[get_current_temperature, is_it_raining], max_turns=4)\n", 100 | "print(response.choices[0].message.content)" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "response = client.chat.completions.create(\n", 110 | " model=\"anthropic:claude-3-5-sonnet-20241022\", messages=messages, tools=[get_current_temperature, is_it_raining], max_turns=4)\n", 111 | "print(response.choices[0].message.content)" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "print(response)" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 12, 126 | "metadata": {}, 127 | "outputs": [ 128 | { 129 | "name": "stdout", 130 | "output_type": "stream", 131 | "text": [ 132 | "[ChatCompletionMessage(content=None, refusal=None, role='assistant', audio=None, function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_PVhrHTUQ2qY7FDyPkCT54phU', function=Function(arguments='{\\n\"location\": \"San Francisco\",\\n\"unit\": \"Fahrenheit\"\\n}', name='get_current_temperature'), type='function')]),\n", 133 | " {'content': '\"70\"',\n", 134 | " 'name': 'get_current_temperature',\n", 135 | " 'role': 'tool',\n", 136 | " 'tool_call_id': 'call_PVhrHTUQ2qY7FDyPkCT54phU'},\n", 137 | " ChatCompletionMessage(content=None, refusal=None, role='assistant', audio=None, function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_Fvy3JUhIbVqzV0Nb0QfSnWQC', function=Function(arguments='{\\n\"location\": \"San Francisco\"\\n}', name='is_it_raining'), type='function')]),\n", 138 | " {'content': '\"yes\"',\n", 139 | " 'name': 'is_it_raining',\n", 140 | " 'role': 'tool',\n", 141 | " 'tool_call_id': 'call_Fvy3JUhIbVqzV0Nb0QfSnWQC'},\n", 142 | " ChatCompletionMessage(content=\"I'm sorry, it seems like it will be raining this afternoon in San Francisco. You might want to plan your picnic for another day. Also the temperature is forecasted to be around 70 degrees Fahrenheit.\", refusal=None, role='assistant', audio=None, function_call=None, tool_calls=None)]\n" 143 | ] 144 | } 145 | ], 146 | "source": [ 147 | "from pprint import pprint \n", 148 | "pprint(response.choices[0].intermediate_messages)" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "from aisuite import Tools\n", 158 | "tools = Tools(tools=[get_current_temperature, is_it_raining])\n", 159 | "tools.tools()\n", 160 | "# tools.add_description(\"is_it_raining\", \"Use this function to understand if it is going to rain or not\")" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "messages = append(messages, response.choices[0].intermediate_messages)\n" 170 | ] 171 | } 172 | ], 173 | "metadata": { 174 | "kernelspec": { 175 | "display_name": "Python 3 (ipykernel)", 176 | "language": "python", 177 | "name": "python3" 178 | }, 179 | "language_info": { 180 | "codemirror_mode": { 181 | "name": "ipython", 182 | "version": 3 183 | }, 184 | "file_extension": ".py", 185 | "mimetype": "text/x-python", 186 | "name": "python", 187 | "nbconvert_exporter": "python", 188 | "pygments_lexer": "ipython3", 189 | "version": "3.12.8" 190 | } 191 | }, 192 | "nbformat": 4, 193 | "nbformat_minor": 4 194 | } 195 | -------------------------------------------------------------------------------- /examples/chat-ui/.streamlit/config.toml: -------------------------------------------------------------------------------- 1 | [theme] 2 | primaryColor = "#1E90FF" # Blue color for primary components 3 | backgroundColor = "#0e1117" # Background color 4 | secondaryBackgroundColor = "#262730" # Secondary background color 5 | textColor = "#ffffff" # Text color 6 | font = "sans serif" 7 | 8 | -------------------------------------------------------------------------------- /examples/chat-ui/README.md: -------------------------------------------------------------------------------- 1 | # Chat UI 2 | 3 | This is a simple chat UI built using Streamlit. It uses the `aisuite` library to power the chat. 4 | 5 | You will need to install streamlit to run this example. 6 | 7 | ```bash 8 | pip install streamlit 9 | ``` 10 | 11 | You will also need to create a `config.yaml` file in the same directory as the `chat.py` file. An example config file has been provided. You need to set environment variables for the API keys and other configuration for the LLMs you want to use. Place a .env file in this directory since `chat.py` will look for it. 12 | 13 | In config.yaml, you can specify the LLMs you want to use in the chat. The chat UI will then display all these LLMs and you can select the one you want to use. 14 | 15 | To run the app, simply run the following command in your terminal: 16 | 17 | ```bash 18 | streamlit run chat.py 19 | ``` 20 | 21 | You can choose different LLMs by ticking the "Comparison Mode" checkbox. Then select the two LLMs you want to compare. 22 | Here are some sample queries you can try: 23 | 24 | ``` 25 | User: "What is the weather in Tokyo?" 26 | ``` 27 | 28 | ``` 29 | User: "Write a poem about the weather in Tokyo." 30 | ``` 31 | 32 | ``` 33 | User: "Write a python program to print the fibonacci sequence." 34 | Assistant: "-- Content from LLM 1 --" 35 | User: "Write test cases for this program." 36 | ``` 37 | -------------------------------------------------------------------------------- /examples/chat-ui/chat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | import streamlit as st 4 | import sys 5 | import yaml 6 | from dotenv import load_dotenv, find_dotenv 7 | 8 | sys.path.append("../../../aisuite") 9 | from aisuite.client import Client 10 | 11 | # Configure Streamlit to use wide mode and hide the top streamlit menu 12 | st.set_page_config(layout="wide", menu_items={}) 13 | # Add heading with padding 14 | st.markdown( 15 | "

Chat & Compare LLM responses

", 16 | unsafe_allow_html=True, 17 | ) 18 | st.markdown( 19 | """ 20 | 32 | """, 33 | unsafe_allow_html=True, 34 | ) 35 | st.markdown( 36 | """ 37 | 81 | """, 82 | unsafe_allow_html=True, 83 | ) 84 | 85 | # Load configuration and initialize aisuite client 86 | with open("config.yaml", "r") as file: 87 | config = yaml.safe_load(file) 88 | configured_llms = config["llms"] 89 | load_dotenv(find_dotenv()) 90 | client = Client() 91 | 92 | 93 | # Function to display chat history 94 | def display_chat_history(chat_history, model_name): 95 | for message in chat_history: 96 | role_display = "User" if message["role"] == "user" else model_name 97 | role = "user" if message["role"] == "user" else "assistant" 98 | if role == "user": 99 | with st.chat_message(role, avatar="👤"): 100 | st.write(message["content"]) 101 | else: 102 | with st.chat_message(role, avatar="🤖"): 103 | st.write(message["content"]) 104 | 105 | 106 | # Helper function to query each LLM 107 | def query_llm(model_config, chat_history): 108 | print(f"Querying {model_config['name']} with {chat_history}") 109 | try: 110 | model = model_config["provider"] + ":" + model_config["model"] 111 | response = client.chat.completions.create(model=model, messages=chat_history) 112 | print( 113 | f"Response from {model_config['name']}: {response.choices[0].message.content}" 114 | ) 115 | return response.choices[0].message.content 116 | except Exception as e: 117 | st.error(f"Error querying {model_config['name']}: {e}") 118 | return "Error with LLM response." 119 | 120 | 121 | # Initialize session states 122 | if "chat_history_1" not in st.session_state: 123 | st.session_state.chat_history_1 = [] 124 | if "chat_history_2" not in st.session_state: 125 | st.session_state.chat_history_2 = [] 126 | if "is_processing" not in st.session_state: 127 | st.session_state.is_processing = False 128 | if "use_comparison_mode" not in st.session_state: 129 | st.session_state.use_comparison_mode = False 130 | 131 | # Top Section - Controls 132 | col1, col2 = st.columns([1, 2]) 133 | with col1: 134 | st.session_state.use_comparison_mode = st.checkbox("Comparison Mode", value=True) 135 | 136 | # Move LLM selection below comparison mode checkbox - now in columns 137 | llm_col1, llm_col2 = st.columns(2) 138 | with llm_col1: 139 | selected_model_1 = st.selectbox( 140 | "Choose LLM Model 1", 141 | [llm["name"] for llm in configured_llms], 142 | key="model_1", 143 | index=0 if configured_llms else 0, 144 | ) 145 | with llm_col2: 146 | if st.session_state.use_comparison_mode: 147 | selected_model_2 = st.selectbox( 148 | "Choose LLM Model 2", 149 | [llm["name"] for llm in configured_llms], 150 | key="model_2", 151 | index=1 if len(configured_llms) > 1 else 0, 152 | ) 153 | 154 | # Display Chat Histories first, always 155 | # Middle Section - Display Chat Histories 156 | if st.session_state.use_comparison_mode: 157 | col1, col2 = st.columns(2) 158 | with col1: 159 | chat_container = st.container(height=500) 160 | with chat_container: 161 | display_chat_history(st.session_state.chat_history_1, selected_model_1) 162 | with col2: 163 | chat_container = st.container(height=500) 164 | with chat_container: 165 | display_chat_history(st.session_state.chat_history_2, selected_model_2) 166 | else: 167 | chat_container = st.container(height=500) 168 | with chat_container: 169 | display_chat_history(st.session_state.chat_history_1, selected_model_1) 170 | 171 | # Bottom Section - User Input 172 | st.markdown("
", unsafe_allow_html=True) 173 | 174 | col1, col2, col3 = st.columns([6, 1, 1]) 175 | with col1: 176 | user_query = st.text_area( 177 | label="Enter your query", 178 | label_visibility="collapsed", 179 | placeholder="Enter your query...", 180 | key="query_input", 181 | height=70, 182 | ) 183 | 184 | 185 | # CSS for aligning buttons with the bottom of the text area 186 | st.markdown( 187 | """ 188 | 200 | """, 201 | unsafe_allow_html=True, 202 | ) 203 | 204 | with col2: 205 | send_button = False # Initialize send_button 206 | if st.session_state.is_processing: 207 | st.markdown( 208 | "
Processing... ⏳
", 209 | unsafe_allow_html=True, 210 | ) 211 | else: 212 | send_button = st.button("Send Query", use_container_width=True) 213 | 214 | with col3: 215 | if st.button("Reset Chat", use_container_width=True): 216 | st.session_state.chat_history_1 = [] 217 | st.session_state.chat_history_2 = [] 218 | st.rerun() 219 | 220 | # Handle send button click and processing 221 | if send_button and user_query and not st.session_state.is_processing: 222 | # Set processing state 223 | st.session_state.is_processing = True 224 | 225 | # Append user's message to chat histories first 226 | st.session_state.chat_history_1.append({"role": "user", "content": user_query}) 227 | if st.session_state.use_comparison_mode: 228 | st.session_state.chat_history_2.append({"role": "user", "content": user_query}) 229 | 230 | st.rerun() 231 | 232 | # Handle the actual processing 233 | if st.session_state.is_processing and user_query: 234 | # Query the selected LLM(s) 235 | model_config_1 = next( 236 | llm for llm in configured_llms if llm["name"] == selected_model_1 237 | ) 238 | response_1 = query_llm(model_config_1, st.session_state.chat_history_1) 239 | st.session_state.chat_history_1.append({"role": "assistant", "content": response_1}) 240 | 241 | if st.session_state.use_comparison_mode: 242 | model_config_2 = next( 243 | llm for llm in configured_llms if llm["name"] == selected_model_2 244 | ) 245 | response_2 = query_llm(model_config_2, st.session_state.chat_history_2) 246 | st.session_state.chat_history_2.append( 247 | {"role": "assistant", "content": response_2} 248 | ) 249 | 250 | # Reset processing state 251 | st.session_state.is_processing = False 252 | st.rerun() 253 | -------------------------------------------------------------------------------- /examples/chat-ui/config.yaml: -------------------------------------------------------------------------------- 1 | # config.yaml 2 | llms: 3 | - name: "OpenAI GPT-4o" 4 | provider: "openai" 5 | model: "gpt-4o" 6 | - name: "Anthropic Claude 3.5 Sonnet" 7 | provider: "anthropic" 8 | model: "claude-3-5-sonnet-20240620" 9 | - name: "Azure/OpenAI GPT-4o" 10 | provider: "azure" 11 | model: "gpt-4o" 12 | - name: "Huggingface/Mistral 7B" 13 | provider: "huggingface" 14 | model: "mistralai/Mistral-7B-Instruct" 15 | -------------------------------------------------------------------------------- /examples/client.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "d34f8c48-90fc-4981-8d2b-b47724c2a6dd", 6 | "metadata": { 7 | "vscode": { 8 | "languageId": "raw" 9 | } 10 | }, 11 | "source": [ 12 | "# Client Examples\n", 13 | "\n", 14 | "Client provides a uniform interface for interacting with LLMs from various providers. It adapts the official python libraries from providers such as Mistral, OpenAI, Groq, Anthropic, AWS, etc to conform to the OpenAI chat completion interface. It directly calls the REST endpoints in some cases.\n", 15 | "\n", 16 | "Below are some examples of how to use Client to interact with different LLMs." 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "id": "initial_id", 23 | "metadata": { 24 | "ExecuteTime": { 25 | "end_time": "2024-07-04T15:30:02.064319Z", 26 | "start_time": "2024-07-04T15:30:02.051986Z" 27 | } 28 | }, 29 | "outputs": [], 30 | "source": [ 31 | "import sys\n", 32 | "from dotenv import load_dotenv, find_dotenv\n", 33 | "\n", 34 | "sys.path.append('../../aisuite')" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "id": "f75736ee", 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "import os\n", 45 | "def configure_environment(additional_env_vars=None):\n", 46 | " \"\"\"\n", 47 | " Load environment variables from .env file and apply any additional variables.\n", 48 | " :param additional_env_vars: A dictionary of additional environment variables to apply.\n", 49 | " \"\"\"\n", 50 | " # Load from .env file if available\n", 51 | " load_dotenv(find_dotenv())\n", 52 | "\n", 53 | " # Apply additional environment variables\n", 54 | " if additional_env_vars:\n", 55 | " for key, value in additional_env_vars.items():\n", 56 | " os.environ[key] = value\n", 57 | "\n", 58 | "# Define additional API keys and credentials\n", 59 | "additional_keys = {\n", 60 | " 'GROQ_API_KEY': 'xxx',\n", 61 | " 'AWS_ACCESS_KEY_ID': 'xxx',\n", 62 | " 'AWS_SECRET_ACCESS_KEY': 'xxx',\n", 63 | " 'ANTHROPIC_API_KEY': 'xxx',\n", 64 | " 'NEBIUS_API_KEY': 'xxx',\n", 65 | "}\n", 66 | "\n", 67 | "# Configure environment\n", 68 | "configure_environment(additional_env_vars=additional_keys)" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "id": "4de3a24f", 75 | "metadata": { 76 | "ExecuteTime": { 77 | "end_time": "2024-07-04T15:31:12.914321Z", 78 | "start_time": "2024-07-04T15:31:12.796445Z" 79 | } 80 | }, 81 | "outputs": [], 82 | "source": [ 83 | "import aisuite as ai\n", 84 | "\n", 85 | "client = ai.Client()\n", 86 | "messages = [\n", 87 | " {\"role\": \"system\", \"content\": \"Respond in Pirate English. Always try to include the phrase - No rum No fun.\"},\n", 88 | " {\"role\": \"user\", \"content\": \"Tell me a joke about Captain Jack Sparrow\"},\n", 89 | "]" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "id": "520a6879", 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "# print(os.environ[\"ANTHROPIC_API_KEY\"])\n", 100 | "anthropic_claude_3_opus = \"anthropic:claude-3-5-sonnet-20240620\"\n", 101 | "response = client.chat.completions.create(model=anthropic_claude_3_opus, messages=messages)\n", 102 | "print(response.choices[0].message.content)" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "id": "9893c7e4-799a-42c9-84de-f9e643044462", 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "aws_bedrock_llama3_8b = \"aws:meta.llama3-1-8b-instruct-v1:0\"\n", 113 | "response = client.chat.completions.create(model=aws_bedrock_llama3_8b, messages=messages)\n", 114 | "print(response.choices[0].message.content)" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "id": "7e46c20a", 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "# IMP NOTE: Azure expects model endpoint to be passed in the format of \"azure:\".\n", 125 | "# The model name is the deployment name in Project/Deployments.\n", 126 | "# In the example below, the model is \"mistral-large-2407\", but the name given to the\n", 127 | "# deployment is \"aisuite-mistral-large-2407\" under the deployments section in Azure.\n", 128 | "client.configure({\"azure\" : {\n", 129 | " \"api_key\": os.environ[\"AZURE_API_KEY\"],\n", 130 | " \"base_url\": \"https://aisuite-mistral-large-2407.westus3.models.ai.azure.com/v1/\",\n", 131 | "}});\n", 132 | "azure_model = \"azure:aisuite-mistral-large-2407\"\n", 133 | "response = client.chat.completions.create(model=azure_model, messages=messages)\n", 134 | "print(response.choices[0].message.content)" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "id": "f996b121", 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "# HuggingFace expects the model to be passed in the format of \"huggingface:\".\n", 145 | "# The model name is the full name of the model in HuggingFace.\n", 146 | "# In the example below, the model is \"mistralai/Mistral-7B-Instruct-v0.3\".\n", 147 | "# The model is deployed as serverless inference endpoint in HuggingFace.\n", 148 | "hf_model = \"huggingface:mistralai/Mistral-7B-Instruct-v0.3\"\n", 149 | "response = client.chat.completions.create(model=hf_model, messages=messages)\n", 150 | "print(response.choices[0].message.content)" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "id": "c9b2aad6-8603-4227-9566-778f714eb0b5", 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "\n", 161 | "# Groq expects the model to be passed in the format of \"groq:\".\n", 162 | "# The model name is the full name of the model in Groq.\n", 163 | "# In the example below, the model is \"llama3-8b-8192\".\n", 164 | "groq_llama3_8b = \"groq:llama3-8b-8192\"\n", 165 | "# groq_llama3_70b = \"groq:llama3-70b-8192\"\n", 166 | "response = client.chat.completions.create(model=groq_llama3_8b, messages=messages)\n", 167 | "print(response.choices[0].message.content)" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "id": "6819ac17", 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "ollama_tinyllama = \"ollama:tinyllama\"\n", 178 | "ollama_phi3mini = \"ollama:phi3:mini\"\n", 179 | "response = client.chat.completions.create(model=ollama_phi3mini, messages=messages, temperature=0.75)\n", 180 | "print(response.choices[0].message.content)" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": null, 186 | "id": "4a94961b2bddedbb", 187 | "metadata": { 188 | "ExecuteTime": { 189 | "end_time": "2024-07-04T15:31:39.472675Z", 190 | "start_time": "2024-07-04T15:31:38.283368Z" 191 | } 192 | }, 193 | "outputs": [], 194 | "source": [ 195 | "mistral_7b = \"mistral:open-mistral-7b\"\n", 196 | "response = client.chat.completions.create(model=mistral_7b, messages=messages, temperature=0.2)\n", 197 | "print(response.choices[0].message.content)" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": null, 203 | "id": "611210a4dc92845f", 204 | "metadata": {}, 205 | "outputs": [], 206 | "source": [ 207 | "openai_gpt35 = \"openai:gpt-3.5-turbo\"\n", 208 | "response = client.chat.completions.create(model=openai_gpt35, messages=messages, temperature=0.75)\n", 209 | "print(response.choices[0].message.content)" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": null, 215 | "id": "f38d033a-a580-4239-9176-27f3d53e7fe1", 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [ 219 | "nebius_model = \"nebius:Qwen/Qwen2.5-1.5B-Instruct\"\n", 220 | "response = client.chat.completions.create(model=nebius_model, messages=messages, top_p=0.01)\n", 221 | "print(response.choices[0].message.content)" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": null, 227 | "id": "321783ae", 228 | "metadata": {}, 229 | "outputs": [], 230 | "source": [ 231 | "fireworks_model = \"fireworks:accounts/fireworks/models/llama-v3p2-3b-instruct\"\n", 232 | "response = client.chat.completions.create(model=fireworks_model, messages=messages, temperature=0.75, presence_penalty=0.5, frequency_penalty=0.5)\n", 233 | "print(response.choices[0].message.content)" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": null, 239 | "id": "e30e5ae0", 240 | "metadata": {}, 241 | "outputs": [], 242 | "source": [ 243 | "togetherai_model = \"together:meta-llama/Llama-3.2-3B-Instruct-Turbo\"\n", 244 | "response = client.chat.completions.create(model=togetherai_model, messages=messages, temperature=0.75, top_p=0.7, top_k=50)\n", 245 | "print(response.choices[0].message.content)" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": null, 251 | "id": "dcf63a11", 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "gemini_15_flash = \"google:gemini-1.5-flash\"\n", 256 | "response = client.chat.completions.create(model=gemini_15_flash, messages=messages, temperature=0.75)\n", 257 | "print(response.choices[0].message.content)" 258 | ] 259 | } 260 | ], 261 | "metadata": { 262 | "kernelspec": { 263 | "display_name": "Python 3 (ipykernel)", 264 | "language": "python", 265 | "name": "python3" 266 | }, 267 | "language_info": { 268 | "codemirror_mode": { 269 | "name": "ipython", 270 | "version": 3 271 | }, 272 | "file_extension": ".py", 273 | "mimetype": "text/x-python", 274 | "name": "python", 275 | "nbconvert_exporter": "python", 276 | "pygments_lexer": "ipython3", 277 | "version": "3.12.6" 278 | } 279 | }, 280 | "nbformat": 4, 281 | "nbformat_minor": 5 282 | } -------------------------------------------------------------------------------- /examples/llm_reasoning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "d39a806c-02a3-4a2d-8c51-f1ab1ea79d2e", 6 | "metadata": {}, 7 | "source": [ 8 | "# LLM Reasoning\n", 9 | "\n", 10 | "This notebook compares how LLMs from different Generative AI providers perform on three examples that can show issues with LLM reasoning:\n", 11 | "\n", 12 | "* [The Reversal Curse](https://github.com/lukasberglund/reversal_curse) shows that LLMs trained on \"A is B\" fail to learn \"B is A\".\n", 13 | "* [How many r's in the word strawberry?](https://x.com/karpathy/status/1816637781659254908) shows \"the weirdness of LLM Tokenization\". \n", 14 | "* [Which number is bigger, 9.11 or 9.9?](https://x.com/DrJimFan/status/1816521330298356181) shows that \"LLMs are alien beasts.\"" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "id": "d2e413bd-983c-42a0-9580-96fedc7b1275", 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "!cat ../.env.sample" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "id": "8d843e36-7de6-4726-8a39-c5dcd3c7cc11", 30 | "metadata": {}, 31 | "source": [ 32 | "Make sure your ~/.env file (copied from the .env.sample file above) has the API keys of the LLM providers to compare set before running the cell below:" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "id": "3c966895-1a63-4922-80b7-5a20e47f29de", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "import sys\n", 43 | "sys.path.append('../../aisuite')\n", 44 | "\n", 45 | "from dotenv import load_dotenv, find_dotenv\n", 46 | "\n", 47 | "load_dotenv(find_dotenv())" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "id": "09d5c5be-1085-4252-9d5e-80b50961484b", 53 | "metadata": {}, 54 | "source": [ 55 | "## Specify LLMs to Compare" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "id": "26c3d5ef-b1c9-48dd-9b89-30799fd4b698", 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "import aisuite as ai\n", 66 | "\n", 67 | "client = ai.Client()" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "id": "886a904f-fef0-4f25-b3ed-41085bf0f2dd", 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "import time\n", 78 | "\n", 79 | "llms = [\n", 80 | " \"anthropic:claude-3-5-sonnet-20240620\",\n", 81 | " \"aws:meta.llama3-1-8b-instruct-v1:0\",\n", 82 | " \"groq:llama3-8b-8192\",\n", 83 | " \"groq:llama3-70b-8192\",\n", 84 | " \"huggingface:mistralai/Mistral-7B-Instruct-v0.3\",\n", 85 | " \"openai:gpt-3.5-turbo\",\n", 86 | " ]\n", 87 | "\n", 88 | "def compare_llm(messages):\n", 89 | " execution_times = []\n", 90 | " responses = []\n", 91 | " for llm in llms:\n", 92 | " start_time = time.time()\n", 93 | " response = client.chat.completions.create(model=llm, messages=messages)\n", 94 | " end_time = time.time()\n", 95 | " execution_time = end_time - start_time\n", 96 | " responses.append(response.choices[0].message.content.strip())\n", 97 | " execution_times.append(execution_time)\n", 98 | " print(f\"{llm} - {execution_time:.2f} seconds: {response.choices[0].message.content.strip()}\")\n", 99 | " return responses, execution_times" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "id": "3c3e8aa2-4ff4-485b-93d9-4a6f22d62e67", 105 | "metadata": {}, 106 | "source": [ 107 | "## The Reversal Curse" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "id": "f3c4a8ef-e23b-4d4a-8561-3e5a2a866bd1", 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "messages = [\n", 118 | " {\"role\": \"user\", \"content\": \"Who is Tom Cruise's mother?\"},\n", 119 | "]\n", 120 | "\n", 121 | "responses, execution_times = compare_llm(messages)" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "id": "769f7f42-2adb-4903-ab17-3143a5d950ce", 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "import pandas as pd\n", 132 | "\n", 133 | "def display(llms, execution_times, responses):\n", 134 | " data = {\n", 135 | " 'Provider:Model Name': llms,\n", 136 | " 'Execution Time': execution_times,\n", 137 | " 'Model Response ': responses\n", 138 | " }\n", 139 | " \n", 140 | " df = pd.DataFrame(data)\n", 141 | " df.index = df.index + 1\n", 142 | " styled_df = df.style.set_table_styles(\n", 143 | " [{'selector': 'th', 'props': [('text-align', 'center')]}, \n", 144 | " {'selector': 'td', 'props': [('text-align', 'center')]}]\n", 145 | " ).set_properties(**{'text-align': 'center'})\n", 146 | " \n", 147 | " return styled_df " 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "id": "d2359ad5-9f0b-4bd6-9838-54df91de0fb3", 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "display(llms, execution_times, responses)" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": null, 163 | "id": "399f6cca-7f34-4a91-aab0-070560640033", 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "messages = [\n", 168 | " {\"role\": \"user\", \"content\": \"Who is Mary Lee Pfeiffer's son?\"},\n", 169 | "]\n", 170 | "\n", 171 | "responses, execution_times = compare_llm(messages)" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "id": "eee7704d-a187-41bc-b119-c94461d0ee74", 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "display(llms, execution_times, responses)" 182 | ] 183 | }, 184 | { 185 | "cell_type": "markdown", 186 | "id": "ada8e0fb-17f0-4781-bf6a-c23ac86922ad", 187 | "metadata": {}, 188 | "source": [ 189 | "## How many r's in the word strawberry?" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": null, 195 | "id": "e537871e-68b6-44c3-886a-d3ebe7a692c1", 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [ 199 | "messages = [\n", 200 | " {\"role\": \"user\", \"content\": \"How many r's in the word strawberry?\"},\n", 201 | "]\n", 202 | "\n", 203 | "responses, execution_times = compare_llm(messages)" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": null, 209 | "id": "5678e393-4967-49f1-9e0f-251471dc92b7", 210 | "metadata": {}, 211 | "outputs": [], 212 | "source": [ 213 | "display(llms, execution_times, responses)" 214 | ] 215 | }, 216 | { 217 | "cell_type": "markdown", 218 | "id": "cae3fb5f-a173-4a33-b843-65df6d1086f9", 219 | "metadata": {}, 220 | "source": [ 221 | "## Which number is bigger?" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": null, 227 | "id": "efdf2fd6-f63a-4f9b-af15-1df25590e4fc", 228 | "metadata": {}, 229 | "outputs": [], 230 | "source": [ 231 | "messages = [\n", 232 | " {\"role\": \"user\", \"content\": \"Which number is bigger, 9.11 or 9.9?\"},\n", 233 | "]\n", 234 | "\n", 235 | "responses, execution_times = compare_llm(messages)" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": null, 241 | "id": "eaa14ed1-c83b-4c8f-bb14-d318bf0c9a60", 242 | "metadata": {}, 243 | "outputs": [], 244 | "source": [ 245 | "display(llms, execution_times, responses)" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": null, 251 | "id": "198b213a-b7bf-4cce-8c30-a8408454370b", 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "messages = [\n", 256 | " {\"role\": \"user\", \"content\": \"Which number is bigger, 9.11 or 9.9? Think step by step.\"},\n", 257 | "]\n", 258 | "\n", 259 | "responses, execution_times = compare_llm(messages)" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": null, 265 | "id": "4a3fb8fc-a7a2-47d3-9db2-792f03cc47c2", 266 | "metadata": {}, 267 | "outputs": [], 268 | "source": [ 269 | "display(llms, execution_times, responses)" 270 | ] 271 | }, 272 | { 273 | "cell_type": "markdown", 274 | "id": "66987d26-4245-4de1-816f-fa57475101f3", 275 | "metadata": {}, 276 | "source": [ 277 | "## Takeaways\n", 278 | "1. Not all LLMs are created equal - not even all Llama 3 (or 3.1) are created equal (by different providers).\n", 279 | "2. Ask LLM to think step by step may help improve its reasoning.\n", 280 | "3. The way tokenization works in LLM could lead to a lot of weirdness in LLM (see AK's awesome [video](https://www.youtube.com/watch?v=zduSFxRajkE) for a deep dive).\n", 281 | "4. A more comprehensive benchmark would be desired, but a quick LLM comparison like shown here can be the first step." 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": null, 287 | "id": "04e13c90-3680-4f1d-8f65-768a78b7adb2", 288 | "metadata": {}, 289 | "outputs": [], 290 | "source": [] 291 | } 292 | ], 293 | "metadata": { 294 | "kernelspec": { 295 | "display_name": "Python 3 (ipykernel)", 296 | "language": "python", 297 | "name": "python3" 298 | }, 299 | "language_info": { 300 | "codemirror_mode": { 301 | "name": "ipython", 302 | "version": 3 303 | }, 304 | "file_extension": ".py", 305 | "mimetype": "text/x-python", 306 | "name": "python", 307 | "nbconvert_exporter": "python", 308 | "pygments_lexer": "ipython3", 309 | "version": "3.12.6" 310 | } 311 | }, 312 | "nbformat": 4, 313 | "nbformat_minor": 5 314 | } 315 | -------------------------------------------------------------------------------- /guides/README.md: -------------------------------------------------------------------------------- 1 | # Provider guides 2 | 3 | These guides give directions for obtaining API keys from different providers. 4 | 5 | Here are the instructions for: 6 | - [Anthropic](anthropic.md) 7 | - [AWS](aws.md) 8 | - [Azure](azure.md) 9 | - [Cohere](cohere.md) 10 | - [Google](google.md) 11 | - [Hugging Face](huggingface.md) 12 | - [Mistral](mistral.md) 13 | - [OpenAI](openai.md) 14 | - [SambaNova](sambanova.md) 15 | - [xAI](xai.md) 16 | - [DeepSeek](deepseek.md) 17 | 18 | Unless otherwise stated, these guides have not been endorsed by the providers. 19 | 20 | We also welcome additional [contributions](../CONTRIBUTING.md). 21 | 22 | -------------------------------------------------------------------------------- /guides/anthropic.md: -------------------------------------------------------------------------------- 1 | # Anthropic 2 | 3 | To use Anthropic with `aisuite` you will need to [create an account](https://console.anthropic.com/login). Once logged in, go to the [API Keys](https://console.anthropic.com/settings/keys) 4 | and click the "Create Key" button and export that key into your environment. 5 | 6 | 7 | ```shell 8 | export ANTHROPIC_API_KEY="your-anthropic-api-key" 9 | ``` 10 | 11 | ## Create a Chat Completion 12 | 13 | Install the `anthropic` python client 14 | 15 | Example with pip 16 | ```shell 17 | pip install anthropic 18 | ``` 19 | 20 | Example with poetry 21 | ```shell 22 | poetry add anthropic 23 | ``` 24 | 25 | In your code: 26 | ```python 27 | import aisuite as ai 28 | client = ai.Client() 29 | 30 | 31 | provider = "anthropic" 32 | model_id = "claude-3-5-sonnet-20241022" 33 | 34 | messages = [ 35 | {"role": "system", "content": "Respond in Pirate English."}, 36 | {"role": "user", "content": "Tell me a joke."}, 37 | ] 38 | 39 | response = client.chat.completions.create( 40 | model=f"{provider}:{model_id}", 41 | messages=messages, 42 | ) 43 | 44 | print(response.choices[0].message.content) 45 | ``` 46 | 47 | Happy coding! If you would like to contribute, please read our [Contributing Guide](../CONTRIBUTING.md). 48 | -------------------------------------------------------------------------------- /guides/aws.md: -------------------------------------------------------------------------------- 1 | # AWS 2 | 3 | To use AWS Bedrock with `aisuite` you will need to create an AWS account and 4 | navigate to https://console.aws.amazon.com/bedrock/home. This route 5 | will be redirected to your default region. In this example the region has been set to 6 | `us-west-2`. Anywhere the region is specified can be replaced with your desired region. 7 | 8 | Navigate to the [overview](https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/overview) page 9 | directly or by clicking on the `Get started` link. 10 | 11 | ## Foundation Model Access 12 | 13 | You will first need to give your AWS account access to the foundation models by 14 | visiting the [modelaccess](https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/modelaccess) 15 | page to enable the models you would like to use. 16 | 17 | After enabling the foundation models, navigate to [providers page](https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers) 18 | and select the provider of the model you would like to use. From this page select the specific model you would like to use and 19 | make note of the `Model ID` (currently located near the bottom) this will be used when using the chat completion example below. 20 | 21 | Once that has been enabled set your Access Key and Secret in the env variables: 22 | 23 | ```shell 24 | export AWS_ACCESS_KEY="your-access-key" 25 | export AWS_SECRET_KEY="your-secret-key" 26 | export AWS_REGION="region-name" 27 | ``` 28 | *Note: AWS_REGION is optional, a default of `us-west-2` has been set for easy of use* 29 | 30 | ## Create a Chat Completion 31 | 32 | Install the boto3 client using your package installer 33 | 34 | Example with pip 35 | ```shell 36 | pip install boto3 37 | ``` 38 | 39 | Example with poetry 40 | ```shell 41 | poetry add boto3 42 | ``` 43 | 44 | In your code: 45 | ```python 46 | import aisuite as ai 47 | client = ai.Client() 48 | 49 | 50 | provider = "aws" 51 | model_id = "meta.llama3-1-405b-instruct-v1:0" # Model ID from above 52 | 53 | messages = [ 54 | {"role": "system", "content": "Respond in Pirate English."}, 55 | {"role": "user", "content": "Tell me a joke."}, 56 | ] 57 | 58 | response = client.chat.completions.create( 59 | model=f"{provider}:{model_id}", 60 | messages=messages, 61 | ) 62 | 63 | print(response.choices[0].message.content) 64 | ``` 65 | 66 | Happy coding! If you would like to contribute, please read our [Contributing Guide](../CONTRIBUTING.md). 67 | 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /guides/azure.md: -------------------------------------------------------------------------------- 1 | # Azure AI 2 | 3 | To use Azure AI with the `aisuite` library, you'll need to set up an Azure account and configure your environment for Azure AI services. 4 | 5 | ## Create an Azure Account and deploy a model from AI Studio 6 | 7 | 1. Visit [Azure Portal](https://portal.azure.com/) and sign up for an account if you don't have one. 8 | 2. Create a project and resource group. 9 | 3. Choose a model from https://ai.azure.com/explore/models and deploy it. You can choose serverless deployment option. 10 | 4. Give a deployment name. Lets say you choose to deploy Mistral-large-2407. You could leave the deployment names as "mistral-large-2407" or give a custom name. 11 | 5. You can see the deployment from project/deployment option. Note the Target URI from the Endpoint panel. It should look something like this - "https://aisuite-Mistral-large-2407.westus3.models.ai.azure.com". 12 | 6. Also note, that is provides a Chat completion URL. It should look like this - https://aisuite-Mistral-large-2407.westus3.models.ai.azure.com/v1/chat/completions 13 | 14 | 15 | ## Obtain Necessary Details & set environment variables. 16 | 17 | After creating your deployment, you'll need to gather the following information: 18 | 19 | 1. API Key: Found in the "Keys and Endpoint" section of your Azure OpenAI resource. 20 | 2. Base URL: This can be obtained from your deployment details. It will look something like this - `https://aisuite-Mistral-large-2407.westus3.models.ai.azure.com/v1/` 21 | 3. API Version: Optional configuration and mainly introduced for Azure OpenAI services. Once specified, the `api-version` query parameters will be added in the end of the API request. 22 | 23 | 24 | Set the following environment variables: 25 | 26 | ```shell 27 | export AZURE_API_KEY="your-api-key" 28 | export AZURE_BASE_URL="https://deployment-name.region-name.models.ai.azure.com/v1" 29 | export AZURE_API_VERSION="=2024-08-01-preview" 30 | ``` 31 | 32 | ## Create a Chat Completion 33 | 34 | With your account set up and environment configured, you can send a chat completion request: 35 | 36 | ```python 37 | import aisuite as ai 38 | 39 | # Either set the environment variables or set the below two parameters. 40 | # Setting the params in ai.Client() will override the values from environment vars. 41 | client = ai.Client( 42 | base_url=os.environ["AZURE_OPENAI_BASE_URL"], 43 | api_key=os.environ["AZURE_OPENAI_API_KEY"], 44 | api_version=os.environ["AZURE_API_VERSION"] 45 | ) 46 | 47 | model = "azure:aisuite-Mistral-large-2407" # Replace with your deployment name. 48 | # The model name must match the deployment name in the base-url. 49 | 50 | messages = [ 51 | {"role": "system", "content": "You are a helpful assistant."}, 52 | {"role": "user", "content": "What's the weather like today?"}, 53 | ] 54 | 55 | response = client.chat.completions.create( 56 | model=model, 57 | messages=messages, 58 | ) 59 | 60 | print(response.choices[0].message.content) 61 | ``` 62 | 63 | Happy coding! If you would like to contribute, please read our [Contributing Guide](../CONTRIBUTING.md). -------------------------------------------------------------------------------- /guides/cerebras.md: -------------------------------------------------------------------------------- 1 | # Cerebras AI Suite Provider Guide 2 | 3 | ## About Cerebras 4 | 5 | At Cerebras, we've developed the world's largest and fastest AI processor, the Wafer-Scale Engine-3 (WSE-3). The Cerebras CS-3 system, powered by the WSE-3, represents a new class of AI supercomputer that sets the standard for generative AI training and inference with unparalleled performance and scalability. 6 | 7 | With Cerebras as your inference provider, you can: 8 | - Achieve unprecedented speed for AI inference workloads 9 | - Build commercially with high throughput 10 | - Effortlessly scale your AI workloads with our seamless clustering technology 11 | 12 | Our CS-3 systems can be quickly and easily clustered to create the largest AI supercomputers in the world, making it simple to place and run the largest models. Leading corporations, research institutions, and governments are already using Cerebras solutions to develop proprietary models and train popular open-source models. 13 | 14 | Want to experience the power of Cerebras? Check out our [website](https://cerebras.net) for more resources and explore options for accessing our technology through the Cerebras Cloud or on-premise deployments! 15 | 16 | > [!NOTE] 17 | > This SDK has a mechanism that sends a few requests to `/v1/tcp_warming` upon construction to reduce the TTFT. If this behaviour is not desired, set `warm_tcp_connection=False` in the constructor. 18 | > 19 | > If you are repeatedly reconstructing the SDK instance it will lead to poor performance. It is recommended that you construct the SDK once and reuse the instance if possible. 20 | 21 | ## Documentation 22 | 23 | For the most comprehensive and up-to-date Cerebras Inference docs, please visit [inference-docs.cerebras.ai](https://inference-docs.cerebras.ai). 24 | 25 | ## Usage 26 | Get an API Key from [cloud.cerebras.ai](https://cloud.cerebras.ai/) and add it to your environment variables: 27 | 28 | ```shell 29 | export CEREBRAS_API_KEY="your-cerebras-api-key" 30 | ``` 31 | 32 | Use the python client. 33 | 34 | ```python 35 | import aisuite as ai 36 | client = ai.Client() 37 | 38 | messages = [ 39 | {"role": "system", "content": "Respond in Pirate English."}, 40 | {"role": "user", "content": "Tell me a joke."}, 41 | ] 42 | 43 | response = client.chat.completions.create( 44 | model="cerebras:llama3.1-8b", 45 | messages=messages, 46 | temperature=0.75 47 | ) 48 | print(response.choices[0].message.content) 49 | 50 | ``` 51 | 52 | ## Requirements 53 | 54 | Python 3.8 or higher. 55 | -------------------------------------------------------------------------------- /guides/cohere.md: -------------------------------------------------------------------------------- 1 | # Cohere 2 | 3 | To use Cohere with `aisuite`, you’ll need an [Cohere account](https://cohere.com/). After logging in, go to the [API Keys](https://dashboard.cohere.com/api-keys) section in your account settings, agree to the terms of service, connect your card, and generate a new key. Once you have your key, add it to your environment as follows: 4 | 5 | ```shell 6 | export CO_API_KEY="your-cohere-api-key" 7 | ``` 8 | 9 | ## Create a Chat Completion 10 | 11 | Install the `cohere` Python client: 12 | 13 | Example with pip: 14 | ```shell 15 | pip install cohere 16 | ``` 17 | 18 | Example with poetry: 19 | ```shell 20 | poetry add cohere 21 | ``` 22 | 23 | In your code: 24 | ```python 25 | import aisuite as ai 26 | 27 | client = ai.Client() 28 | 29 | provider = "cohere" 30 | model_id = "command-r-plus-08-2024" 31 | 32 | messages = [ 33 | {"role": "user", "content": "Hi, how are you?"} 34 | ] 35 | 36 | response = client.chat.completions.create( 37 | model=f"{provider}:{model_id}", 38 | messages=messages, 39 | ) 40 | 41 | print(response.choices[0].message.content) 42 | ``` 43 | 44 | Happy coding! If you’d like to contribute, please read our [Contributing Guide](CONTRIBUTING.md). 45 | -------------------------------------------------------------------------------- /guides/deepseek.md: -------------------------------------------------------------------------------- 1 | # DeepSeek 2 | 3 | To use DeepSeek with `aisuite`, you’ll need an [DeepSeek account](https://platform.deepseek.com). After logging in, go to the [API Keys](https://platform.deepseek.com/api_keys) section in your account settings and generate a new key. Once you have your key, add it to your environment as follows: 4 | 5 | ```shell 6 | export DEEPSEEK_API_KEY="your-deepseek-api-key" 7 | ``` 8 | 9 | ## Create a Chat Completion 10 | 11 | (Note: The DeepSeek uses an API format consistent with OpenAI, hence why we need to install OpenAI, there is no DeepSeek Library at least not for now) 12 | 13 | Install the `openai` Python client: 14 | 15 | Example with pip: 16 | ```shell 17 | pip install openai 18 | ``` 19 | 20 | Example with poetry: 21 | ```shell 22 | poetry add openai 23 | ``` 24 | 25 | In your code: 26 | ```python 27 | import aisuite as ai 28 | client = ai.Client() 29 | 30 | provider = "deepseek" 31 | model_id = "deepseek-chat" 32 | 33 | messages = [ 34 | {"role": "system", "content": "You are a helpful assistant."}, 35 | {"role": "user", "content": "What’s the weather like in San Francisco?"}, 36 | ] 37 | 38 | response = client.chat.completions.create( 39 | model=f"{provider}:{model_id}", 40 | messages=messages, 41 | ) 42 | 43 | print(response.choices[0].message.content) 44 | ``` 45 | 46 | Happy coding! If you’d like to contribute, please read our [Contributing Guide](../CONTRIBUTING.md). 47 | -------------------------------------------------------------------------------- /guides/google.md: -------------------------------------------------------------------------------- 1 | # Google (Vertex) AI 2 | 3 | To use Google (Vertex) AI with the `aisuite` library, you'll first need to create a Google Cloud account and set up your environment to work with Google Cloud. 4 | 5 | ## Create a Google Cloud Account and Project 6 | 7 | Google Cloud provides in-depth [documentation](https://cloud.google.com/vertex-ai/docs/start/cloud-environment) on getting started with their platform, but here are the basic steps: 8 | 9 | ### Create your account. 10 | 11 | Visit [Google Cloud](https://cloud.google.com/free) and follow the instructions for registering a new account. If you already have an account with Google Cloud, sign in and skip to the next step. 12 | 13 | ### Create a new project and enable billing. 14 | 15 | Once you have an account, you can create a new project. Visit the [project selector page](https://console.cloud.google.com/projectselector2/home/dashboard) and click the "New Project" button. Give your project a name and click "Create Project." Your project will be created and you will be redirected to the project dashboard. 16 | 17 | Now that you have a project, you'll need to enable billing. Visit the [how-to page](https://cloud.google.com/billing/docs/how-to/verify-billing-enabled#confirm_billing_is_enabled_on_a_project) for billing enablement instructions. 18 | 19 | ### Set your project ID in an environment variable. 20 | 21 | Set the `GOOGLE_PROJECT_ID` environment variable to the ID of your project. You can find the Project ID by visiting the project dashboard in the "Project Info" section toward the top of the page. 22 | 23 | ### Set your preferred region in an environment variable. 24 | 25 | Set the `GOOGLE_REGION` environment variable. You can find the region by going to Project Dashboard under VertexAI side navigation menu, and then scrolling to the bottom of the page. 26 | 27 | ## Create a Service Account For API Access 28 | 29 | Because `aisuite` needs to authenticate with Google Cloud to access the Vertex AI API, you'll need to create a service account and set the `GOOGLE_APPLICATION_CREDENTIALS` environment variable to the path of a JSON file containing the service account's credentials, which you can download from the Google Cloud Console. 30 | 31 | This is documented [here](https://cloud.google.com/docs/authentication/provide-credentials-adc#how-to), and the basic steps are as follows: 32 | 33 | 1. Visit the [service accounts page](https://console.cloud.google.com/iam-admin/serviceaccounts) in the Google Cloud Console. 34 | 2. Click the "+ Create Service Account" button toward the top of the page. 35 | 3. Follow the steps for naming your service account and granting access to the project. 36 | 4. Click "Done" to create the service account. 37 | 5. Now, click the "Keys" tab towards the top of the page. 38 | 6. Click the "Add Key" menu, then select "Create New Key." 39 | 6. Choose "JSON" as the key type, and click "Create." 40 | 7. Move this file to a location on your file system like your home directory. 41 | 8. Set the `GOOGLE_APPLICATION_CREDENTIALS` environment variable to the path of the JSON file. 42 | 43 | ## Double check your environment is configured correctly. 44 | 45 | At this point, you should have three environment variables set to ensure your environment is set up correctly: 46 | 47 | - `GOOGLE_PROJECT_ID` 48 | - `GOOGLE_REGION` 49 | - `GOOGLE_APPLICATION_CREDENTIALS` 50 | 51 | Once these are set, you are ready to write some code and send a chat completion request. 52 | 53 | ## Create a chat completion. 54 | 55 | With your account and service account set up, you can send a chat completion request. 56 | 57 | Export the environment variables: 58 | 59 | ```shell 60 | export GOOGLE_PROJECT_ID="your-project-id" 61 | export GOOGLE_REGION="your-region" 62 | export GOOGLE_APPLICATION_CREDENTIALS="path/to/your/service-account-file.json" 63 | ``` 64 | 65 | Install the Vertex AI SDK: 66 | 67 | ```shell 68 | pip install vertexai 69 | ``` 70 | 71 | In your code: 72 | 73 | ```python 74 | import aisuite as ai 75 | client = ai.Client() 76 | 77 | model="google:gemini-1.5-pro-001" 78 | 79 | messages = [ 80 | {"role": "system", "content": "Respond in Pirate English."}, 81 | {"role": "user", "content": "Tell me a joke."}, 82 | ] 83 | 84 | response = client.chat.completions.create( 85 | model=model, 86 | messages=messages, 87 | ) 88 | 89 | print(response.choices[0].message.content) 90 | ``` 91 | 92 | Happy coding! If you would like to contribute, please read our [Contributing Guide](../CONTRIBUTING.md). 93 | -------------------------------------------------------------------------------- /guides/groq.md: -------------------------------------------------------------------------------- 1 | # Groq 2 | 3 | To use Groq with `aisuite`, you’ll need a free [Groq account](https://console.groq.com/). After logging in, go to the [API Keys](https://console.groq.com/keys) section in your account settings and generate a new Groq API key. Once you have your key, add it to your environment as follows: 4 | 5 | ```shell 6 | export GROQ_API_KEY="your-groq-api-key" 7 | ``` 8 | 9 | ## Create a Python Chat Completion 10 | 11 | 1. First, install the `groq` Python client library: 12 | 13 | ```shell 14 | pip install groq 15 | ``` 16 | 17 | 2. Now you can simply create your first chat completion with the following example code or customize by swapoping out the `model_id` with any of the other available [models powered by Groq](https://console.groq.com/docs/models) and `messages` array with whatever you'd like: 18 | ```python 19 | import aisuite as ai 20 | client = ai.Client() 21 | 22 | provider = "groq" 23 | model_id = "llama-3.2-3b-preview" 24 | 25 | messages = [ 26 | {"role": "system", "content": "You are a helpful assistant."}, 27 | {"role": "user", "content": "What’s the weather like in San Francisco?"}, 28 | ] 29 | 30 | response = client.chat.completions.create( 31 | model=f"{provider}:{model_id}", 32 | messages=messages, 33 | ) 34 | 35 | print(response.choices[0].message.content) 36 | ``` 37 | 38 | 39 | Happy coding! If you’d like to contribute, please read our [Contributing Guide](CONTRIBUTING.md). 40 | -------------------------------------------------------------------------------- /guides/huggingface.md: -------------------------------------------------------------------------------- 1 | # Hugging Face AI 2 | 3 | To use Hugging Face with the `aisuite` library, you'll need to set up a Hugging Face account, obtain the necessary API credentials, and configure your environment for Hugging Face's API. 4 | 5 | ## Create a Hugging Face Account and Deploy a Model 6 | 7 | 1. Visit [Hugging Face](https://huggingface.co/) and sign up for an account if you don't already have one. 8 | 2. Explore conversational models on the [Hugging Face Model Hub](https://huggingface.co/models?inference=warm&other=conversational&sort=trending) and select a model you want to use. Popular models include conversational AI models like `gpt2`, `gpt3`, and `mistral`. 9 | 3. Deploy or host your chosen model if needed; Hugging Face provides various hosting options, including free, individual, and organizational hosting. Using Serverless Inference API is a fast way to get started. 10 | 5. Once the model is deployed (or if using a public model directly), note the model's unique identifier (e.g., `mistralai/Mistral-7B-Instruct-v0.3`), which you'll use for making requests. 11 | 12 | ## Obtain Necessary Details & Set Environment Variables 13 | 14 | After setting up your model, you'll need to gather the following information: 15 | 16 | - **API Token**: You can generate an API token in your [Hugging Face account settings](https://huggingface.co/settings/tokens). 17 | 18 | Set the following environment variables to make authentication and requests easy: 19 | 20 | ```shell 21 | export HF_TOKEN="your-api-token" 22 | ``` 23 | 24 | ## Create a Chat Completion 25 | 26 | With your account set up and environment variables configured, you can send a chat completion request as follows: 27 | 28 | ```python 29 | import os 30 | import aisuite as ai 31 | 32 | # Either set the environment variables or define the parameters below. 33 | # Setting the parameters in ai.Client() will override the environment variable values. 34 | client = ai.Client() 35 | 36 | model = "huggingface:your-model-name" # Replace with your model's identifier. 37 | 38 | messages = [ 39 | {"role": "system", "content": "You are a helpful assistant."}, 40 | {"role": "user", "content": "What's the weather like today?"}, 41 | ] 42 | 43 | response = client.chat.completions.create( 44 | model=model, 45 | messages=messages, 46 | ) 47 | 48 | print(response.choices[0].message.content) 49 | ``` 50 | 51 | ### Notes 52 | 53 | - Ensure that the `model` variable matches the identifier of your model as seen in the Hugging Face Model Hub. 54 | - If you encounter any rate limits or API access restrictions, you may have to upgrade your Hugging Face plan to enable higher usage limits. 55 | """ 56 | 57 | Happy coding! If you would like to contribute, please read our [Contributing Guide](../CONTRIBUTING.md). -------------------------------------------------------------------------------- /guides/mistral.md: -------------------------------------------------------------------------------- 1 | # Mistral 2 | 3 | To use Mistral with `aisuite`, you’ll need a [Mistral account](https://console.mistral.ai/). 4 | 5 | After logging in, go to [Workspace billing](https://console.mistral.ai/billing) and choose a plan 6 | - **Experiment** *(Free, 1 request per second); or* 7 | - **Scale** *(Pay per use).* 8 | 9 | Visit the [API Keys](https://console.mistral.ai/api-keys/) section in your account settings and generate a new key. Once you have your key, add it to your environment as follows: 10 | 11 | ```shell 12 | export MISTRAL="your-mistralai-api-key" 13 | ``` 14 | ## Create a Chat Completion 15 | 16 | Install the `mistralai` Python client: 17 | 18 | Example with pip: 19 | ```shell 20 | pip install mistralai 21 | ``` 22 | 23 | Example with poetry: 24 | ```shell 25 | poetry add mistralai 26 | ``` 27 | 28 | In your code: 29 | ```python 30 | import aisuite as ai 31 | client = ai.Client() 32 | 33 | provider = "mistral" 34 | model_id = "mistral-large-latest" 35 | 36 | messages = [ 37 | {"role": "system", "content": "You are a helpful assistant."}, 38 | {"role": "user", "content": "What’s the weather like in Montréal?"}, 39 | ] 40 | 41 | response = client.chat.completions.create( 42 | model=f"{provider}:{model_id}", 43 | messages=messages, 44 | ) 45 | 46 | print(response.choices[0].message.content) 47 | ``` 48 | 49 | Happy coding! If you’d like to contribute, please read our [Contributing Guide](../CONTRIBUTING.md). 50 | -------------------------------------------------------------------------------- /guides/nebius.md: -------------------------------------------------------------------------------- 1 | # Nebius AI Studio 2 | 3 | To use Nebius AI Studio with `aisuite`, you need an AI Studio account. Go to [AI Studio](https://studio.nebius.ai/) and press "Log in to AI Studio" in the right top corner. After logging in, go to the [API Keys](https://studio.nebius.ai/settings/api-keys) section and generate a new key. Once you have a key, add it to your environment as follows: 4 | 5 | ```shell 6 | export NEBIUS_API_KEY="your-nebius-api-key" 7 | ``` 8 | 9 | ## Create a Chat Completion 10 | 11 | Install the `openai` Python client: 12 | 13 | Example with pip: 14 | ```shell 15 | pip install openai 16 | ``` 17 | 18 | Example with poetry: 19 | ```shell 20 | poetry add openai 21 | ``` 22 | 23 | In your code: 24 | ```python 25 | import aisuite as ai 26 | client = ai.Client() 27 | 28 | provider = "nebius" 29 | model_id = "meta-llama/Llama-3.3-70B-Instruct" 30 | 31 | messages = [ 32 | {"role": "system", "content": "You are a helpful assistant."}, 33 | {"role": "user", "content": "How many times has Jurgen Klopp won the Champions League?"}, 34 | ] 35 | 36 | response = client.chat.completions.create( 37 | model=f"{provider}:{model_id}", 38 | messages=messages, 39 | ) 40 | 41 | print(response.choices[0].message.content) 42 | ``` 43 | 44 | Happy coding! If you’d like to contribute, please read our [Contributing Guide](CONTRIBUTING.md). 45 | -------------------------------------------------------------------------------- /guides/openai.md: -------------------------------------------------------------------------------- 1 | # OpenAI 2 | 3 | To use OpenAI with `aisuite`, you’ll need an [OpenAI account](https://platform.openai.com/). After logging in, go to the [API Keys](https://platform.openai.com/account/api-keys) section in your account settings and generate a new key. Once you have your key, add it to your environment as follows: 4 | 5 | ```shell 6 | export OPENAI_API_KEY="your-openai-api-key" 7 | ``` 8 | 9 | ## Create a Chat Completion 10 | 11 | Install the `openai` Python client: 12 | 13 | Example with pip: 14 | ```shell 15 | pip install openai 16 | ``` 17 | 18 | Example with poetry: 19 | ```shell 20 | poetry add openai 21 | ``` 22 | 23 | In your code: 24 | ```python 25 | import aisuite as ai 26 | client = ai.Client() 27 | 28 | provider = "openai" 29 | model_id = "gpt-4-turbo" 30 | 31 | messages = [ 32 | {"role": "system", "content": "You are a helpful assistant."}, 33 | {"role": "user", "content": "What’s the weather like in San Francisco?"}, 34 | ] 35 | 36 | response = client.chat.completions.create( 37 | model=f"{provider}:{model_id}", 38 | messages=messages, 39 | ) 40 | 41 | print(response.choices[0].message.content) 42 | ``` 43 | 44 | Happy coding! If you’d like to contribute, please read our [Contributing Guide](../CONTRIBUTING.md). 45 | -------------------------------------------------------------------------------- /guides/sambanova.md: -------------------------------------------------------------------------------- 1 | # Sambanova 2 | 3 | To use Sambanova with `aisuite`, you’ll need a [Sambanova Cloud](https://cloud.sambanova.ai/) account. After logging in, go to the [API](https://cloud.sambanova.ai/apis) section and generate a new key. Once you have your key, add it to your environment as follows: 4 | 5 | ```shell 6 | export SAMBANOVA_API_KEY="your-sambanova-api-key" 7 | ``` 8 | 9 | ## Create a Chat Completion 10 | 11 | Install the `openai` Python client: 12 | 13 | Example with pip: 14 | ```shell 15 | pip install openai 16 | ``` 17 | 18 | Example with poetry: 19 | ```shell 20 | poetry add openai 21 | ``` 22 | 23 | In your code: 24 | ```python 25 | import aisuite as ai 26 | client = ai.Client() 27 | 28 | provider = "sambanova" 29 | model_id = "Meta-Llama-3.1-405B-Instruct" 30 | 31 | messages = [ 32 | {"role": "system", "content": "You are a helpful assistant."}, 33 | {"role": "user", "content": "What’s the weather like in San Francisco?"}, 34 | ] 35 | 36 | response = client.chat.completions.create( 37 | model=f"{provider}:{model_id}", 38 | messages=messages, 39 | ) 40 | 41 | print(response.choices[0].message.content) 42 | ``` 43 | 44 | Happy coding! If you’d like to contribute, please read our [Contributing Guide](CONTRIBUTING.md). 45 | -------------------------------------------------------------------------------- /guides/watsonx.md: -------------------------------------------------------------------------------- 1 | # Watsonx with `aisuite` 2 | 3 | A a step-by-step guide to set up Watsonx with the `aisuite` library, enabling you to use IBM Watsonx's powerful AI models for various tasks. 4 | 5 | ## Setup Instructions 6 | 7 | ### Step 1: Create a Watsonx Account 8 | 9 | 1. Visit [IBM Watsonx](https://www.ibm.com/watsonx). 10 | 2. Sign up for a new account or log in with your existing IBM credentials. 11 | 3. Once logged in, navigate to the **Watsonx Dashboard** () 12 | 13 | --- 14 | 15 | ### Step 2: Obtain API Credentials 16 | 17 | 1. **Generate an API Key**: 18 | - Go to IAM > API keys and create a new API key () 19 | - Copy the API key. This is your `WATSONX_API_KEY`. 20 | 21 | 2. **Locate the Service URL**: 22 | - Your service URL is based on the region where your service is hosted. 23 | - Pick one from the list here 24 | - Copy the service URL. This is your `WATSONX_SERVICE_URL`. 25 | 26 | 3. **Get the Project ID**: 27 | - Go to the **Watsonx Dashboard** () 28 | - Under the **Projects** section, If you don't have a sandbox project, create a new project. 29 | - Navigate to the **Manage** tab and find the **Project ID**. 30 | - Copy the **Project ID**. This will serve as your `WATSONX_PROJECT_ID`. 31 | 32 | --- 33 | 34 | ### Step 3: Set Environment Variables 35 | 36 | To simplify authentication, set the following environment variables: 37 | 38 | Run the following commands in your terminal: 39 | 40 | ```bash 41 | export WATSONX_API_KEY="your-watsonx-api-key" 42 | export WATSONX_SERVICE_URL="your-watsonx-service-url" 43 | export WATSONX_PROJECT_ID="your-watsonx-project-id" 44 | ``` 45 | 46 | 47 | ## Create a Chat Completion 48 | 49 | Install the `ibm-watsonx-ai` Python client: 50 | 51 | Example with pip: 52 | 53 | ```shell 54 | pip install ibm-watsonx-ai 55 | ``` 56 | 57 | Example with poetry: 58 | 59 | ```shell 60 | poetry add ibm-watsonx-ai 61 | ``` 62 | 63 | In your code: 64 | 65 | ```python 66 | import aisuite as ai 67 | client = ai.Client() 68 | 69 | provider = "watsonx" 70 | model_id = "meta-llama/llama-3-70b-instruct" 71 | 72 | messages = [ 73 | {"role": "system", "content": "You are a helpful assistant."}, 74 | {"role": "user", "content": "Tell me a joke."}, 75 | ] 76 | 77 | response = client.chat.completions.create( 78 | model=f"{provider}:{model_id}", 79 | messages=messages, 80 | ) 81 | 82 | print(response.choices[0].message.content) 83 | ``` -------------------------------------------------------------------------------- /guides/xai.md: -------------------------------------------------------------------------------- 1 | # xAI 2 | 3 | To use xAI with `aisuite`, you’ll need an [API key](https://console.x.ai/). Generate a new key and once you have your key, add it to your environment as follows: 4 | 5 | ```shell 6 | export XAI_API_KEY="your-xai-api-key" 7 | ``` 8 | 9 | ## Create a Chat Completion 10 | 11 | Sample code: 12 | ```python 13 | import aisuite as ai 14 | client = ai.Client() 15 | 16 | models = ["xai:grok-beta"] 17 | 18 | messages = [ 19 | {"role": "system", "content": "Respond in Pirate English."}, 20 | {"role": "user", "content": "Tell me a joke."}, 21 | ] 22 | 23 | for model in models: 24 | response = client.chat.completions.create( 25 | model=model, 26 | messages=messages, 27 | temperature=0.75 28 | ) 29 | print(response.choices[0].message.content) 30 | 31 | ``` 32 | 33 | Happy coding! If you’d like to contribute, please read our [Contributing Guide](CONTRIBUTING.md). 34 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "aisuite" 3 | version = "0.1.11" 4 | description = "Uniform access layer for LLMs" 5 | authors = ["Andrew Ng, Rohit P"] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.10" 10 | anthropic = { version = "^0.30.1", optional = true } 11 | boto3 = { version = "^1.34.144", optional = true } 12 | cohere = { version = "^5.12.0", optional = true } 13 | vertexai = { version = "^1.63.0", optional = true } 14 | groq = { version = "^0.9.0", optional = true } 15 | mistralai = { version = "^1.0.3", optional = true } 16 | openai = { version = "^1.35.8", optional = true } 17 | ibm-watsonx-ai = { version = "^1.1.16", optional = true } 18 | docstring-parser = { version = "^0.14.0", optional = true } 19 | cerebras_cloud_sdk = { version = "^1.19.0", optional = true } 20 | 21 | # Optional dependencies for different providers 22 | httpx = "~0.27.0" 23 | [tool.poetry.extras] 24 | anthropic = ["anthropic"] 25 | aws = ["boto3"] 26 | azure = [] 27 | cerebras = ["cerebras_cloud_sdk"] 28 | cohere = ["cohere"] 29 | deepseek = ["openai"] 30 | google = ["vertexai"] 31 | groq = ["groq"] 32 | huggingface = [] 33 | mistral = ["mistralai"] 34 | ollama = [] 35 | openai = ["openai"] 36 | watsonx = ["ibm-watsonx-ai"] 37 | all = ["anthropic", "aws", "cerebras_cloud_sdk", "google", "groq", "mistral", "openai", "cohere", "watsonx"] # To install all providers 38 | 39 | [tool.poetry.group.dev.dependencies] 40 | pre-commit = "^3.7.1" 41 | black = "^24.4.2" 42 | python-dotenv = "^1.0.1" 43 | openai = "^1.35.8" 44 | groq = "^0.9.0" 45 | anthropic = "^0.30.1" 46 | notebook = "^7.2.1" 47 | ollama = "^0.2.1" 48 | mistralai = "^1.0.3" 49 | boto3 = "^1.34.144" 50 | fireworks-ai = "^0.14.0" 51 | chromadb = "^0.5.4" 52 | sentence-transformers = "^3.0.1" 53 | datasets = "^2.20.0" 54 | vertexai = "^1.63.0" 55 | ibm-watsonx-ai = "^1.1.16" 56 | cerebras_cloud_sdk = "^1.19.0" 57 | 58 | [tool.poetry.group.test] 59 | optional = true 60 | 61 | [tool.poetry.group.test.dependencies] 62 | pytest = "^8.2.2" 63 | pytest-cov = "^6.0.0" 64 | 65 | [build-system] 66 | requires = ["poetry-core"] 67 | build-backend = "poetry.core.masonry.api" 68 | 69 | [tool.pytest.ini_options] 70 | testpaths="tests" 71 | markers = [ 72 | "integration: marks tests as integration tests that interact with external services", 73 | ] 74 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewyng/aisuite/ddba58493e6e43b2ff7d507779e3c7735c09c33e/tests/__init__.py -------------------------------------------------------------------------------- /tests/client/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewyng/aisuite/ddba58493e6e43b2ff7d507779e3c7735c09c33e/tests/client/__init__.py -------------------------------------------------------------------------------- /tests/client/test_client.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import Mock, patch 2 | 3 | import pytest 4 | 5 | from aisuite import Client 6 | 7 | 8 | @pytest.fixture(scope="module") 9 | def provider_configs(): 10 | return { 11 | "openai": {"api_key": "test_openai_api_key"}, 12 | "aws": { 13 | "aws_access_key": "test_aws_access_key", 14 | "aws_secret_key": "test_aws_secret_key", 15 | "aws_session_token": "test_aws_session_token", 16 | "aws_region": "us-west-2", 17 | }, 18 | "azure": { 19 | "api_key": "azure-api-key", 20 | "base_url": "https://model.ai.azure.com", 21 | }, 22 | "groq": { 23 | "api_key": "groq-api-key", 24 | }, 25 | "mistral": { 26 | "api_key": "mistral-api-key", 27 | }, 28 | "google": { 29 | "project_id": "test_google_project_id", 30 | "region": "us-west4", 31 | "application_credentials": "test_google_application_credentials", 32 | }, 33 | "fireworks": { 34 | "api_key": "fireworks-api-key", 35 | }, 36 | "nebius": { 37 | "api_key": "nebius-api-key", 38 | }, 39 | } 40 | 41 | 42 | @pytest.mark.parametrize( 43 | argnames=("patch_target", "provider", "model"), 44 | argvalues=[ 45 | ( 46 | "aisuite.providers.openai_provider.OpenaiProvider.chat_completions_create", 47 | "openai", 48 | "gpt-4o", 49 | ), 50 | ( 51 | "aisuite.providers.mistral_provider.MistralProvider.chat_completions_create", 52 | "mistral", 53 | "mistral-model", 54 | ), 55 | ( 56 | "aisuite.providers.groq_provider.GroqProvider.chat_completions_create", 57 | "groq", 58 | "groq-model", 59 | ), 60 | ( 61 | "aisuite.providers.aws_provider.AwsProvider.chat_completions_create", 62 | "aws", 63 | "claude-v3", 64 | ), 65 | ( 66 | "aisuite.providers.azure_provider.AzureProvider.chat_completions_create", 67 | "azure", 68 | "azure-model", 69 | ), 70 | ( 71 | "aisuite.providers.anthropic_provider.AnthropicProvider.chat_completions_create", 72 | "anthropic", 73 | "anthropic-model", 74 | ), 75 | ( 76 | "aisuite.providers.google_provider.GoogleProvider.chat_completions_create", 77 | "google", 78 | "google-model", 79 | ), 80 | ( 81 | "aisuite.providers.fireworks_provider.FireworksProvider.chat_completions_create", 82 | "fireworks", 83 | "fireworks-model", 84 | ), 85 | ( 86 | "aisuite.providers.nebius_provider.NebiusProvider.chat_completions_create", 87 | "nebius", 88 | "nebius-model", 89 | ), 90 | ], 91 | ) 92 | def test_client_chat_completions( 93 | provider_configs: dict, patch_target: str, provider: str, model: str 94 | ): 95 | expected_response = f"{patch_target}_{provider}_{model}" 96 | with patch(patch_target) as mock_provider: 97 | mock_provider.return_value = expected_response 98 | client = Client() 99 | client.configure(provider_configs) 100 | messages = [ 101 | {"role": "system", "content": "You are a helpful assistant."}, 102 | {"role": "user", "content": "Who won the world series in 2020?"}, 103 | ] 104 | 105 | model_str = f"{provider}:{model}" 106 | model_response = client.chat.completions.create(model_str, messages=messages) 107 | assert model_response == expected_response 108 | 109 | 110 | def test_invalid_provider_in_client_config(): 111 | # Testing an invalid provider name in the configuration 112 | invalid_provider_configs = { 113 | "invalid_provider": {"api_key": "invalid_api_key"}, 114 | } 115 | 116 | # Expect ValueError when initializing Client with invalid provider and verify message 117 | with pytest.raises( 118 | ValueError, 119 | match=r"Invalid provider key 'invalid_provider'. Supported providers: ", 120 | ): 121 | _ = Client(invalid_provider_configs) 122 | 123 | 124 | def test_invalid_model_format_in_create(monkeypatch): 125 | from aisuite.providers.openai_provider import OpenaiProvider 126 | 127 | monkeypatch.setattr( 128 | target=OpenaiProvider, 129 | name="chat_completions_create", 130 | value=Mock(), 131 | ) 132 | 133 | # Valid provider configurations 134 | provider_configs = { 135 | "openai": {"api_key": "test_openai_api_key"}, 136 | } 137 | 138 | # Initialize the client with valid provider 139 | client = Client() 140 | client.configure(provider_configs) 141 | 142 | messages = [ 143 | {"role": "system", "content": "You are a helpful assistant."}, 144 | {"role": "user", "content": "Tell me a joke."}, 145 | ] 146 | 147 | # Invalid model format 148 | invalid_model = "invalidmodel" 149 | 150 | # Expect ValueError when calling create with invalid model format and verify message 151 | with pytest.raises( 152 | ValueError, match=r"Invalid model format. Expected 'provider:model'" 153 | ): 154 | client.chat.completions.create(invalid_model, messages=messages) 155 | -------------------------------------------------------------------------------- /tests/client/test_prerelease.py: -------------------------------------------------------------------------------- 1 | # Run this test before releasing a new version. 2 | # It will test all the models in the client. 3 | 4 | import pytest 5 | import aisuite as ai 6 | from typing import List, Dict 7 | from dotenv import load_dotenv, find_dotenv 8 | 9 | 10 | def setup_client() -> ai.Client: 11 | """Initialize the AI client with environment variables.""" 12 | load_dotenv(find_dotenv()) 13 | return ai.Client() 14 | 15 | 16 | def get_test_models() -> List[str]: 17 | """Return a list of model identifiers to test.""" 18 | return [ 19 | "anthropic:claude-3-5-sonnet-20240620", 20 | "aws:meta.llama3-1-8b-instruct-v1:0", 21 | "huggingface:mistralai/Mistral-7B-Instruct-v0.3", 22 | "groq:llama3-8b-8192", 23 | "mistral:open-mistral-7b", 24 | "openai:gpt-3.5-turbo", 25 | "cohere:command-r-plus-08-2024", 26 | ] 27 | 28 | 29 | def get_test_messages() -> List[Dict[str, str]]: 30 | """Return the test messages to send to each model.""" 31 | return [ 32 | { 33 | "role": "system", 34 | "content": "Respond in Pirate English. Always try to include the phrase - No rum No fun.", 35 | }, 36 | {"role": "user", "content": "Tell me a joke about Captain Jack Sparrow"}, 37 | ] 38 | 39 | 40 | @pytest.mark.integration 41 | @pytest.mark.parametrize("model_id", get_test_models()) 42 | def test_model_pirate_response(model_id: str): 43 | """ 44 | Test that each model responds appropriately to the pirate prompt. 45 | 46 | Args: 47 | model_id: The provider:model identifier to test 48 | """ 49 | client = setup_client() 50 | messages = get_test_messages() 51 | 52 | try: 53 | response = client.chat.completions.create( 54 | model=model_id, messages=messages, temperature=0.75 55 | ) 56 | 57 | content = response.choices[0].message.content.lower() 58 | 59 | # Check if either version of the required phrase is present 60 | assert any( 61 | phrase in content for phrase in ["no rum no fun", "no rum, no fun"] 62 | ), f"Model {model_id} did not include required phrase 'No rum No fun'" 63 | 64 | assert len(content) > 0, f"Model {model_id} returned empty response" 65 | assert isinstance( 66 | content, str 67 | ), f"Model {model_id} returned non-string response" 68 | 69 | except Exception as e: 70 | pytest.fail(f"Error testing model {model_id}: {str(e)}") 71 | 72 | 73 | if __name__ == "__main__": 74 | pytest.main([__file__, "-v"]) 75 | -------------------------------------------------------------------------------- /tests/providers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewyng/aisuite/ddba58493e6e43b2ff7d507779e3c7735c09c33e/tests/providers/__init__.py -------------------------------------------------------------------------------- /tests/providers/test_anthropic_converter.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import MagicMock 3 | from aisuite.providers.anthropic_provider import AnthropicMessageConverter 4 | from aisuite.framework.message import Message, ChatCompletionMessageToolCall, Function 5 | from aisuite.framework import ChatCompletionResponse 6 | 7 | 8 | class TestAnthropicMessageConverter(unittest.TestCase): 9 | 10 | def setUp(self): 11 | self.converter = AnthropicMessageConverter() 12 | 13 | def test_convert_request_single_user_message(self): 14 | messages = [{"role": "user", "content": "Hello, how are you?"}] 15 | system_message, converted_messages = self.converter.convert_request(messages) 16 | 17 | self.assertEqual(system_message, []) 18 | self.assertEqual( 19 | converted_messages, [{"role": "user", "content": "Hello, how are you?"}] 20 | ) 21 | 22 | def test_convert_request_with_system_message(self): 23 | messages = [ 24 | {"role": "system", "content": "You are a helpful assistant."}, 25 | {"role": "user", "content": "What is the weather?"}, 26 | ] 27 | system_message, converted_messages = self.converter.convert_request(messages) 28 | 29 | self.assertEqual(system_message, "You are a helpful assistant.") 30 | self.assertEqual( 31 | converted_messages, [{"role": "user", "content": "What is the weather?"}] 32 | ) 33 | 34 | def test_convert_request_with_tool_use_message(self): 35 | messages = [ 36 | {"role": "tool", "tool_call_id": "tool123", "content": "Weather data here."} 37 | ] 38 | system_message, converted_messages = self.converter.convert_request(messages) 39 | 40 | self.assertEqual(system_message, []) 41 | self.assertEqual( 42 | converted_messages, 43 | [ 44 | { 45 | "role": "user", 46 | "content": [ 47 | { 48 | "type": "tool_result", 49 | "tool_use_id": "tool123", 50 | "content": "Weather data here.", 51 | } 52 | ], 53 | } 54 | ], 55 | ) 56 | 57 | def test_convert_response_normal_message(self): 58 | response = MagicMock() 59 | response.stop_reason = "end_turn" 60 | response.usage.input_tokens = 10 61 | response.usage.output_tokens = 5 62 | content_mock = MagicMock() 63 | content_mock.type = "text" 64 | content_mock.text = "The weather is sunny." 65 | response.content = [content_mock] 66 | 67 | normalized_response = self.converter.convert_response(response) 68 | 69 | self.assertIsInstance(normalized_response, ChatCompletionResponse) 70 | self.assertEqual(normalized_response.choices[0].finish_reason, "stop") 71 | self.assertEqual( 72 | normalized_response.usage, 73 | {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, 74 | ) 75 | self.assertEqual( 76 | normalized_response.choices[0].message.content, "The weather is sunny." 77 | ) 78 | 79 | # Test that - when Anthropic returns a tool use message, it is correctly converted. 80 | def test_convert_response_with_tool_use(self): 81 | response = MagicMock() 82 | response.id = "msg_01Aq9w938a90dw8q" 83 | response.model = "claude-3-5-sonnet-20241022" 84 | response.role = "assistant" 85 | response.stop_reason = "tool_use" 86 | response.usage.input_tokens = 20 87 | response.usage.output_tokens = 10 88 | tool_use_mock = MagicMock() 89 | tool_use_mock.type = "tool_use" 90 | tool_use_mock.id = "tool123" 91 | tool_use_mock.name = "get_weather" 92 | tool_use_mock.input = {"location": "Paris"} 93 | 94 | text_mock = MagicMock() 95 | text_mock.type = "text" 96 | text_mock.text = "I need to call the get_weather function" 97 | 98 | response.content = [tool_use_mock, text_mock] 99 | 100 | normalized_response = self.converter.convert_response(response) 101 | 102 | self.assertIsInstance(normalized_response, ChatCompletionResponse) 103 | self.assertEqual(normalized_response.choices[0].finish_reason, "tool_calls") 104 | self.assertEqual( 105 | normalized_response.usage, 106 | {"prompt_tokens": 20, "completion_tokens": 10, "total_tokens": 30}, 107 | ) 108 | self.assertEqual( 109 | normalized_response.choices[0].message.content, 110 | "I need to call the get_weather function", 111 | ) 112 | self.assertEqual(len(normalized_response.choices[0].message.tool_calls), 1) 113 | self.assertEqual( 114 | normalized_response.choices[0].message.tool_calls[0].id, "tool123" 115 | ) 116 | self.assertEqual( 117 | normalized_response.choices[0].message.tool_calls[0].function.name, 118 | "get_weather", 119 | ) 120 | 121 | def test_convert_tool_spec(self): 122 | openai_tools = [ 123 | { 124 | "type": "function", 125 | "function": { 126 | "name": "get_weather", 127 | "description": "Get the weather.", 128 | "parameters": { 129 | "type": "object", 130 | "properties": { 131 | "location": {"type": "string", "description": "City name."} 132 | }, 133 | "required": ["location"], 134 | }, 135 | }, 136 | } 137 | ] 138 | 139 | anthropic_tools = self.converter.convert_tool_spec(openai_tools) 140 | 141 | self.assertEqual(len(anthropic_tools), 1) 142 | self.assertEqual(anthropic_tools[0]["name"], "get_weather") 143 | self.assertEqual(anthropic_tools[0]["description"], "Get the weather.") 144 | self.assertEqual( 145 | anthropic_tools[0]["input_schema"], 146 | { 147 | "type": "object", 148 | "properties": { 149 | "location": {"type": "string", "description": "City name."} 150 | }, 151 | "required": ["location"], 152 | }, 153 | ) 154 | 155 | def test_convert_request_with_tool_call_and_result(self): 156 | messages = [ 157 | { 158 | "role": "assistant", 159 | "content": "Let me check the weather.", 160 | "tool_calls": [ 161 | { 162 | "id": "tool123", 163 | "function": { 164 | "name": "get_weather", 165 | "arguments": '{"location": "San Francisco"}', 166 | }, 167 | } 168 | ], 169 | }, 170 | {"role": "tool", "tool_call_id": "tool123", "content": "65 degrees"}, 171 | ] 172 | system_message, converted_messages = self.converter.convert_request(messages) 173 | 174 | self.assertEqual(system_message, []) 175 | self.assertEqual( 176 | converted_messages, 177 | [ 178 | { 179 | "role": "assistant", 180 | "content": [ 181 | {"type": "text", "text": "Let me check the weather."}, 182 | { 183 | "type": "tool_use", 184 | "id": "tool123", 185 | "name": "get_weather", 186 | "input": {"location": "San Francisco"}, 187 | }, 188 | ], 189 | }, 190 | { 191 | "role": "user", 192 | "content": [ 193 | { 194 | "type": "tool_result", 195 | "tool_use_id": "tool123", 196 | "content": "65 degrees", 197 | } 198 | ], 199 | }, 200 | ], 201 | ) 202 | 203 | 204 | if __name__ == "__main__": 205 | unittest.main() 206 | -------------------------------------------------------------------------------- /tests/providers/test_aws_converter.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import MagicMock 3 | from aisuite.providers.aws_provider import BedrockMessageConverter 4 | from aisuite.framework.message import Message, ChatCompletionMessageToolCall 5 | from aisuite.framework import ChatCompletionResponse 6 | 7 | 8 | class TestBedrockMessageConverter(unittest.TestCase): 9 | 10 | def setUp(self): 11 | self.converter = BedrockMessageConverter() 12 | 13 | def test_convert_request_user_message(self): 14 | messages = [ 15 | {"role": "user", "content": "What is the most popular song on WZPZ?"} 16 | ] 17 | system_message, formatted_messages = self.converter.convert_request(messages) 18 | 19 | self.assertEqual(system_message, []) 20 | self.assertEqual(len(formatted_messages), 1) 21 | self.assertEqual(formatted_messages[0]["role"], "user") 22 | self.assertEqual( 23 | formatted_messages[0]["content"], 24 | [{"text": "What is the most popular song on WZPZ?"}], 25 | ) 26 | 27 | def test_convert_request_tool_result(self): 28 | messages = [ 29 | { 30 | "role": "tool", 31 | "tool_call_id": "tool123", 32 | "content": '{"song": "Elemental Hotel", "artist": "8 Storey Hike"}', 33 | } 34 | ] 35 | system_message, formatted_messages = self.converter.convert_request(messages) 36 | 37 | self.assertEqual(system_message, []) 38 | self.assertEqual(len(formatted_messages), 1) 39 | self.assertEqual(formatted_messages[0]["role"], "user") 40 | self.assertEqual( 41 | formatted_messages[0]["content"], 42 | [ 43 | { 44 | "toolResult": { 45 | "toolUseId": "tool123", 46 | "content": [ 47 | { 48 | "json": { 49 | "song": "Elemental Hotel", 50 | "artist": "8 Storey Hike", 51 | } 52 | } 53 | ], 54 | } 55 | } 56 | ], 57 | ) 58 | 59 | def test_convert_response_tool_call(self): 60 | response = { 61 | "output": { 62 | "message": { 63 | "role": "assistant", 64 | "content": [ 65 | { 66 | "toolUse": { 67 | "toolUseId": "tool123", 68 | "name": "top_song", 69 | "input": {"sign": "WZPZ"}, 70 | } 71 | } 72 | ], 73 | } 74 | }, 75 | "stopReason": "tool_use", 76 | } 77 | 78 | normalized_response = self.converter.convert_response(response) 79 | 80 | self.assertIsInstance(normalized_response, ChatCompletionResponse) 81 | self.assertEqual(normalized_response.choices[0].finish_reason, "tool_calls") 82 | tool_call = normalized_response.choices[0].message.tool_calls[0] 83 | self.assertEqual(tool_call.function.name, "top_song") 84 | self.assertEqual(tool_call.function.arguments, '{"sign": "WZPZ"}') 85 | 86 | def test_convert_response_text(self): 87 | response = { 88 | "output": { 89 | "message": { 90 | "role": "assistant", 91 | "content": [ 92 | { 93 | "text": "The most popular song on WZPZ is Elemental Hotel by 8 Storey Hike." 94 | } 95 | ], 96 | } 97 | }, 98 | "stopReason": "complete", 99 | } 100 | 101 | normalized_response = self.converter.convert_response(response) 102 | 103 | self.assertIsInstance(normalized_response, ChatCompletionResponse) 104 | self.assertEqual(normalized_response.choices[0].finish_reason, "stop") 105 | self.assertEqual( 106 | normalized_response.choices[0].message.content, 107 | "The most popular song on WZPZ is Elemental Hotel by 8 Storey Hike.", 108 | ) 109 | 110 | 111 | if __name__ == "__main__": 112 | unittest.main() 113 | -------------------------------------------------------------------------------- /tests/providers/test_azure_provider.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from aisuite.providers.azure_provider import AzureMessageConverter 3 | from aisuite.framework.message import Message, ChatCompletionMessageToolCall 4 | from aisuite.framework import ChatCompletionResponse 5 | 6 | 7 | class TestAzureMessageConverter(unittest.TestCase): 8 | def setUp(self): 9 | self.converter = AzureMessageConverter() 10 | 11 | def test_convert_request_dict_message(self): 12 | messages = [{"role": "user", "content": "Hello, how are you?"}] 13 | converted_messages = self.converter.convert_request(messages) 14 | 15 | self.assertEqual( 16 | converted_messages, [{"role": "user", "content": "Hello, how are you?"}] 17 | ) 18 | 19 | def test_convert_request_message_object(self): 20 | message = Message(role="user", content="Hello", tool_calls=None, refusal=None) 21 | messages = [message] 22 | converted_messages = self.converter.convert_request(messages) 23 | 24 | expected_message = { 25 | "role": "user", 26 | "content": "Hello", 27 | "reasoning_content": None, 28 | "tool_calls": None, 29 | "refusal": None, 30 | } 31 | self.assertEqual(converted_messages, [expected_message]) 32 | 33 | def test_convert_response_basic(self): 34 | azure_response = { 35 | "choices": [ 36 | { 37 | "message": { 38 | "role": "assistant", 39 | "content": "Hello! How can I help you?", 40 | } 41 | } 42 | ] 43 | } 44 | 45 | response = self.converter.convert_response(azure_response) 46 | 47 | self.assertIsInstance(response, ChatCompletionResponse) 48 | self.assertEqual( 49 | response.choices[0].message.content, "Hello! How can I help you?" 50 | ) 51 | self.assertEqual(response.choices[0].message.role, "assistant") 52 | self.assertIsNone(response.choices[0].message.tool_calls) 53 | 54 | def test_convert_response_with_tool_calls(self): 55 | azure_response = { 56 | "choices": [ 57 | { 58 | "message": { 59 | "role": "assistant", 60 | "content": "Let me check the weather.", 61 | "tool_calls": [ 62 | { 63 | "id": "tool123", 64 | "type": "function", 65 | "function": { 66 | "name": "get_weather", 67 | "arguments": '{"location": "London"}', 68 | }, 69 | } 70 | ], 71 | } 72 | } 73 | ] 74 | } 75 | 76 | response = self.converter.convert_response(azure_response) 77 | 78 | self.assertIsInstance(response, ChatCompletionResponse) 79 | self.assertEqual( 80 | response.choices[0].message.content, "Let me check the weather." 81 | ) 82 | self.assertEqual(len(response.choices[0].message.tool_calls), 1) 83 | 84 | tool_call = response.choices[0].message.tool_calls[0] 85 | self.assertEqual(tool_call.id, "tool123") 86 | self.assertEqual(tool_call.type, "function") 87 | self.assertEqual(tool_call.function.name, "get_weather") 88 | self.assertEqual(tool_call.function.arguments, '{"location": "London"}') 89 | 90 | 91 | if __name__ == "__main__": 92 | unittest.main() 93 | -------------------------------------------------------------------------------- /tests/providers/test_cerebras_provider.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock, patch 2 | 3 | import pytest 4 | 5 | from aisuite.providers.cerebras_provider import CerebrasProvider 6 | 7 | 8 | @pytest.fixture(autouse=True) 9 | def set_api_key_env_var(monkeypatch): 10 | """Fixture to set environment variables for tests.""" 11 | monkeypatch.setenv("CEREBRAS_API_KEY", "test-api-key") 12 | 13 | 14 | def test_cerebras_provider(): 15 | """High-level test that the provider is initialized and chat completions are requested successfully.""" 16 | 17 | user_greeting = "Hello!" 18 | message_history = [{"role": "user", "content": user_greeting}] 19 | selected_model = "our-favorite-model" 20 | chosen_temperature = 0.75 21 | response_text_content = "mocked-text-response-from-model" 22 | 23 | provider = CerebrasProvider() 24 | mock_response = MagicMock() 25 | mock_response.model_dump.return_value = { 26 | "choices": [{"message": {"content": response_text_content}}] 27 | } 28 | 29 | with patch.object( 30 | provider.client.chat.completions, 31 | "create", 32 | return_value=mock_response, 33 | ) as mock_create: 34 | response = provider.chat_completions_create( 35 | messages=message_history, 36 | model=selected_model, 37 | temperature=chosen_temperature, 38 | ) 39 | 40 | mock_create.assert_called_with( 41 | messages=message_history, 42 | model=selected_model, 43 | temperature=chosen_temperature, 44 | ) 45 | 46 | assert response.choices[0].message.content == response_text_content 47 | -------------------------------------------------------------------------------- /tests/providers/test_cohere_provider.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock, patch 2 | 3 | import pytest 4 | 5 | from aisuite.providers.cohere_provider import CohereProvider 6 | 7 | 8 | @pytest.fixture(autouse=True) 9 | def set_api_key_env_var(monkeypatch): 10 | """Fixture to set environment variables for tests.""" 11 | monkeypatch.setenv("CO_API_KEY", "test-api-key") 12 | 13 | 14 | def test_cohere_provider(): 15 | """High-level test that the provider is initialized and chat completions are requested successfully.""" 16 | 17 | user_greeting = "Hello!" 18 | message_history = [{"role": "user", "content": user_greeting}] 19 | selected_model = "our-favorite-model" 20 | chosen_temperature = 0.75 21 | response_text_content = "mocked-text-response-from-model" 22 | 23 | provider = CohereProvider() 24 | mock_response = MagicMock() 25 | mock_response.message = MagicMock() 26 | mock_response.message.content = [MagicMock()] 27 | mock_response.message.content[0].text = response_text_content 28 | 29 | with patch.object( 30 | provider.client, 31 | "chat", 32 | return_value=mock_response, 33 | ) as mock_create: 34 | response = provider.chat_completions_create( 35 | messages=message_history, 36 | model=selected_model, 37 | temperature=chosen_temperature, 38 | ) 39 | 40 | mock_create.assert_called_with( 41 | messages=message_history, 42 | model=selected_model, 43 | temperature=chosen_temperature, 44 | ) 45 | 46 | assert response.choices[0].message.content == response_text_content 47 | -------------------------------------------------------------------------------- /tests/providers/test_deepseek_provider.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock, patch 2 | 3 | import pytest 4 | 5 | from aisuite.providers.deepseek_provider import DeepseekProvider 6 | 7 | 8 | @pytest.fixture(autouse=True) 9 | def set_api_key_env_var(monkeypatch): 10 | """Fixture to set environment variables for tests.""" 11 | monkeypatch.setenv("DEEPSEEK_API_KEY", "test-api-key") 12 | 13 | 14 | def test_groq_provider(): 15 | """High-level test that the provider is initialized and chat completions are requested successfully.""" 16 | 17 | user_greeting = "Hello!" 18 | message_history = [{"role": "user", "content": user_greeting}] 19 | selected_model = "our-favorite-model" 20 | chosen_temperature = 0.75 21 | response_text_content = "mocked-text-response-from-model" 22 | 23 | provider = DeepseekProvider() 24 | mock_response = MagicMock() 25 | mock_response.choices = [MagicMock()] 26 | mock_response.choices[0].message = MagicMock() 27 | mock_response.choices[0].message.content = response_text_content 28 | 29 | with patch.object( 30 | provider.client.chat.completions, 31 | "create", 32 | return_value=mock_response, 33 | ) as mock_create: 34 | response = provider.chat_completions_create( 35 | messages=message_history, 36 | model=selected_model, 37 | temperature=chosen_temperature, 38 | ) 39 | 40 | mock_create.assert_called_with( 41 | messages=message_history, 42 | model=selected_model, 43 | temperature=chosen_temperature, 44 | ) 45 | 46 | assert response.choices[0].message.content == response_text_content 47 | -------------------------------------------------------------------------------- /tests/providers/test_google_converter.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import MagicMock 3 | from aisuite.providers.google_provider import GoogleMessageConverter 4 | from aisuite.framework.message import Message, ChatCompletionMessageToolCall, Function 5 | from aisuite.framework import ChatCompletionResponse 6 | 7 | 8 | class TestGoogleMessageConverter(unittest.TestCase): 9 | 10 | def setUp(self): 11 | self.converter = GoogleMessageConverter() 12 | 13 | def test_convert_request_user_message(self): 14 | messages = [{"role": "user", "content": "What is the weather today?"}] 15 | converted_messages = self.converter.convert_request(messages) 16 | 17 | self.assertEqual(len(converted_messages), 1) 18 | self.assertEqual(converted_messages[0].role, "user") 19 | self.assertEqual( 20 | converted_messages[0].parts[0].text, "What is the weather today?" 21 | ) 22 | 23 | def test_convert_request_tool_result_message(self): 24 | messages = [ 25 | { 26 | "role": "tool", 27 | "name": "get_weather", 28 | "content": '{"temperature": "15", "unit": "Celsius"}', 29 | } 30 | ] 31 | converted_messages = self.converter.convert_request(messages) 32 | 33 | self.assertEqual(len(converted_messages), 1) 34 | self.assertEqual(converted_messages[0].function_response.name, "get_weather") 35 | self.assertEqual( 36 | converted_messages[0].function_response.response, 37 | {"temperature": "15", "unit": "Celsius"}, 38 | ) 39 | 40 | def test_convert_request_assistant_message(self): 41 | messages = [ 42 | { 43 | "role": "assistant", 44 | "content": "The weather is sunny with a temperature of 25 degrees Celsius.", 45 | } 46 | ] 47 | converted_messages = self.converter.convert_request(messages) 48 | 49 | self.assertEqual(len(converted_messages), 1) 50 | self.assertEqual(converted_messages[0].role, "model") 51 | self.assertEqual( 52 | converted_messages[0].parts[0].text, 53 | "The weather is sunny with a temperature of 25 degrees Celsius.", 54 | ) 55 | 56 | def test_convert_response_with_function_call(self): 57 | function_call_mock = MagicMock() 58 | function_call_mock.name = "get_exchange_rate" 59 | function_call_mock.args = { 60 | "currency_from": "AUD", 61 | "currency_to": "SEK", 62 | "currency_date": "latest", 63 | } 64 | 65 | response = MagicMock() 66 | response.candidates = [ 67 | MagicMock( 68 | content=MagicMock(parts=[MagicMock(function_call=function_call_mock)]), 69 | finish_reason="function_call", 70 | ) 71 | ] 72 | 73 | normalized_response = self.converter.convert_response(response) 74 | 75 | self.assertIsInstance(normalized_response, ChatCompletionResponse) 76 | self.assertEqual(normalized_response.choices[0].finish_reason, "tool_calls") 77 | self.assertEqual( 78 | normalized_response.choices[0].message.tool_calls[0].function.name, 79 | "get_exchange_rate", 80 | ) 81 | self.assertEqual( 82 | normalized_response.choices[0].message.tool_calls[0].function.arguments, 83 | '{"currency_from": "AUD", "currency_to": "SEK", "currency_date": "latest"}', 84 | ) 85 | 86 | def test_convert_response_with_text(self): 87 | response = MagicMock() 88 | text_content = "The current exchange rate is 7.50 SEK per AUD." 89 | 90 | mock_part = MagicMock() 91 | mock_part.text = text_content 92 | mock_part.function_call = None 93 | 94 | mock_content = MagicMock() 95 | mock_content.parts = [mock_part] 96 | 97 | mock_candidate = MagicMock() 98 | mock_candidate.content = mock_content 99 | mock_candidate.finish_reason = "stop" 100 | 101 | response.candidates = [mock_candidate] 102 | 103 | normalized_response = self.converter.convert_response(response) 104 | 105 | self.assertIsInstance(normalized_response, ChatCompletionResponse) 106 | self.assertEqual(normalized_response.choices[0].finish_reason, "stop") 107 | self.assertEqual(normalized_response.choices[0].message.content, text_content) 108 | 109 | 110 | if __name__ == "__main__": 111 | unittest.main() 112 | -------------------------------------------------------------------------------- /tests/providers/test_google_provider.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import patch, MagicMock 3 | from aisuite.providers.google_provider import GoogleProvider 4 | from vertexai.generative_models import Content, Part 5 | import json 6 | 7 | 8 | @pytest.fixture(autouse=True) 9 | def set_api_key_env_var(monkeypatch): 10 | """Fixture to set environment variables for tests.""" 11 | monkeypatch.setenv("GOOGLE_APPLICATION_CREDENTIALS", "path-to-service-account-json") 12 | monkeypatch.setenv("GOOGLE_PROJECT_ID", "vertex-project-id") 13 | monkeypatch.setenv("GOOGLE_REGION", "us-central1") 14 | 15 | 16 | def test_missing_env_vars(): 17 | """Test that an error is raised if required environment variables are missing.""" 18 | with patch.dict("os.environ", {}, clear=True): 19 | with pytest.raises(EnvironmentError) as exc_info: 20 | GoogleProvider() 21 | assert "Missing one or more required Google environment variables" in str( 22 | exc_info.value 23 | ) 24 | 25 | 26 | def test_vertex_interface(): 27 | """High-level test that the interface is initialized and chat completions are requested successfully.""" 28 | 29 | # Test case 1: Regular text response 30 | def test_text_response(): 31 | user_greeting = "Hello!" 32 | message_history = [{"role": "user", "content": user_greeting}] 33 | selected_model = "our-favorite-model" 34 | response_text_content = "mocked-text-response-from-model" 35 | 36 | interface = GoogleProvider() 37 | mock_response = MagicMock() 38 | mock_response.candidates = [MagicMock()] 39 | mock_response.candidates[0].content.parts = [MagicMock()] 40 | mock_response.candidates[0].content.parts[0].text = response_text_content 41 | # Ensure function_call attribute doesn't exist 42 | del mock_response.candidates[0].content.parts[0].function_call 43 | 44 | with patch( 45 | "aisuite.providers.google_provider.GenerativeModel" 46 | ) as mock_generative_model: 47 | mock_model = MagicMock() 48 | mock_generative_model.return_value = mock_model 49 | mock_chat = MagicMock() 50 | mock_model.start_chat.return_value = mock_chat 51 | mock_chat.send_message.return_value = mock_response 52 | 53 | response = interface.chat_completions_create( 54 | messages=message_history, 55 | model=selected_model, 56 | temperature=0.7, 57 | ) 58 | 59 | # Assert the response is in the correct format 60 | assert response.choices[0].message.content == response_text_content 61 | assert response.choices[0].finish_reason == "stop" 62 | 63 | # Test case 2: Function call response 64 | def test_function_call(): 65 | user_greeting = "What's the weather?" 66 | message_history = [{"role": "user", "content": user_greeting}] 67 | selected_model = "our-favorite-model" 68 | 69 | interface = GoogleProvider() 70 | mock_response = MagicMock() 71 | mock_response.candidates = [MagicMock()] 72 | mock_response.candidates[0].content.parts = [MagicMock()] 73 | 74 | # Mock the function call response 75 | function_call_mock = MagicMock() 76 | function_call_mock.name = "get_weather" 77 | function_call_mock.args = {"location": "San Francisco"} 78 | mock_response.candidates[0].content.parts[0].function_call = function_call_mock 79 | mock_response.candidates[0].content.parts[0].text = None 80 | 81 | with patch( 82 | "aisuite.providers.google_provider.GenerativeModel" 83 | ) as mock_generative_model: 84 | mock_model = MagicMock() 85 | mock_generative_model.return_value = mock_model 86 | mock_chat = MagicMock() 87 | mock_model.start_chat.return_value = mock_chat 88 | mock_chat.send_message.return_value = mock_response 89 | 90 | response = interface.chat_completions_create( 91 | messages=message_history, 92 | model=selected_model, 93 | temperature=0.7, 94 | ) 95 | 96 | # Assert the response contains the function call 97 | assert response.choices[0].message.content is None 98 | assert response.choices[0].message.tool_calls[0].type == "function" 99 | assert ( 100 | response.choices[0].message.tool_calls[0].function.name == "get_weather" 101 | ) 102 | assert json.loads( 103 | response.choices[0].message.tool_calls[0].function.arguments 104 | ) == {"location": "San Francisco"} 105 | assert response.choices[0].finish_reason == "tool_calls" 106 | 107 | # Run both test cases 108 | test_text_response() 109 | test_function_call() 110 | 111 | 112 | def test_convert_openai_to_vertex_ai(): 113 | """Test the message conversion from OpenAI format to Vertex AI format.""" 114 | interface = GoogleProvider() 115 | message = {"role": "user", "content": "Hello!"} 116 | 117 | # Use the transformer to convert the message 118 | result = interface.transformer.convert_request([message]) 119 | 120 | # Verify the conversion result 121 | assert len(result) == 1 122 | assert isinstance(result[0], Content) 123 | assert result[0].role == "user" 124 | assert len(result[0].parts) == 1 125 | assert isinstance(result[0].parts[0], Part) 126 | assert result[0].parts[0].text == "Hello!" 127 | 128 | 129 | def test_role_conversions(): 130 | """Test that different message roles are converted correctly.""" 131 | interface = GoogleProvider() 132 | 133 | messages = [ 134 | {"role": "system", "content": "System message"}, 135 | {"role": "user", "content": "User message"}, 136 | {"role": "assistant", "content": "Assistant message"}, 137 | ] 138 | 139 | result = interface.transformer.convert_request(messages) 140 | 141 | # System and user messages should both be converted to "user" role in Vertex AI 142 | assert len(result) == 3 143 | assert result[0].role == "user" # system converted to user 144 | assert result[0].parts[0].text == "System message" 145 | 146 | assert result[1].role == "user" 147 | assert result[1].parts[0].text == "User message" 148 | 149 | assert result[2].role == "model" # assistant converted to model 150 | assert result[2].parts[0].text == "Assistant message" 151 | -------------------------------------------------------------------------------- /tests/providers/test_groq_provider.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock, patch 2 | 3 | import pytest 4 | 5 | from aisuite.providers.groq_provider import GroqProvider 6 | 7 | 8 | @pytest.fixture(autouse=True) 9 | def set_api_key_env_var(monkeypatch): 10 | """Fixture to set environment variables for tests.""" 11 | monkeypatch.setenv("GROQ_API_KEY", "test-api-key") 12 | 13 | 14 | def test_groq_provider(): 15 | """High-level test that the provider is initialized and chat completions are requested successfully.""" 16 | 17 | user_greeting = "Hello!" 18 | message_history = [{"role": "user", "content": user_greeting}] 19 | selected_model = "our-favorite-model" 20 | chosen_temperature = 0.75 21 | response_text_content = "mocked-text-response-from-model" 22 | 23 | provider = GroqProvider() 24 | mock_response = MagicMock() 25 | mock_response.model_dump.return_value = { 26 | "choices": [{"message": {"content": response_text_content}}] 27 | } 28 | 29 | with patch.object( 30 | provider.client.chat.completions, 31 | "create", 32 | return_value=mock_response, 33 | ) as mock_create: 34 | response = provider.chat_completions_create( 35 | messages=message_history, 36 | model=selected_model, 37 | temperature=chosen_temperature, 38 | ) 39 | 40 | mock_create.assert_called_with( 41 | messages=message_history, 42 | model=selected_model, 43 | temperature=chosen_temperature, 44 | ) 45 | 46 | assert response.choices[0].message.content == response_text_content 47 | -------------------------------------------------------------------------------- /tests/providers/test_mistral_provider.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import patch, MagicMock 3 | 4 | from aisuite.providers.mistral_provider import MistralProvider 5 | 6 | 7 | @pytest.fixture(autouse=True) 8 | def set_api_key_env_var(monkeypatch): 9 | """Fixture to set environment variables for tests.""" 10 | monkeypatch.setenv("MISTRAL_API_KEY", "test-api-key") 11 | 12 | 13 | def test_mistral_provider(): 14 | """High-level test that the provider is initialized and chat completions are requested successfully.""" 15 | 16 | user_greeting = "Hello!" 17 | message_history = [{"role": "user", "content": user_greeting}] 18 | selected_model = "our-favorite-model" 19 | chosen_temperature = 0.75 20 | response_text_content = "mocked-text-response-from-model" 21 | 22 | provider = MistralProvider() 23 | mock_response = MagicMock() 24 | mock_response.model_dump.return_value = { 25 | "choices": [{"message": {"content": response_text_content}}] 26 | } 27 | 28 | with patch.object( 29 | provider.client.chat, "complete", return_value=mock_response 30 | ) as mock_create: 31 | response = provider.chat_completions_create( 32 | messages=message_history, 33 | model=selected_model, 34 | temperature=chosen_temperature, 35 | ) 36 | 37 | mock_create.assert_called_with( 38 | messages=message_history, 39 | model=selected_model, 40 | temperature=chosen_temperature, 41 | ) 42 | 43 | assert response.choices[0].message.content == response_text_content 44 | -------------------------------------------------------------------------------- /tests/providers/test_nebius_provider.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import patch, MagicMock 3 | 4 | from aisuite.providers.nebius_provider import NebiusProvider 5 | 6 | 7 | @pytest.fixture(autouse=True) 8 | def set_api_key_env_var(monkeypatch): 9 | """Fixture to set environment variables for tests.""" 10 | monkeypatch.setenv("NEBIUS_API_KEY", "test-api-key") 11 | 12 | 13 | def test_nebius_provider(): 14 | """High-level test that the provider is initialized and chat completions are requested successfully.""" 15 | 16 | user_greeting = "Hello!" 17 | message_history = [{"role": "user", "content": user_greeting}] 18 | selected_model = "our-favorite-model" 19 | chosen_temperature = 0.75 20 | response_text_content = "mocked-text-response-from-model" 21 | 22 | provider = NebiusProvider() 23 | mock_response = MagicMock() 24 | mock_response.choices = [MagicMock()] 25 | mock_response.choices[0].message = MagicMock() 26 | mock_response.choices[0].message.content = response_text_content 27 | 28 | with patch.object( 29 | provider.client.chat.completions, 30 | "create", 31 | return_value=mock_response, 32 | ) as mock_create: 33 | response = provider.chat_completions_create( 34 | messages=message_history, 35 | model=selected_model, 36 | temperature=chosen_temperature, 37 | ) 38 | 39 | mock_create.assert_called_with( 40 | messages=message_history, 41 | model=selected_model, 42 | temperature=chosen_temperature, 43 | ) 44 | 45 | assert response.choices[0].message.content == response_text_content 46 | -------------------------------------------------------------------------------- /tests/providers/test_ollama_provider.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import patch, MagicMock 3 | from aisuite.providers.ollama_provider import OllamaProvider 4 | 5 | 6 | @pytest.fixture(autouse=True) 7 | def set_api_url_var(monkeypatch): 8 | """Fixture to set environment variables for tests.""" 9 | monkeypatch.setenv("OLLAMA_API_URL", "http://localhost:11434") 10 | 11 | 12 | def test_completion(): 13 | """Test that completions request successfully.""" 14 | 15 | user_greeting = "Howdy!" 16 | message_history = [{"role": "user", "content": user_greeting}] 17 | selected_model = "best-model-ever" 18 | chosen_temperature = 0.77 19 | response_text_content = "mocked-text-response-from-ollama-model" 20 | 21 | ollama = OllamaProvider() 22 | mock_response = {"message": {"content": response_text_content}} 23 | 24 | with patch( 25 | "httpx.post", 26 | return_value=MagicMock(status_code=200, json=lambda: mock_response), 27 | ) as mock_post: 28 | response = ollama.chat_completions_create( 29 | messages=message_history, 30 | model=selected_model, 31 | temperature=chosen_temperature, 32 | ) 33 | 34 | mock_post.assert_called_once_with( 35 | "http://localhost:11434/api/chat", 36 | json={ 37 | "model": selected_model, 38 | "messages": message_history, 39 | "stream": False, 40 | "temperature": chosen_temperature, 41 | }, 42 | timeout=30, 43 | ) 44 | 45 | assert response.choices[0].message.content == response_text_content 46 | -------------------------------------------------------------------------------- /tests/providers/test_sambanova_provider.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock, patch 2 | 3 | import pytest 4 | 5 | from aisuite.providers.sambanova_provider import SambanovaProvider 6 | 7 | 8 | @pytest.fixture(autouse=True) 9 | def set_api_key_env_var(monkeypatch): 10 | """Fixture to set environment variables for tests.""" 11 | monkeypatch.setenv("SAMBANOVA_API_KEY", "test-api-key") 12 | 13 | 14 | def test_sambanova_provider(): 15 | """High-level test that the provider is initialized and chat completions are requested successfully.""" 16 | 17 | user_greeting = "Hello!" 18 | message_history = [{"role": "user", "content": user_greeting}] 19 | selected_model = "our-favorite-model" 20 | chosen_temperature = 0.75 21 | response_text_content = "mocked-text-response-from-model" 22 | 23 | provider = SambanovaProvider() 24 | mock_response = MagicMock() 25 | mock_response.model_dump.return_value = { 26 | "choices": [ 27 | {"message": {"content": response_text_content, "role": "assistant"}} 28 | ] 29 | } 30 | 31 | with patch.object( 32 | provider.client.chat.completions, 33 | "create", 34 | return_value=mock_response, 35 | ) as mock_create: 36 | response = provider.chat_completions_create( 37 | messages=message_history, 38 | model=selected_model, 39 | temperature=chosen_temperature, 40 | ) 41 | 42 | mock_create.assert_called_with( 43 | messages=message_history, 44 | model=selected_model, 45 | temperature=chosen_temperature, 46 | ) 47 | 48 | assert response.choices[0].message.content == response_text_content 49 | -------------------------------------------------------------------------------- /tests/providers/test_watsonx_provider.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock, patch 2 | 3 | import pytest 4 | 5 | try: 6 | from ibm_watsonx_ai.metanames import GenChatParamsMetaNames as GenChatParams 7 | except Exception as e: 8 | pytest.skip(f"Skipping test due to import error: {e}", allow_module_level=True) 9 | 10 | from aisuite.providers.watsonx_provider import WatsonxProvider 11 | 12 | 13 | @pytest.fixture(autouse=True) 14 | def set_api_key_env_var(monkeypatch): 15 | """Fixture to set environment variables for tests.""" 16 | monkeypatch.setenv("WATSONX_SERVICE_URL", "https://watsonx-service-url.com") 17 | monkeypatch.setenv("WATSONX_API_KEY", "test-api-key") 18 | monkeypatch.setenv("WATSONX_PROJECT_ID", "test-project-id") 19 | 20 | 21 | @pytest.mark.skip(reason="Skipping due to version compatibility issue on python 3.11") 22 | def test_watsonx_provider(): 23 | """High-level test that the provider is initialized and chat completions are requested successfully.""" 24 | 25 | user_greeting = "Hello!" 26 | message_history = [{"role": "user", "content": user_greeting}] 27 | selected_model = "our-favorite-model" 28 | chosen_temperature = 0.7 29 | response_text_content = "mocked-text-response-from-model" 30 | 31 | provider = WatsonxProvider() 32 | mock_response = {"choices": [{"message": {"content": response_text_content}}]} 33 | 34 | with patch( 35 | "aisuite.providers.watsonx_provider.ModelInference" 36 | ) as mock_model_inference: 37 | mock_model = MagicMock() 38 | mock_model_inference.return_value = mock_model 39 | mock_model.chat.return_value = mock_response 40 | 41 | response = provider.chat_completions_create( 42 | messages=message_history, 43 | model=selected_model, 44 | temperature=chosen_temperature, 45 | ) 46 | 47 | # Assert that ModelInference was called with correct arguments. 48 | mock_model_inference.assert_called_once() 49 | args, kwargs = mock_model_inference.call_args 50 | assert kwargs["model_id"] == selected_model 51 | assert kwargs["project_id"] == provider.project_id 52 | 53 | # Assert that the credentials have the correct API key and service URL. 54 | credentials = kwargs["credentials"] 55 | assert credentials.api_key == provider.api_key 56 | assert credentials.url == provider.service_url 57 | 58 | # Assert that chat was called with correct history and params 59 | mock_model.chat.assert_called_once_with( 60 | messages=message_history, 61 | params={GenChatParams.TEMPERATURE: chosen_temperature}, 62 | ) 63 | 64 | assert response.choices[0].message.content == response_text_content 65 | -------------------------------------------------------------------------------- /tests/utils/test_tool_manager.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from pydantic import BaseModel 3 | from typing import Dict 4 | from aisuite.utils.tools import Tools # Import your ToolManager class 5 | from enum import Enum 6 | 7 | 8 | # Define a sample tool function and Pydantic model for testing 9 | class TemperatureUnit(str, Enum): 10 | CELSIUS = "Celsius" 11 | FAHRENHEIT = "Fahrenheit" 12 | 13 | 14 | class TemperatureParamsV2(BaseModel): 15 | location: str 16 | unit: TemperatureUnit = TemperatureUnit.CELSIUS 17 | 18 | 19 | class TemperatureParams(BaseModel): 20 | location: str 21 | unit: str = "Celsius" 22 | 23 | 24 | def get_current_temperature(location: str, unit: str = "Celsius") -> Dict[str, str]: 25 | """Gets the current temperature for a specific location and unit.""" 26 | return {"location": location, "unit": unit, "temperature": "72"} 27 | 28 | 29 | def missing_annotation_tool(location, unit="Celsius"): 30 | """Tool function without type annotations.""" 31 | return {"location": location, "unit": unit, "temperature": "72"} 32 | 33 | 34 | def get_current_temperature_v2( 35 | location: str, unit: TemperatureUnit = TemperatureUnit.CELSIUS 36 | ) -> Dict[str, str]: 37 | """Gets the current temperature for a specific location and unit (with enum support).""" 38 | return {"location": location, "unit": unit, "temperature": "72"} 39 | 40 | 41 | class TestToolManager(unittest.TestCase): 42 | def setUp(self): 43 | self.tool_manager = Tools() 44 | 45 | def test_add_tool_with_pydantic_model(self): 46 | """Test adding a tool with an explicit Pydantic model.""" 47 | self.tool_manager._add_tool(get_current_temperature, TemperatureParams) 48 | 49 | expected_tool_spec = [ 50 | { 51 | "type": "function", 52 | "function": { 53 | "name": "get_current_temperature", 54 | "description": "Gets the current temperature for a specific location and unit.", 55 | "parameters": { 56 | "type": "object", 57 | "properties": { 58 | "location": { 59 | "type": "string", 60 | "description": "", 61 | }, 62 | "unit": { 63 | "type": "string", 64 | "description": "", 65 | "default": "Celsius", 66 | }, 67 | }, 68 | "required": ["location"], 69 | }, 70 | }, 71 | } 72 | ] 73 | 74 | tools = self.tool_manager.tools() 75 | self.assertIn( 76 | "get_current_temperature", [tool["function"]["name"] for tool in tools] 77 | ) 78 | assert ( 79 | tools == expected_tool_spec 80 | ), f"Expected {expected_tool_spec}, but got {tools}" 81 | 82 | def test_add_tool_with_signature_inference(self): 83 | """Test adding a tool and inferring parameters from the function signature.""" 84 | self.tool_manager._add_tool(get_current_temperature) 85 | # Expected output from tool_manager.tools() when called with OpenAI format 86 | expected_tool_spec = [ 87 | { 88 | "type": "function", 89 | "function": { 90 | "name": "get_current_temperature", 91 | "description": "Gets the current temperature for a specific location and unit.", 92 | "parameters": { 93 | "type": "object", 94 | "properties": { 95 | "location": { 96 | "type": "string", 97 | "description": "", # No description provided in function signature 98 | }, 99 | "unit": { 100 | "type": "string", 101 | "description": "", 102 | "default": "Celsius", 103 | }, 104 | }, 105 | "required": ["location"], 106 | }, 107 | }, 108 | } 109 | ] 110 | tools = self.tool_manager.tools() 111 | print(tools) 112 | self.assertIn( 113 | "get_current_temperature", [tool["function"]["name"] for tool in tools] 114 | ) 115 | assert ( 116 | tools == expected_tool_spec 117 | ), f"Expected {expected_tool_spec}, but got {tools}" 118 | 119 | def test_add_tool_missing_annotation_raises_exception(self): 120 | """Test that adding a tool with missing type annotations raises a TypeError.""" 121 | with self.assertRaises(TypeError): 122 | self.tool_manager._add_tool(missing_annotation_tool) 123 | 124 | def test_execute_tool_valid_parameters(self): 125 | """Test executing a registered tool with valid parameters.""" 126 | self.tool_manager._add_tool(get_current_temperature, TemperatureParams) 127 | tool_call = { 128 | "id": "call_1", 129 | "function": { 130 | "name": "get_current_temperature", 131 | "arguments": {"location": "San Francisco", "unit": "Celsius"}, 132 | }, 133 | } 134 | result, result_message = self.tool_manager.execute_tool(tool_call) 135 | 136 | # Assuming result is returned as a list with a single dictionary 137 | result_dict = result[0] if isinstance(result, list) else result 138 | 139 | # Check that the result matches expected output 140 | self.assertEqual(result_dict["location"], "San Francisco") 141 | self.assertEqual(result_dict["unit"], "Celsius") 142 | self.assertEqual(result_dict["temperature"], "72") 143 | 144 | def test_execute_tool_invalid_parameters(self): 145 | """Test that executing a tool with invalid parameters raises a ValueError.""" 146 | self.tool_manager._add_tool(get_current_temperature, TemperatureParams) 147 | tool_call = { 148 | "id": "call_1", 149 | "function": { 150 | "name": "get_current_temperature", 151 | "arguments": {"location": 123}, # Invalid type for location 152 | }, 153 | } 154 | 155 | with self.assertRaises(ValueError) as context: 156 | self.tool_manager.execute_tool(tool_call) 157 | 158 | # Verify the error message contains information about the validation error 159 | self.assertIn( 160 | "Error in tool 'get_current_temperature' parameters", str(context.exception) 161 | ) 162 | 163 | def test_add_tool_with_enum(self): 164 | """Test adding a tool with an enum parameter.""" 165 | self.tool_manager._add_tool(get_current_temperature_v2, TemperatureParamsV2) 166 | 167 | expected_tool_spec = [ 168 | { 169 | "type": "function", 170 | "function": { 171 | "name": "get_current_temperature_v2", 172 | "description": "Gets the current temperature for a specific location and unit (with enum support).", 173 | "parameters": { 174 | "type": "object", 175 | "properties": { 176 | "location": { 177 | "type": "string", 178 | "description": "", 179 | }, 180 | "unit": { 181 | "type": "string", 182 | "enum": ["Celsius", "Fahrenheit"], 183 | "description": "", 184 | "default": "Celsius", 185 | }, 186 | }, 187 | "required": ["location"], 188 | }, 189 | }, 190 | } 191 | ] 192 | 193 | tools = self.tool_manager.tools() 194 | assert ( 195 | tools == expected_tool_spec 196 | ), f"Expected {expected_tool_spec}, but got {tools}" 197 | 198 | 199 | if __name__ == "__main__": 200 | unittest.main() 201 | --------------------------------------------------------------------------------