├── src └── gigasmol │ ├── gigachat_api │ ├── __init__.py │ ├── auth.py │ └── api_model.py │ ├── __init__.py │ └── models.py ├── .gitignore ├── assets └── logo.png ├── pyproject.toml ├── README.md ├── LICENSE └── examples └── structured_output.ipynb /src/gigasmol/gigachat_api/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .DS_Store 3 | credentials.json 4 | *.egg-info/ 5 | dist/ -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poteminr/gigasmol/HEAD/assets/logo.png -------------------------------------------------------------------------------- /src/gigasmol/__init__.py: -------------------------------------------------------------------------------- 1 | """GigaSmol - Lightweight GigaChat API wrapper for smolagents""" 2 | 3 | from gigasmol.gigachat_api.api_model import GigaChat 4 | 5 | __version__ = "0.0.8" 6 | 7 | __all__ = ["GigaChat"] 8 | 9 | try: 10 | from gigasmol.models import GigaChatSmolModel 11 | __all__.append("GigaChatSmolModel") 12 | except ImportError: 13 | # For API-only installations where smolagents is not installed 14 | class GigaChatSmolModel: 15 | """This class requires the smolagents package. 16 | 17 | Install with: pip install gigasmol 18 | """ 19 | def __init__(self, *args, **kwargs): 20 | raise ImportError( 21 | 'The smolagents package is required to use GigaChatSmolModel. ' 22 | 'Please install the full package with `pip install "gigasmol[agent]"`.' 23 | ) 24 | 25 | __all__.append("GigaChatSmolModel") 26 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "gigasmol" 7 | version = "0.0.8" 8 | authors = [ 9 | {name = "poteminr", email = "poteminr@gmail.com"}, 10 | ] 11 | description = "A lightweight wrapper for gigachat api model for seamless use with hf smolagents" 12 | readme = "README.md" 13 | requires-python = ">=3.8" 14 | license = {text = "LICENSE"} 15 | # Base dependencies - ONLY the API parts 16 | dependencies = [ 17 | "requests>=2.31.0", 18 | "sseclient>=0.0.27", 19 | "urllib3>=2.0.0", 20 | ] 21 | 22 | [project.urls] 23 | "Homepage" = "https://github.com/poteminr/gigasmol" 24 | "Bug Tracker" = "https://github.com/poteminr/gigasmol/issues" 25 | 26 | [tool.setuptools] 27 | package-dir = {"" = "src"} 28 | 29 | [project.optional-dependencies] 30 | # Agent functionality requires Python 3.10+ 31 | agent = [ 32 | "smolagents==1.22.0; python_version >= '3.10'", 33 | "huggingface-hub>=0.19.0; python_version >= '3.10'", 34 | "gigasmol-agent-requires-python-3.10-or-newer==0.0.0; python_version < '3.10'", 35 | ] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | 7 | 8 |
9 | 10 |
11 | GigaSmol Logo 12 |

lightweight gigachat api wrapper for smolagents

13 |
14 | 15 | ## Overview 16 | 17 | gigasmol serves two primary purposes: 18 | 19 | 1. Provides **direct, lightweight access** to GigaChat models through GigaChat API without unnecessary abstractions 20 | 2. Creates a **smolagents-compatible wrapper** that lets you use GigaChat within agent systems 21 | 22 | No complex abstractions — just clean, straightforward access to GigaChat's capabilities through smolagents. 23 | 24 | ``` 25 | GigaChat API + smolagents = gigasmol 💀 26 | ``` 27 | 28 | ## Why gigasmol 💀? 29 | 30 | - **Tiny Footprint**: Less than 1K lines of code total 31 | - **Simple Structure**: Just 4 core files 32 | - **Zero Bloat**: Only essential dependencies 33 | - **Easy to Understand**: Read and comprehend the entire codebase in minutes 34 | - **Maintainable**: Small, focused codebase means fewer bugs and easier updates 35 | ## Installation 36 | ### API-Only Installation (default) 37 | `python>=3.8` 38 | ```bash 39 | pip install gigasmol 40 | ``` 41 | 42 | ### Full Installation with Agent Support 43 | `python>=3.10` 44 | ```bash 45 | pip install "gigasmol[agent]" 46 | ``` 47 | 48 | 49 | ## Quick Start 50 | ### Raw GigaChat API 51 | `gigasmol` 52 | 53 | 54 | ```python 55 | import json 56 | from gigasmol import GigaChat 57 | 58 | # Direct access to GigaChat API 59 | gigachat = GigaChat( 60 | auth_data="YOUR_AUTH_TOKEN", 61 | model_name="GigaChat-Max", 62 | ) 63 | 64 | # Generate a response 65 | response = gigachat.chat([ 66 | {"role": "user", "content": "What is the capital of Russia?"} 67 | ]) 68 | print(response['answer']) # or print(response['response']['choices'][0]['message']['content']) 69 | ``` 70 | ### Usage with smolagents 71 | `gigasmol[agent]` 72 | 73 | ```python 74 | from gigasmol import GigaChatSmolModel 75 | from smolagents import CodeAgent, ToolCallingAgent, DuckDuckGoSearchTool 76 | 77 | # Initialize the GigaChat model with your credentials 78 | model = GigaChatSmolModel( 79 | auth_data="YOUR_AUTH_TOKEN", 80 | model_name="GigaChat-Max" 81 | ) 82 | 83 | # Create a CodeAgent with the model 84 | code_agent = CodeAgent( 85 | tools=[DuckDuckGoSearchTool()], 86 | model=model 87 | ) 88 | 89 | # Run the code_agent 90 | code_agent.run("What are the main tourist attractions in Moscow?") 91 | 92 | # Create a ToolCallingAgent with the model 93 | tool_calling_agent = ToolCallingAgent( 94 | tools=[DuckDuckGoSearchTool()], 95 | model=model 96 | ) 97 | 98 | # Run the tool_calling_agent 99 | tool_calling_agent.run("What are the main tourist attractions in Moscow?") 100 | ``` 101 | 102 | 103 | 104 | ## How It Works 105 | 106 | GigaSmol provides two layers of functionality: 107 | 108 | ``` 109 | ┌───────────────────────────────────────────────────┐ 110 | │ gigasmol │ 111 | ├───────────────────────────────────────────────────┤ 112 | │ ┌───────────────┐ ┌───────────────────┐ │ 113 | │ │ Direct │ │ smolagents │ │ 114 | │ │ GigaChat API │ │ compatibility │ │ 115 | │ │ access │ │ layer │ │ 116 | │ └───────────────┘ └───────────────────┘ │ 117 | └───────────────────────────────────────────────────┘ 118 | │ │ 119 | ▼ ▼ 120 | ┌─────────────┐ ┌────────────────┐ 121 | │ GigaChat API│ │ Agent systems │ 122 | └─────────────┘ └────────────────┘ 123 | ``` 124 | 125 | 1. **Direct API Access**: Use `GigaChat` for clean, direct access to the API 126 | 2. **smolagents Integration**: Use `GigaChatSmolModel` to plug GigaChat into smolagents 127 | 128 | 129 | ## Examples 130 | 131 | Check the `examples` directory: 132 | - `structured_output.ipynb`: Using GigaChat API and function_calling for structured output 133 | - `agents.ipynb`: Building code and tool agents with GigaChat and smolagents 134 | 135 | ## Acknowledgements 136 | 137 | - [SberDevices](https://gigachat.ru/) for creating the GigaChat API 138 | - [Hugging Face](https://huggingface.co/) for the smolagents framework 139 | -------------------------------------------------------------------------------- /src/gigasmol/gigachat_api/auth.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import datetime 3 | import logging 4 | from typing import Literal, Optional 5 | import uuid 6 | 7 | import requests 8 | import urllib3 9 | from urllib3.exceptions import InsecureRequestWarning 10 | 11 | 12 | class APIAuthorize(ABC): 13 | """Interface for authorization and obtaining access to various APIs. 14 | 15 | Each concrete authorization class receives user data and issues an API token. 16 | """ 17 | 18 | @property 19 | @abstractmethod 20 | def token(self) -> str: 21 | """Returns the token for API access. 22 | 23 | Returns: 24 | str: Valid authentication token. 25 | """ 26 | pass 27 | 28 | @property 29 | @abstractmethod 30 | def cert_path(self) -> str: 31 | """Returns the path to the certificate for API access. 32 | 33 | Returns: 34 | str: Path to the certificate file. 35 | """ 36 | pass 37 | 38 | 39 | class LLMAuthorizeAdvanced(APIAuthorize): 40 | """Access to GigaChat via username/password authentication.""" 41 | def __init__( 42 | self, 43 | username: str, 44 | password: str, 45 | auth_endpoint: str = "https://beta.saluteai.sberdevices.ru/v1/" 46 | ) -> None: 47 | """Initialize with GigaChat API access parameters. 48 | 49 | Args: 50 | username: GigaChat API username 51 | password: GigaChat API password 52 | """ 53 | self.auth_endpoint = auth_endpoint 54 | self.__username = username 55 | self.__password = password 56 | self.__token_expiration_time = datetime.datetime.min 57 | self._token = "" 58 | self._set_token() 59 | 60 | urllib3.disable_warnings( 61 | urllib3.exceptions.InsecureRequestWarning 62 | ) 63 | 64 | @property 65 | def token(self) -> str: 66 | """Get the current valid token, refreshing if necessary. 67 | 68 | Returns: 69 | str: Valid authentication token. 70 | """ 71 | self._check_token() 72 | return self._token 73 | 74 | @token.setter 75 | def token(self, value: str) -> None: 76 | """Set the token value. 77 | 78 | Args: 79 | value: Token value to set. 80 | """ 81 | self._token = value 82 | 83 | @property 84 | def cert_path(self) -> str: 85 | """Get the certificate path (not used in this implementation). 86 | 87 | Returns: 88 | str: Empty string as this implementation doesn't use certificates. 89 | """ 90 | return "" 91 | 92 | @cert_path.setter 93 | def cert_path(self, value: str) -> None: 94 | """Certificate path setter (not used in this implementation). 95 | 96 | Args: 97 | value: Certificate path. 98 | """ 99 | pass 100 | 101 | def _check_token(self) -> None: 102 | """Check if the current token is expired and refresh if needed.""" 103 | if datetime.datetime.now() > self.__token_expiration_time: 104 | self._set_token() 105 | 106 | def _set_token(self) -> None: 107 | """Obtain a fresh access token from the API.""" 108 | logger = logging.getLogger(__name__) 109 | token_url = f"{self.API_ENDPOINT}token" 110 | try: 111 | logger.info( 112 | "Getting token from GigaChat API", 113 | extra={"url": token_url, "username": self.__username} 114 | ) 115 | response = requests.post(token_url, auth=(self.__username, self.__password)) 116 | if response.status_code == 200: 117 | data = response.json() 118 | if "tok" not in data or "exp" not in data: 119 | raise ValueError("Incorrect response format from GigaChat API") 120 | 121 | self.__token_expiration_time = datetime.datetime.utcfromtimestamp(data["exp"]) 122 | self._token = data["tok"] 123 | logger.info("Successfully obtained authentication token") 124 | else: 125 | raise ValueError( 126 | f"Failed to get token: HTTP {response.status_code} - {response.text}" 127 | ) 128 | except requests.RequestException as e: 129 | logger.error(f"Network error during token retrieval: {str(e)}") 130 | raise 131 | except ValueError as e: 132 | logger.error(f"Error processing token response: {str(e)}") 133 | raise 134 | except Exception as e: 135 | logger.error(f"Unexpected error during token retrieval: {str(e)}") 136 | raise 137 | 138 | 139 | class LLMAuthorizeEnablers(APIAuthorize): 140 | """Access to GigaChat via client_id/client_secret with certificate.""" 141 | 142 | def __init__( 143 | self, 144 | auth_data: str, 145 | client_id: Optional[str] = None, 146 | auth_endpoint: str = "https://ngw.devices.sberbank.ru:9443/api/v2/oauth", 147 | auth_scope: str = Literal["GIGACHAT_API_PERS", "GIGACHAT_API_CORP", "GIGACHAT_API_B2B"], 148 | cert_path: Optional[str] = None 149 | ) -> None: 150 | """Initialize with GigaChat API access parameters. 151 | Args: 152 | auth_data: Authorization key for exchanging messages with GigaChat API 153 | client_id: GigaChat API client ID (used as RqUID) 154 | auth_endpoint: The authentication endpoint URL 155 | auth_scope: The authentication scope. Contains information about the API version 156 | being accessed. If you are using the API version for individual 157 | entrepreneurs or legal entities, specify this explicitly in the scope 158 | parameter (GIGACHAT_API_PERS, GIGACHAT_API_CORP, or GIGACHAT_API_B2B) 159 | cert_path: Path to the certificate for GigaChat API access 160 | """ 161 | self.auth_endpoint = auth_endpoint 162 | self.auth_scope = auth_scope 163 | self.__client_id = client_id if client_id is not None else str(uuid.uuid4()) 164 | self.__auth_data = auth_data 165 | self.__token_expiration_time = datetime.datetime.min 166 | self._cert_path = cert_path if cert_path is not None else "" 167 | self._token = "" 168 | self._set_token() 169 | 170 | @property 171 | def token(self) -> str: 172 | """Get the current valid token, refreshing if necessary. 173 | 174 | Returns: 175 | str: Valid authentication token. 176 | """ 177 | self._check_token() 178 | return self._token 179 | 180 | @token.setter 181 | def token(self, value: str) -> None: 182 | """Set the token value. 183 | 184 | Args: 185 | value: Token value to set. 186 | """ 187 | self._token = value 188 | 189 | @property 190 | def cert_path(self) -> str: 191 | """Get the path to the certificate. 192 | 193 | Returns: 194 | str: Path to the certificate file. 195 | """ 196 | return self._cert_path 197 | 198 | @cert_path.setter 199 | def cert_path(self, value: str) -> None: 200 | """Set the certificate path. 201 | 202 | Args: 203 | value: Certificate path to set. 204 | """ 205 | self._cert_path = value 206 | 207 | def _check_token(self) -> None: 208 | """Check if the current token is expired and refresh if needed.""" 209 | if datetime.datetime.now() > self.__token_expiration_time: 210 | self._set_token() 211 | 212 | def _set_token(self) -> None: 213 | """Obtain a fresh access token from the API.""" 214 | logger = logging.getLogger(__name__) 215 | headers = { 216 | "Authorization": f"Bearer {self.__auth_data}", 217 | "RqUID": self.__client_id, 218 | "Content-Type": "application/x-www-form-urlencoded", 219 | } 220 | try: 221 | logger.info( 222 | "Getting token from GigaChat Auth API", 223 | extra={"url": self.auth_endpoint, "client_id": self.__client_id} 224 | ) 225 | response = requests.post( 226 | self.auth_endpoint, 227 | data={"scope": self.auth_scope}, 228 | headers=headers, 229 | verify=self._cert_path 230 | ) 231 | if response.status_code == 200: 232 | data = response.json() 233 | if "access_token" not in data or "expires_at" not in data: 234 | raise ValueError("Incorrect response format from GigaChat Auth API") 235 | 236 | self._token = data["access_token"] 237 | expiry_timestamp = int(data["expires_at"]) / 1000 238 | self.__token_expiration_time = datetime.datetime.fromtimestamp(expiry_timestamp) 239 | 240 | logger.info("Successfully obtained authentication token") 241 | else: 242 | raise ValueError( 243 | f"Failed to get token: HTTP {response.status_code} - {response.text}" 244 | ) 245 | except requests.RequestException as e: 246 | logger.error(f"Network error during token retrieval: {str(e)}") 247 | raise 248 | except ValueError as e: 249 | logger.error(f"Error processing token response: {str(e)}") 250 | raise 251 | except Exception as e: 252 | logger.error(f"Unexpected error during token retrieval: {str(e)}") 253 | raise 254 | 255 | 256 | class SberDSAuthorize(APIAuthorize): 257 | @property 258 | def cert_path(self) -> bool: 259 | return False 260 | 261 | @property 262 | def token(self) -> str: 263 | return "" -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /examples/structured_output.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# !pip install gigasmol -q" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "%load_ext autoreload\n", 19 | "%autoreload 2\n", 20 | " \n", 21 | "import warnings\n", 22 | "warnings.filterwarnings('ignore')\n", 23 | "\n", 24 | "import json\n", 25 | "from gigasmol import GigaChat" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "### Initialize API model for structured outputs via function calling (without agents)" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 3, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "credentials = json.load(open('credentials.json'))\n", 42 | "\n", 43 | "model = GigaChat(\n", 44 | " auth_data=credentials['auth_data'],\n", 45 | " model_name=\"GigaChat-Max\",\n", 46 | " api_endpoint=\"https://gigachat.devices.sberbank.ru/api/v1/\", # \"https://gigachat-preview.devices.sberbank.ru/api/v1/\" \n", 47 | " temperature=0.0000001,\n", 48 | " top_p=0.1,\n", 49 | " repetition_penalty=1.1,\n", 50 | " max_tokens=1024,\n", 51 | " profanity_check=False,\n", 52 | " auth_scope=\"GIGACHAT_API_CORP\",\n", 53 | ")" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "### Define schema for structured entity extraction" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 4, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "schema = {\n", 70 | " \"name\": \"extract_complex_entities\",\n", 71 | " \"description\": \"Extracts complex, nested entity information from text.\",\n", 72 | " \"parameters\": {\n", 73 | " \"type\": \"object\",\n", 74 | " \"properties\": {\n", 75 | " \"persons\": {\n", 76 | " \"type\": \"array\",\n", 77 | " \"description\": \"List of persons found in the text with nested properties.\",\n", 78 | " \"items\": {\n", 79 | " \"type\": \"object\",\n", 80 | " \"properties\": {\n", 81 | " \"name\": {\n", 82 | " \"type\": \"string\",\n", 83 | " \"description\": \"Full name of the person\"\n", 84 | " },\n", 85 | " \"birthPlace\": {\n", 86 | " \"type\": \"string\",\n", 87 | " \"description\": \"Birthplace of the person\"\n", 88 | " },\n", 89 | " \"roles\": {\n", 90 | " \"type\": \"array\",\n", 91 | " \"description\": \"Roles or titles the person has held\",\n", 92 | " \"items\": {\n", 93 | " \"type\": \"string\"\n", 94 | " }\n", 95 | " },\n", 96 | " \"education\": {\n", 97 | " \"type\": \"array\",\n", 98 | " \"description\": \"List of institutions where the person studied\",\n", 99 | " \"items\": {\n", 100 | " \"type\": \"string\"\n", 101 | " }\n", 102 | " }\n", 103 | " },\n", 104 | " \"required\": [\"name\"]\n", 105 | " }\n", 106 | " },\n", 107 | " \"organizations\": {\n", 108 | " \"type\": \"array\",\n", 109 | " \"description\": \"List of organizations with additional details.\",\n", 110 | " \"items\": {\n", 111 | " \"type\": \"object\",\n", 112 | " \"properties\": {\n", 113 | " \"orgName\": {\n", 114 | " \"type\": \"string\",\n", 115 | " \"description\": \"Name of the organization\"\n", 116 | " },\n", 117 | " \"orgType\": {\n", 118 | " \"type\": \"string\",\n", 119 | " \"description\": \"Type of organization (e.g., university, government, company)\"\n", 120 | " },\n", 121 | " \"location\": {\n", 122 | " \"type\": \"string\",\n", 123 | " \"description\": \"Location of the organization\"\n", 124 | " }\n", 125 | " },\n", 126 | " \"required\": [\"orgName\"]\n", 127 | " }\n", 128 | " },\n", 129 | " \"locations\": {\n", 130 | " \"type\": \"array\",\n", 131 | " \"description\": \"List of location entities with details.\",\n", 132 | " \"items\": {\n", 133 | " \"type\": \"object\",\n", 134 | " \"properties\": {\n", 135 | " \"placeName\": {\n", 136 | " \"type\": \"string\",\n", 137 | " \"description\": \"Name of the place\"\n", 138 | " },\n", 139 | " \"country\": {\n", 140 | " \"type\": \"string\",\n", 141 | " \"description\": \"Country of the place\"\n", 142 | " }\n", 143 | " },\n", 144 | " \"required\": [\"placeName\"]\n", 145 | " }\n", 146 | " }\n", 147 | " },\n", 148 | " \"required\": [\"persons\", \"organizations\", \"locations\"]\n", 149 | " }\n", 150 | "}" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": {}, 156 | "source": [ 157 | "### Extracting Structured Entities from Text\n" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 5, 163 | "metadata": {}, 164 | "outputs": [], 165 | "source": [ 166 | "system = (\n", 167 | " \"You are a helpful assistant that extracts complex entity information from text \"\n", 168 | " \"and returns ONLY JSON.\"\n", 169 | ")\n", 170 | "\n", 171 | "text = \"\"\"\n", 172 | "“Bill Gates, the co-founder of Microsoft, was born in Seattle, Washington, in 1955. He initially enrolled at Harvard University to study pre-law, but he shifted his focus to mathematics and computer science before dropping out. Along with Paul Allen, he founded Microsoft in 1975, which rapidly grew into one of the world’s largest software companies. Over the years, Bill Gates became a notable philanthropist through the Bill & Melinda Gates Foundation, focusing on global health and education initiatives.\n", 173 | "\n", 174 | "\tMeanwhile, Mark Zuckerberg, born in 1984, is the co-founder and CEO of Facebook, now known as Meta. He developed the initial version of the platform while studying computer science at Harvard University, though he never completed his degree. Under his leadership, Facebook rapidly expanded into a worldwide social media giant. Zuckerberg also launched the Chan Zuckerberg Initiative in collaboration with his wife, Dr. Priscilla Chan, to tackle issues related to education, healthcare, and scientific research.\n", 175 | "\n", 176 | "\tOn the other hand, Elon Musk, born in Pretoria, South Africa, in 1971, is known for founding SpaceX and co-founding Tesla. After moving to Canada, he later studied at the University of Pennsylvania in the United States, where he earned degrees in physics and economics. Musk’s ventures have often focused on emerging technologies and innovation—from electric vehicles and solar energy solutions at Tesla to space exploration and rocket technology at SpaceX. In recent years, he has also been actively involved in artificial intelligence research, among other futuristic endeavors.”\n", 177 | "\"\"\"\n", 178 | "\n", 179 | "\n", 180 | "prompt = (\n", 181 | " \"The user provides the following text:\\n\\n\"\n", 182 | " f\"{text}\\n\"\n", 183 | " \"Extract the following complex information:\\n\"\n", 184 | " \" - persons: with name, birthPlace, roles (array of strings), and education (array of schools),\\n\"\n", 185 | " \" - organizations: with orgName, orgType, and location,\\n\"\n", 186 | " \" - locations: with placeName, country.\\n\\n\"\n", 187 | " \"Return ONLY JSON containing arrays for 'persons', 'organizations', and 'locations'. \"\n", 188 | " \"Each 'person' should have nested fields such as 'name', 'birthPlace', 'roles', \"\n", 189 | " \"and an array of 'education' items. Each 'organization' has 'orgName', 'orgType', 'location'. \"\n", 190 | " \"Each 'location' has 'placeName' and 'country'.\"\n", 191 | ")" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 6, 197 | "metadata": {}, 198 | "outputs": [], 199 | "source": [ 200 | "messages = [\n", 201 | " {\"role\": \"system\", \"content\": system},\n", 202 | " {\"role\": \"user\", \"content\": prompt}\n", 203 | "]" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": 7, 209 | "metadata": {}, 210 | "outputs": [], 211 | "source": [ 212 | "response_complete = model.chat(\n", 213 | " messages=messages,\n", 214 | " functions=[schema],\n", 215 | " function_call={\"name\": \"extract_complex_entities\"}\n", 216 | ")" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": 8, 222 | "metadata": {}, 223 | "outputs": [ 224 | { 225 | "data": { 226 | "text/plain": [ 227 | "{'locations': [{'country': 'United States',\n", 228 | " 'placeName': 'Seattle, Washington'},\n", 229 | " {'country': 'United States', 'placeName': 'White Plains, New York'},\n", 230 | " {'country': 'South Africa', 'placeName': 'Pretoria, South Africa'},\n", 231 | " {'country': 'United States', 'placeName': 'Menlo Park, California'},\n", 232 | " {'country': 'United States', 'placeName': 'Hawthorne, California'},\n", 233 | " {'country': 'United States', 'placeName': 'Palo Alto, California'}],\n", 234 | " 'organizations': [{'location': 'Seattle, Washington',\n", 235 | " 'orgName': 'Microsoft',\n", 236 | " 'orgType': 'software company'},\n", 237 | " {'location': 'Menlo Park, California',\n", 238 | " 'orgName': 'Facebook (Meta)',\n", 239 | " 'orgType': 'social media company'},\n", 240 | " {'location': 'Hawthorne, California',\n", 241 | " 'orgName': 'SpaceX',\n", 242 | " 'orgType': 'aerospace manufacturer and space transportation services company'},\n", 243 | " {'location': 'Palo Alto, California',\n", 244 | " 'orgName': 'Tesla',\n", 245 | " 'orgType': 'electric vehicle and clean energy company'}],\n", 246 | " 'persons': [{'birthPlace': 'Seattle, Washington',\n", 247 | " 'education': ['Harvard University'],\n", 248 | " 'name': 'Bill Gates',\n", 249 | " 'roles': ['co-founder of Microsoft', 'philanthropist']},\n", 250 | " {'birthPlace': 'White Plains, New York',\n", 251 | " 'education': ['Harvard University'],\n", 252 | " 'name': 'Mark Zuckerberg',\n", 253 | " 'roles': ['co-founder and CEO of Facebook (Meta)']},\n", 254 | " {'birthPlace': 'Pretoria, South Africa',\n", 255 | " 'education': ['University of Pennsylvania'],\n", 256 | " 'name': 'Elon Musk',\n", 257 | " 'roles': ['founder of SpaceX', 'co-founder of Tesla']}]}" 258 | ] 259 | }, 260 | "execution_count": 8, 261 | "metadata": {}, 262 | "output_type": "execute_result" 263 | } 264 | ], 265 | "source": [ 266 | "response_complete['response']['choices'][0]['message']['function_call']['arguments']" 267 | ] 268 | } 269 | ], 270 | "metadata": { 271 | "kernelspec": { 272 | "display_name": "Python 3", 273 | "language": "python", 274 | "name": "python3" 275 | }, 276 | "language_info": { 277 | "codemirror_mode": { 278 | "name": "ipython", 279 | "version": 3 280 | }, 281 | "file_extension": ".py", 282 | "mimetype": "text/x-python", 283 | "name": "python", 284 | "nbconvert_exporter": "python", 285 | "pygments_lexer": "ipython3", 286 | "version": "3.12.3" 287 | } 288 | }, 289 | "nbformat": 4, 290 | "nbformat_minor": 2 291 | } 292 | -------------------------------------------------------------------------------- /src/gigasmol/models.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | import json 3 | import logging 4 | from copy import deepcopy 5 | from typing import List, Dict, Optional, Any, Tuple, Union, Literal 6 | 7 | from smolagents.tools import Tool 8 | from smolagents.models import Model, MessageRole, ChatMessage 9 | from smolagents.monitoring import TokenUsage 10 | from huggingface_hub import ChatCompletionOutputToolCall 11 | 12 | from .gigachat_api.api_model import DialogRole, GigaChat, MessageList 13 | 14 | 15 | TOOL_ROLE_CONVERSIONS = { 16 | MessageRole.TOOL_CALL: DialogRole.ASSISTANT, 17 | MessageRole.TOOL_RESPONSE: DialogRole.USER, 18 | MessageRole.ASSISTANT: DialogRole.ASSISTANT, 19 | MessageRole.USER: DialogRole.USER, 20 | MessageRole.SYSTEM: DialogRole.SYSTEM, 21 | } 22 | 23 | 24 | def remove_stop_sequences(content: str, stop_sequences: List[str]) -> str: 25 | """Remove stop sequences from the end of content string. 26 | 27 | This function checks if the content ends with any of the provided stop sequences 28 | and removes them if found. 29 | 30 | Args: 31 | content: The string content to process. 32 | stop_sequences: A list of stop sequences to check for and remove. 33 | 34 | Returns: 35 | str: The content with any matching stop sequences removed from the end. 36 | """ 37 | for stop_seq in stop_sequences: 38 | if content[-len(stop_seq) :] == stop_seq: 39 | content = content[: -len(stop_seq)] 40 | return content 41 | 42 | 43 | def parse_json_if_needed(arguments: Union[str, dict]) -> Union[str, dict]: 44 | """Parse JSON string to dictionary if needed. 45 | 46 | This function checks if the input is already a dictionary. If not, it attempts 47 | to parse it as JSON. If parsing fails, it returns the original string. 48 | 49 | Args: 50 | arguments: Either a string potentially containing JSON or a dictionary. 51 | 52 | Returns: 53 | Union[str, dict]: The parsed dictionary if successful, or the original input. 54 | """ 55 | if isinstance(arguments, dict): 56 | return arguments 57 | else: 58 | try: 59 | return json.loads(arguments) 60 | except Exception: 61 | return arguments 62 | 63 | 64 | def parse_tool_args_if_needed(message: ChatMessage) -> ChatMessage: 65 | """Parse tool call arguments from JSON strings to dictionaries if needed. 66 | 67 | This function processes a ChatMessage object, checking if it contains tool calls. 68 | For each tool call, it attempts to parse the function arguments from JSON string 69 | to dictionary format if they're not already dictionaries. 70 | 71 | Args: 72 | message: A ChatMessage object that may contain tool calls. 73 | 74 | Returns: 75 | ChatMessage: The same message object with tool call arguments parsed if needed. 76 | """ 77 | if message.tool_calls is not None: 78 | for tool_call in message.tool_calls: 79 | tool_call.function.arguments = parse_json_if_needed(tool_call.function.arguments) 80 | return message 81 | 82 | 83 | def get_tool_json_schema_gigachat(tool: Tool) -> Dict: 84 | """Convert a Tool object to a GigaChat-compatible function schema. 85 | 86 | This function transforms a smolagents Tool object into a JSON schema format 87 | that is compatible with GigaChat's function calling API. It handles type 88 | conversions and determines required parameters. 89 | 90 | Args: 91 | tool: A Tool object containing name, description, and input specifications. 92 | 93 | Returns: 94 | Dict: A dictionary representing the function schema in GigaChat's expected format 95 | """ 96 | properties = deepcopy(tool.inputs) 97 | required = [] 98 | for key, value in properties.items(): 99 | if value["type"] == "any": 100 | value["type"] = "string" 101 | if not ("nullable" in value and value["nullable"]): 102 | required.append(key) 103 | return { 104 | "name": tool.name, 105 | "description": tool.description, 106 | "parameters": { 107 | "type": "object", 108 | "properties": properties, 109 | "required": required, 110 | }, 111 | } 112 | 113 | 114 | def map_message_roles_to_api_format(messages: List[Dict[str, Any]]) -> List[Dict[DialogRole, str]]: 115 | """Convert smolagents message format to GigaChat API format. 116 | 117 | This function transforms messages from the smolagents format to the format 118 | expected by the GigaChat API. It converts message roles using the TOOL_ROLE_CONVERSIONS 119 | mapping and extracts the text content from each message. 120 | 121 | Args: 122 | messages: A list of message dictionaries in smolagents format, each containing 123 | 'role' and 'content' keys. 124 | 125 | Returns: 126 | List[Dict[DialogRole, str]]: A list of tuples containing the converted DialogRole 127 | and the message text content. 128 | """ 129 | converted_messages = [] 130 | for message in messages: 131 | message_role = TOOL_ROLE_CONVERSIONS[message.role] 132 | message_content = message.content[0]['text'] 133 | converted_messages.append({"role": message_role, "content": message_content}) 134 | return converted_messages 135 | 136 | 137 | def extract_tool_calls(response: Dict[str, Any]) -> Optional[ChatCompletionOutputToolCall]: 138 | """Extract and format tool calls from a raw GigaChat API response. 139 | 140 | This utility function processes a raw GigaChat response and extracts any function 141 | calls, formatting them into the standardized structure with unique IDs. 142 | 143 | Args: 144 | response: The raw response from GigaChat API 145 | 146 | Returns: 147 | ChatCompletionOutputToolCall 148 | """ 149 | tool_calls = [] 150 | for choice in response['response']['choices']: 151 | if 'message' in choice and 'function_call' in choice['message']: 152 | func_call = choice['message']['function_call'] 153 | call_id = f"call_{str(uuid.uuid4())[:8]}" 154 | arguments = func_call['arguments'] 155 | if isinstance(arguments, dict): 156 | arguments = json.dumps(arguments) 157 | 158 | formatted_call = { 159 | "id": call_id, 160 | "type": "function", 161 | "function": { 162 | "name": func_call['name'], 163 | "arguments": arguments 164 | } 165 | } 166 | tool_calls.append(formatted_call) 167 | return ChatCompletionOutputToolCall.parse_obj(tool_calls) if tool_calls else None 168 | 169 | 170 | def create_final_answer_tool_call(answer: str) -> ChatCompletionOutputToolCall: 171 | """Create a FinalAnswerTool call with the given answer. 172 | 173 | This helper method creates a properly formatted tool call for the FinalAnswerTool 174 | using the provided answer as the argument. 175 | 176 | Args: 177 | answer: The text answer to include in the tool call. 178 | 179 | Returns: 180 | ChatCompletionOutputToolCall: A formatted tool call for FinalAnswerTool. 181 | """ 182 | call_id = f"call_{str(uuid.uuid4())[:8]}" 183 | final_answer_call = [{ 184 | "id": call_id, 185 | "type": "function", 186 | "function": { 187 | "name": "final_answer", 188 | "arguments": json.dumps({"answer": answer}) 189 | } 190 | }] 191 | return ChatCompletionOutputToolCall.parse_obj(final_answer_call) 192 | 193 | 194 | class GigaChatSmolModel(Model): 195 | """A wrapper for the GigaChat model that implements the smolagents Model interface. 196 | 197 | This class handles communication with the GigaChat API, including authentication, 198 | message formatting, and response processing. 199 | 200 | Attributes: 201 | model_name: The name of the GigaChat model to use. 202 | temperature: Controls randomness in generation (0.0-1.0). 203 | top_p: Controls diversity via nucleus sampling (0.0-1.0). 204 | repetition_penalty: Penalizes repetition in generated text (>= 1.0). 205 | max_tokens: Maximum number of tokens to generate. 206 | profanity_check: Whether to enable profanity filtering. 207 | auth: Authentication handler for the GigaChat API. 208 | gigachat_instance: The underlying GigaChat client. 209 | """ 210 | 211 | def __init__( 212 | self, 213 | auth_data: str, 214 | model_name: str = "GigaChat", 215 | api_endpoint: str = "https://gigachat.devices.sberbank.ru/api/v1/", 216 | temperature: float = 0.1, 217 | top_p: float = 0.1, 218 | repetition_penalty: float = 1.0, 219 | max_tokens: int = 1500, 220 | profanity_check: bool = True, 221 | client_id: Optional[str] = None, 222 | auth_endpoint: str = "https://ngw.devices.sberbank.ru:9443/api/v2/oauth", 223 | auth_scope: Literal["GIGACHAT_API_PERS", "GIGACHAT_API_CORP", "GIGACHAT_API_B2B"] = "GIGACHAT_API_CORP", 224 | cert_path: Optional[str] = None, 225 | ) -> None: 226 | """Initialize a new GigaChatModel instance. 227 | 228 | Args: 229 | auth_data: Authorization key for exchanging messages with GigaChat API 230 | model_name: The name of the GigaChat model to use. 231 | api_endpoint: The GigaChat API endpoint URL. 232 | temperature: Controls randomness in generation (0.0-1.0). 233 | top_p: Controls diversity via nucleus sampling (0.0-1.0). 234 | repetition_penalty: Penalizes repetition in generated text (>= 1.0). 235 | max_tokens: Maximum number of tokens to generate. 236 | profanity_check: Whether to enable profanity filtering. 237 | client_id: The client ID for API authentication. 238 | auth_endpoint: The authentication endpoint URL. 239 | auth_scope: The authentication scope. 240 | cert_path: Path to the certificate file for API authentication. 241 | """ 242 | super().__init__() 243 | self.model_name = model_name 244 | self.temperature = temperature 245 | self.top_p = top_p 246 | self.repetition_penalty = repetition_penalty 247 | self.max_tokens = max_tokens 248 | self.profanity_check = profanity_check 249 | self.gigachat_instance = GigaChat( 250 | auth_data=auth_data, 251 | model_name=self.model_name, 252 | api_endpoint=api_endpoint, 253 | temperature=self.temperature, 254 | top_p=self.top_p, 255 | repetition_penalty=self.repetition_penalty, 256 | max_tokens=self.max_tokens, 257 | profanity_check=self.profanity_check, 258 | client_id=client_id, 259 | auth_endpoint=auth_endpoint, 260 | auth_scope=auth_scope, 261 | cert_path=cert_path 262 | ) 263 | 264 | def generate( 265 | self, 266 | messages: List[Dict[str, Any]], 267 | stop_sequences: Optional[List[str]] = None, 268 | grammar: Optional[str] = None, 269 | tools_to_call_from: Optional[List[Tool]] = None, 270 | ) -> ChatMessage: 271 | try: 272 | messages = map_message_roles_to_api_format(messages) 273 | functions = [get_tool_json_schema_gigachat(tool) for tool in tools_to_call_from] if tools_to_call_from else None 274 | response = self.chat(messages=messages, functions=functions) 275 | answer = response.get('answer', '') 276 | tool_calls = extract_tool_calls(response) 277 | 278 | if tool_calls is None and tools_to_call_from is not None: 279 | tool_calls = create_final_answer_tool_call(answer) 280 | 281 | if stop_sequences and isinstance(stop_sequences, list): 282 | answer = remove_stop_sequences(answer, stop_sequences) 283 | 284 | return parse_tool_args_if_needed( 285 | ChatMessage( 286 | role="assistant", 287 | content=answer, 288 | tool_calls=tool_calls, 289 | raw=response['response'], 290 | token_usage=TokenUsage( 291 | input_tokens=response['response']['usage']['prompt_tokens'], 292 | output_tokens=response['response']['usage']['completion_tokens'] 293 | ) 294 | ) 295 | ) 296 | except Exception as e: 297 | logging.error(f"Critical error in __call__: {str(e)}", exc_info=True) 298 | return ChatMessage( 299 | role="assistant", 300 | content=f"Error in model execution: {str(e)}" 301 | ) 302 | 303 | def chat( 304 | self, 305 | messages: MessageList, 306 | params: Optional[Dict[str, Any]] = None, 307 | functions: Optional[List[Dict[str, Any]]] = None, 308 | function_call: Optional[Union[str, Dict[str, str]]] = None, 309 | ) -> Dict[str, Any]: 310 | return self.gigachat_instance.chat(messages, params, functions, function_call) 311 | 312 | def get_available_models(self) -> List[str]: 313 | return self.gigachat_instance._get_list_model() -------------------------------------------------------------------------------- /src/gigasmol/gigachat_api/api_model.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from datetime import datetime 4 | from typing import Any, Iterator, Optional, Union, Literal, Dict, List, Tuple 5 | 6 | import requests 7 | from sseclient import SSEClient 8 | 9 | from .auth import APIAuthorize, LLMAuthorizeEnablers, SberDSAuthorize 10 | 11 | try: 12 | from enum import StrEnum 13 | except ImportError: 14 | from enum import Enum 15 | class StrEnum(str, Enum): 16 | pass 17 | 18 | 19 | class DialogRole(StrEnum): 20 | """Roles for GigaChat conversations.""" 21 | SYSTEM = "system" 22 | USER = "user" 23 | ASSISTANT = "assistant" 24 | FUNCTION = "function" 25 | 26 | 27 | MessageList = List[Dict[str, Any]] 28 | 29 | 30 | class GigaChat: 31 | """Access to LLM based on GigaChat API.""" 32 | 33 | def __init__( 34 | self, 35 | auth_data: Optional[str] = None, 36 | model_name: str = 'GigaChat', 37 | api_endpoint: str = "https://gigachat.devices.sberbank.ru/api/v1/", 38 | sber_ds: bool = False, 39 | authorize: Optional[APIAuthorize] = None, 40 | temperature: float = 0.1, 41 | top_p: float = 0.1, 42 | repetition_penalty: float = 1.0, 43 | max_tokens: int = 5000, 44 | n: int = 1, 45 | n_stream: int = 1, 46 | profanity_check: bool = True, 47 | client_id: Optional[str] = None, 48 | auth_endpoint: Optional[str] = "https://ngw.devices.sberbank.ru:9443/api/v2/oauth", 49 | auth_scope: Literal["GIGACHAT_API_PERS", "GIGACHAT_API_CORP", "GIGACHAT_API_B2B"] = "GIGACHAT_API_CORP", 50 | cert_path: Optional[str] = None, 51 | ) -> None: 52 | """Initialize with GigaChat API access parameters and response generation settings. 53 | 54 | Args: 55 | auth_data: Authorization key for exchanging messages with GigaChat API 56 | model_name: Model name of the GigaChat to use. 57 | api_endpoint: GigaChat API URL 58 | sber_ds: Using on sber-ds platform 59 | authorize: Authorization method for GigaChat API. If not provided, LLMAuthorizeEnablers will be used. 60 | temperature: Temperature parameter; higher values produce more diverse outputs (typical value 0.7). 61 | top_p: Another parameter for output diversity (typical value 0.1). 62 | repetition_penalty: Controls word repetition. Value 1.0 is neutral, 0-1 increases repetition, >1 decreases repetition. 63 | max_tokens: Maximum number of tokens to generate in the response. 64 | n: Number of completions to generate (non-streaming mode). 65 | n_stream: Number of completions to generate (streaming mode). 66 | profanity_check: Whether to enable profanity checking. 67 | client_id: GigaChat API client ID (used as RqUID). 68 | auth_endpoint: The authentication endpoint URL. 69 | auth_scope: The authentication scope. Contains information about the API version being accessed. 70 | cert_path: Path to the certificate for GigaChat API access. 71 | """ 72 | self.api_endpoint = api_endpoint 73 | self.sber_ds = sber_ds 74 | self.temperature = temperature 75 | self.top_p = top_p 76 | self.repetition_penalty = repetition_penalty 77 | self.model_name = model_name 78 | self.max_tokens = max_tokens 79 | self.n = n 80 | self.n_stream = n_stream 81 | self.profanity_check = profanity_check 82 | 83 | assert auth_data is not None or sber_ds is True, "auth_data is required for non-sber_ds mode" 84 | 85 | if authorize is not None: 86 | self.__authorize = authorize 87 | elif not sber_ds: 88 | self.__authorize = LLMAuthorizeEnablers( 89 | auth_data=auth_data, 90 | client_id=client_id, 91 | auth_endpoint=auth_endpoint, 92 | auth_scope=auth_scope, 93 | cert_path=cert_path 94 | ) 95 | else: 96 | self.__authorize = SberDSAuthorize() 97 | self.__token_expiration_time = datetime.min 98 | 99 | def _get_list_model(self) -> List[str]: 100 | """Get a list of available models. 101 | 102 | The model list is useful when multiple models are available on the inference endpoint 103 | and a specific large or small model is needed. 104 | 105 | Returns: 106 | list[str]: List of available model names. 107 | 108 | Raises: 109 | RuntimeError: If the API response is invalid or an error occurs. 110 | """ 111 | try: 112 | url = f"{self.api_endpoint}models" 113 | headers = {} 114 | 115 | if not self.sber_ds: 116 | headers["Authorization"] = f"Bearer {self.__authorize.token}" 117 | else: 118 | headers["Accept"] = "application/json" 119 | 120 | response = requests.get(url, headers=headers, verify=self.__authorize.cert_path) 121 | if response.status_code == 200: 122 | models = json.loads(response.content) 123 | if "data" not in models: 124 | raise RuntimeError("Incorrect response from GigaChat API") 125 | logging.info(f"Available models: {models['data']}") 126 | return models["data"] 127 | else: 128 | raise RuntimeError(f"Error in getting list models: {response.status_code} - {response.content}") 129 | except Exception as e: 130 | logging.error(f"{str(e)}") 131 | raise e 132 | 133 | def _prepare_request( 134 | self, 135 | params: Dict[str, Any], 136 | messages: MessageList, 137 | stream: bool, 138 | functions: Optional[List[Dict[str, Any]]] = None, 139 | function_call: Optional[Union[str, Dict[str, str]]] = None, 140 | ) -> Tuple[str, Dict[str, Any], str, str]: 141 | """Prepare and return all necessary parameters for a Gigachat API request. 142 | 143 | Args: 144 | params: Additional parameters for the request. 145 | messages: List of formatted messages. 146 | stream: Whether to use streaming mode. 147 | functions: Optional list of function definitions. 148 | function_call: Optional function call specification. 149 | 150 | Returns: 151 | tuple: URL, headers, JSON query, and model name. 152 | """ 153 | temperature = params.get("temperature", self.temperature) 154 | top_p = params.get("top_p", self.top_p) 155 | repetition_penalty = params.get("repetition_penalty", self.repetition_penalty) 156 | model_name: str = params.get("model_name", self.model_name) 157 | max_tokens: int = params.get("max_tokens", self.max_tokens) 158 | profanity_check: bool = params.get("profanity_check", self.profanity_check) 159 | n = self.n_stream if stream else self.n 160 | 161 | url = f"{self.api_endpoint}chat/completions" 162 | headers = {"Content-Type": "application/json"} 163 | 164 | if not self.sber_ds: 165 | headers["Authorization"] = f"Bearer {self.__authorize.token}" 166 | else: 167 | headers["Accept"] = "application/json" 168 | 169 | params = { 170 | "model": f"{model_name}", 171 | "messages": messages, 172 | "temperature": temperature, 173 | "top_p": top_p, 174 | "n": n, 175 | "repetition_penalty": repetition_penalty, 176 | "profanity_check": bool(profanity_check), 177 | "stream": bool(stream), 178 | "max_tokens": int(max_tokens), 179 | } 180 | if functions is not None: 181 | params["functions"] = functions 182 | if function_call is not None: 183 | params["function_call"] = function_call 184 | else: 185 | params["function_call"] = "auto" 186 | 187 | if stream: 188 | params["update_interval"] = 0 189 | query = json.dumps(params, ensure_ascii=False) 190 | return url, headers, query, model_name 191 | 192 | def get_model_name(self) -> str: 193 | """Return the name of the default selected model. 194 | 195 | Returns: 196 | str: The model name. 197 | """ 198 | return self.model_name 199 | 200 | def complete( 201 | self, 202 | prompt: str, 203 | system_prompt: Optional[str] = None, 204 | params: Optional[Dict[str, Any]] = None, 205 | functions: Optional[List[Dict[str, Any]]] = None, 206 | function_call: Optional[Union[str, Dict[str, str]]] = None, 207 | ) -> Dict[str, Any]: 208 | """Send a prompt to the model and return the model's response. 209 | 210 | Args: 211 | prompt: The user's prompt text. 212 | system_prompt: Optional system instructions. 213 | params: Additional parameters for controlling the response generation. 214 | functions: Optional list of function definitions for function calling. 215 | function_call: Optional function call specification. 216 | 217 | Returns: 218 | dict: A dictionary containing the LLM response and metadata. 219 | """ 220 | messages = [] 221 | if system_prompt is not None: 222 | messages.append({"role": "system", "content": system_prompt}) 223 | messages.append({"role": "user", "content": prompt}) 224 | return self.chat(messages=messages, params=params, functions=functions, function_call=function_call) 225 | 226 | def complete_stream( 227 | self, 228 | prompt: str, 229 | system_prompt: Optional[str] = None, 230 | params: Optional[Dict[str, Any]] = None, 231 | functions: Optional[List[Dict[str, Any]]] = None, 232 | function_call: Optional[Union[str, Dict[str, str]]] = None, 233 | ) -> Iterator[str]: 234 | """Send a prompt to the model and return the model's response as a stream. 235 | 236 | Args: 237 | prompt: The user's prompt text. 238 | system_prompt: Optional system instructions. 239 | params: Additional parameters for controlling the response generation. 240 | functions: Optional list of function definitions for function calling. 241 | function_call: Optional function call specification. 242 | 243 | Returns: 244 | Iterator[str]: An iterator that yields response tokens as they're generated. 245 | """ 246 | messages = [] 247 | if system_prompt is not None: 248 | messages.append({"role": "system", "content": system_prompt}) 249 | messages.append({"role": "user", "content": prompt}) 250 | return self.chat_stream(messages=messages, params=params, functions=functions, function_call=function_call) 251 | 252 | def tokens(self, text: str, params: Dict[str, Any]) -> int: 253 | """Count the number of tokens in the given text string. 254 | 255 | Args: 256 | text: The text string to count tokens for. 257 | params: Additional parameters as a dictionary. 258 | 259 | Returns: 260 | int: The number of tokens in the provided text. 261 | 262 | Raises: 263 | RuntimeError: If the API response is invalid or an error occurs. 264 | """ 265 | try: 266 | model_name = params.get("model_name", self.model_name) 267 | headers = {"Content-Type": "application/json"} 268 | 269 | if not self.sber_ds: 270 | headers["Authorization"] = f"Bearer {self.__authorize.token}" 271 | else: 272 | headers["Accept"] = "application/json" 273 | 274 | query = json.dumps({"model": f"{model_name}", "input": [text]}) 275 | url = f"{self.api_endpoint}tokens/count" 276 | logging.debug(f"Count tokens on LLM '{model_name}' and prompt: {text}...") 277 | response = requests.post(url, data=query, headers=headers, verify=self.__authorize.cert_path) 278 | if response.status_code == 200: 279 | answer = json.loads(response.content) 280 | if "tokens" not in answer[0] or "characters" not in answer[0]: 281 | raise RuntimeError("Incorrect response from GigaChat API") 282 | return answer[0]["tokens"] 283 | else: 284 | raise RuntimeError(f"Error in running tokens method: {response.status_code} - {response.content}") 285 | except Exception as e: 286 | logging.error(f"{str(e)}") 287 | raise e 288 | 289 | def chat( 290 | self, 291 | messages: MessageList, 292 | params: Optional[Dict[str, Any]] = None, 293 | functions: Optional[List[Dict[str, Any]]] = None, 294 | function_call: Optional[Union[str, Dict[str, str]]] = None, 295 | ) -> Dict[str, Any]: 296 | """Send a list of messages to the model and return the model's response. 297 | 298 | Args: 299 | messages: Messages either as a list of (role, message) tuples or 300 | a list of dictionaries with 'role' and 'content' keys. 301 | params: Additional parameters for controlling the response generation. 302 | functions: Optional list of function definitions for function calling. 303 | function_call: Optional function call specification. 304 | 305 | Returns: 306 | Dict: A dictionary containing the LLM response and metadata. 307 | 308 | Raises: 309 | RuntimeError: If the API response is invalid or an error occurs. 310 | """ 311 | try: 312 | params = params if params is not None else {} 313 | url, headers, query, model_name = self._prepare_request( 314 | params=params, 315 | messages=messages, 316 | stream=False, 317 | functions=functions, 318 | function_call=function_call, 319 | ) 320 | logging.debug(f"Run chat on LLM '{model_name}' and chat: {messages}...") 321 | response = requests.post(url, data=query, headers=headers, verify=self.__authorize.cert_path) 322 | if response.status_code == 200: 323 | answer = json.loads(response.content) 324 | if "choices" not in answer or "usage" not in answer: 325 | raise RuntimeError("Incorrect response from GigaChat API") 326 | complete = answer["choices"][0]["message"]["content"] 327 | prompt_tokens = answer["usage"]["prompt_tokens"] 328 | completion_tokens = answer["usage"]["completion_tokens"] 329 | finish_reason = answer["choices"][0]["finish_reason"] 330 | logging.debug(f"answer: {complete}") 331 | logging.debug(f"prompt tokens number: {prompt_tokens} complete tokens number: {completion_tokens}") 332 | return { 333 | "answer": complete, 334 | "response": answer, 335 | "prompt_tokens": prompt_tokens, 336 | "answer_tokens": completion_tokens, 337 | "finish_reason": finish_reason, 338 | "info": json.dumps({'model':response.json()['model'], 339 | 'x-request-id':response.headers.get('x-request-id'), 340 | 'x-session-id':response.headers.get('x-session-id')}) 341 | } 342 | else: 343 | raise RuntimeError(f"Error in running complete method: {response.status_code} - {response.content}") 344 | except Exception as e: 345 | logging.error(f"{str(e)}") 346 | raise e 347 | 348 | def chat_stream( 349 | self, 350 | messages: MessageList, 351 | params: Optional[Dict[str, Any]] = None, 352 | functions: Optional[List[Dict[str, Any]]] = None, 353 | function_call: Optional[Union[str, Dict[str, str]]] = None, 354 | ) -> Iterator[str]: 355 | """Send a list of messages to the model and return the model's response as a stream. 356 | 357 | Args: 358 | messages: Messages either as a list of (role, message) tuples or 359 | a list of dictionaries with 'role' and 'content' keys. 360 | params: Additional parameters for controlling the response generation. 361 | functions: Optional list of function definitions for function calling. 362 | function_call: Optional function call specification. 363 | 364 | Returns: 365 | Iterator[str]: An iterator that yields response tokens as they're generated. 366 | 367 | Raises: 368 | RuntimeError: If the API response is invalid or an error occurs. 369 | """ 370 | try: 371 | params = params if params is not None else {} 372 | url, headers, query, model_name = self._prepare_request( 373 | params=params, 374 | messages=messages, 375 | stream=True, 376 | functions=functions, 377 | function_call=function_call, 378 | ) 379 | logging.debug(f"Run chat on LLM '{model_name}' and chat: {messages}...") 380 | response = requests.post(url, data=query, headers=headers, verify=self.__authorize.cert_path, stream=True) 381 | if response.status_code == 200: 382 | finish_reason = str() 383 | client = SSEClient(response) 384 | for event in client.events(): 385 | if event.data == "[DONE]": 386 | return {"finish_reason": finish_reason} 387 | event_data = json.loads(event.data) 388 | if "choices" not in event_data: 389 | raise RuntimeError("No choices field in response from GigaChat API") 390 | if "finish_reason" in event_data["choices"][0]: 391 | finish_reason = event_data["choices"][0]["finish_reason"] 392 | token = event_data["choices"][0]["delta"]["content"] 393 | logging.debug(f"token: {token}") 394 | yield token 395 | else: 396 | raise RuntimeError(f"Error in running complete method: {response.status_code} - {response.content}") 397 | except Exception as e: 398 | logging.error(f"{str(e)}") 399 | raise e 400 | 401 | def check_chat_profanity(self, messages: MessageList) -> bool: 402 | """Check if the given message history contains prohibited content. 403 | 404 | Args: 405 | messages: The message history as a list of (role, message) tuples. 406 | 407 | Returns: 408 | bool: True if the history contains prohibited content, False otherwise. 409 | """ 410 | try: 411 | check_result = self.chat(messages=messages, params={"profanity_check": True}) 412 | return True if check_result["finish_reason"] == "blacklist" else False 413 | except Exception as err: 414 | logging.error(f"Error in checking censorship for chat: '{messages}', error: '{err}'") 415 | raise err 416 | 417 | def check_question_profanity(self, question: str) -> bool: 418 | """Check if the given question contains prohibited content. 419 | 420 | Args: 421 | question: The question text. 422 | 423 | Returns: 424 | bool: True if the question contains prohibited content, False otherwise. 425 | """ 426 | try: 427 | check_result = self.complete(prompt=question, params={"profanity_check": True}) 428 | return True if check_result["finish_reason"] == "blacklist" else False 429 | except Exception as err: 430 | logging.error(f"Error in checking censorship for question: '{question}', error: '{err}'") 431 | raise err 432 | 433 | 434 | class GigaFilter: 435 | """Class for direct interaction with GigaFilter for profanity checking.""" 436 | 437 | def __init__(self, api_endpoint: str, authorize: APIAuthorize) -> None: 438 | """Initialize GigaFilter. 439 | 440 | Args: 441 | api_endpoint: The API endpoint URL. 442 | authorize: Authorization method for the API. 443 | """ 444 | self.api_endpoint = api_endpoint 445 | self.__authorize = authorize 446 | 447 | def check_profanity(self, question: str, return_json: bool = False) -> Union[bool, Dict[str, Any]]: 448 | """Check if the given text contains profanity. 449 | 450 | Args: 451 | question: The text to check. 452 | return_json: If True, return the full JSON response; if False, return only a boolean result. 453 | 454 | Returns: 455 | Union[bool, dict]: Either a boolean indicating whether profanity was detected, 456 | or the complete JSON response if return_json is True. 457 | """ 458 | url = f"{self.api_endpoint}filter/check" 459 | headers = { 460 | "Authorization": f"Bearer {self.__authorize.token}", 461 | "Content-Type": "application/json", 462 | } 463 | query = { 464 | "model": "GigaFilter", 465 | "messages": [{ 466 | "content": question, 467 | "role": "user" 468 | }] 469 | } 470 | query = json.dumps(query) 471 | response = requests.post(url, data=query, headers=headers, verify=self.__authorize.cert_path) 472 | if return_json: 473 | return response.json() 474 | else: 475 | return response.json()['is_profane'] 476 | --------------------------------------------------------------------------------