├── src
└── gigasmol
│ ├── gigachat_api
│ ├── __init__.py
│ ├── auth.py
│ └── api_model.py
│ ├── __init__.py
│ └── models.py
├── .gitignore
├── assets
└── logo.png
├── pyproject.toml
├── README.md
├── LICENSE
└── examples
└── structured_output.ipynb
/src/gigasmol/gigachat_api/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | .DS_Store
3 | credentials.json
4 | *.egg-info/
5 | dist/
--------------------------------------------------------------------------------
/assets/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poteminr/gigasmol/HEAD/assets/logo.png
--------------------------------------------------------------------------------
/src/gigasmol/__init__.py:
--------------------------------------------------------------------------------
1 | """GigaSmol - Lightweight GigaChat API wrapper for smolagents"""
2 |
3 | from gigasmol.gigachat_api.api_model import GigaChat
4 |
5 | __version__ = "0.0.8"
6 |
7 | __all__ = ["GigaChat"]
8 |
9 | try:
10 | from gigasmol.models import GigaChatSmolModel
11 | __all__.append("GigaChatSmolModel")
12 | except ImportError:
13 | # For API-only installations where smolagents is not installed
14 | class GigaChatSmolModel:
15 | """This class requires the smolagents package.
16 |
17 | Install with: pip install gigasmol
18 | """
19 | def __init__(self, *args, **kwargs):
20 | raise ImportError(
21 | 'The smolagents package is required to use GigaChatSmolModel. '
22 | 'Please install the full package with `pip install "gigasmol[agent]"`.'
23 | )
24 |
25 | __all__.append("GigaChatSmolModel")
26 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "gigasmol"
7 | version = "0.0.8"
8 | authors = [
9 | {name = "poteminr", email = "poteminr@gmail.com"},
10 | ]
11 | description = "A lightweight wrapper for gigachat api model for seamless use with hf smolagents"
12 | readme = "README.md"
13 | requires-python = ">=3.8"
14 | license = {text = "LICENSE"}
15 | # Base dependencies - ONLY the API parts
16 | dependencies = [
17 | "requests>=2.31.0",
18 | "sseclient>=0.0.27",
19 | "urllib3>=2.0.0",
20 | ]
21 |
22 | [project.urls]
23 | "Homepage" = "https://github.com/poteminr/gigasmol"
24 | "Bug Tracker" = "https://github.com/poteminr/gigasmol/issues"
25 |
26 | [tool.setuptools]
27 | package-dir = {"" = "src"}
28 |
29 | [project.optional-dependencies]
30 | # Agent functionality requires Python 3.10+
31 | agent = [
32 | "smolagents==1.22.0; python_version >= '3.10'",
33 | "huggingface-hub>=0.19.0; python_version >= '3.10'",
34 | "gigasmol-agent-requires-python-3.10-or-newer==0.0.0; python_version < '3.10'",
35 | ]
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
7 |
8 |
9 |
10 |
11 |

12 |
lightweight gigachat api wrapper for smolagents
13 |
14 |
15 | ## Overview
16 |
17 | gigasmol serves two primary purposes:
18 |
19 | 1. Provides **direct, lightweight access** to GigaChat models through GigaChat API without unnecessary abstractions
20 | 2. Creates a **smolagents-compatible wrapper** that lets you use GigaChat within agent systems
21 |
22 | No complex abstractions — just clean, straightforward access to GigaChat's capabilities through smolagents.
23 |
24 | ```
25 | GigaChat API + smolagents = gigasmol 💀
26 | ```
27 |
28 | ## Why gigasmol 💀?
29 |
30 | - **Tiny Footprint**: Less than 1K lines of code total
31 | - **Simple Structure**: Just 4 core files
32 | - **Zero Bloat**: Only essential dependencies
33 | - **Easy to Understand**: Read and comprehend the entire codebase in minutes
34 | - **Maintainable**: Small, focused codebase means fewer bugs and easier updates
35 | ## Installation
36 | ### API-Only Installation (default)
37 | `python>=3.8`
38 | ```bash
39 | pip install gigasmol
40 | ```
41 |
42 | ### Full Installation with Agent Support
43 | `python>=3.10`
44 | ```bash
45 | pip install "gigasmol[agent]"
46 | ```
47 |
48 |
49 | ## Quick Start
50 | ### Raw GigaChat API
51 | `gigasmol`
52 |
53 |
54 | ```python
55 | import json
56 | from gigasmol import GigaChat
57 |
58 | # Direct access to GigaChat API
59 | gigachat = GigaChat(
60 | auth_data="YOUR_AUTH_TOKEN",
61 | model_name="GigaChat-Max",
62 | )
63 |
64 | # Generate a response
65 | response = gigachat.chat([
66 | {"role": "user", "content": "What is the capital of Russia?"}
67 | ])
68 | print(response['answer']) # or print(response['response']['choices'][0]['message']['content'])
69 | ```
70 | ### Usage with smolagents
71 | `gigasmol[agent]`
72 |
73 | ```python
74 | from gigasmol import GigaChatSmolModel
75 | from smolagents import CodeAgent, ToolCallingAgent, DuckDuckGoSearchTool
76 |
77 | # Initialize the GigaChat model with your credentials
78 | model = GigaChatSmolModel(
79 | auth_data="YOUR_AUTH_TOKEN",
80 | model_name="GigaChat-Max"
81 | )
82 |
83 | # Create a CodeAgent with the model
84 | code_agent = CodeAgent(
85 | tools=[DuckDuckGoSearchTool()],
86 | model=model
87 | )
88 |
89 | # Run the code_agent
90 | code_agent.run("What are the main tourist attractions in Moscow?")
91 |
92 | # Create a ToolCallingAgent with the model
93 | tool_calling_agent = ToolCallingAgent(
94 | tools=[DuckDuckGoSearchTool()],
95 | model=model
96 | )
97 |
98 | # Run the tool_calling_agent
99 | tool_calling_agent.run("What are the main tourist attractions in Moscow?")
100 | ```
101 |
102 |
103 |
104 | ## How It Works
105 |
106 | GigaSmol provides two layers of functionality:
107 |
108 | ```
109 | ┌───────────────────────────────────────────────────┐
110 | │ gigasmol │
111 | ├───────────────────────────────────────────────────┤
112 | │ ┌───────────────┐ ┌───────────────────┐ │
113 | │ │ Direct │ │ smolagents │ │
114 | │ │ GigaChat API │ │ compatibility │ │
115 | │ │ access │ │ layer │ │
116 | │ └───────────────┘ └───────────────────┘ │
117 | └───────────────────────────────────────────────────┘
118 | │ │
119 | ▼ ▼
120 | ┌─────────────┐ ┌────────────────┐
121 | │ GigaChat API│ │ Agent systems │
122 | └─────────────┘ └────────────────┘
123 | ```
124 |
125 | 1. **Direct API Access**: Use `GigaChat` for clean, direct access to the API
126 | 2. **smolagents Integration**: Use `GigaChatSmolModel` to plug GigaChat into smolagents
127 |
128 |
129 | ## Examples
130 |
131 | Check the `examples` directory:
132 | - `structured_output.ipynb`: Using GigaChat API and function_calling for structured output
133 | - `agents.ipynb`: Building code and tool agents with GigaChat and smolagents
134 |
135 | ## Acknowledgements
136 |
137 | - [SberDevices](https://gigachat.ru/) for creating the GigaChat API
138 | - [Hugging Face](https://huggingface.co/) for the smolagents framework
139 |
--------------------------------------------------------------------------------
/src/gigasmol/gigachat_api/auth.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | import datetime
3 | import logging
4 | from typing import Literal, Optional
5 | import uuid
6 |
7 | import requests
8 | import urllib3
9 | from urllib3.exceptions import InsecureRequestWarning
10 |
11 |
12 | class APIAuthorize(ABC):
13 | """Interface for authorization and obtaining access to various APIs.
14 |
15 | Each concrete authorization class receives user data and issues an API token.
16 | """
17 |
18 | @property
19 | @abstractmethod
20 | def token(self) -> str:
21 | """Returns the token for API access.
22 |
23 | Returns:
24 | str: Valid authentication token.
25 | """
26 | pass
27 |
28 | @property
29 | @abstractmethod
30 | def cert_path(self) -> str:
31 | """Returns the path to the certificate for API access.
32 |
33 | Returns:
34 | str: Path to the certificate file.
35 | """
36 | pass
37 |
38 |
39 | class LLMAuthorizeAdvanced(APIAuthorize):
40 | """Access to GigaChat via username/password authentication."""
41 | def __init__(
42 | self,
43 | username: str,
44 | password: str,
45 | auth_endpoint: str = "https://beta.saluteai.sberdevices.ru/v1/"
46 | ) -> None:
47 | """Initialize with GigaChat API access parameters.
48 |
49 | Args:
50 | username: GigaChat API username
51 | password: GigaChat API password
52 | """
53 | self.auth_endpoint = auth_endpoint
54 | self.__username = username
55 | self.__password = password
56 | self.__token_expiration_time = datetime.datetime.min
57 | self._token = ""
58 | self._set_token()
59 |
60 | urllib3.disable_warnings(
61 | urllib3.exceptions.InsecureRequestWarning
62 | )
63 |
64 | @property
65 | def token(self) -> str:
66 | """Get the current valid token, refreshing if necessary.
67 |
68 | Returns:
69 | str: Valid authentication token.
70 | """
71 | self._check_token()
72 | return self._token
73 |
74 | @token.setter
75 | def token(self, value: str) -> None:
76 | """Set the token value.
77 |
78 | Args:
79 | value: Token value to set.
80 | """
81 | self._token = value
82 |
83 | @property
84 | def cert_path(self) -> str:
85 | """Get the certificate path (not used in this implementation).
86 |
87 | Returns:
88 | str: Empty string as this implementation doesn't use certificates.
89 | """
90 | return ""
91 |
92 | @cert_path.setter
93 | def cert_path(self, value: str) -> None:
94 | """Certificate path setter (not used in this implementation).
95 |
96 | Args:
97 | value: Certificate path.
98 | """
99 | pass
100 |
101 | def _check_token(self) -> None:
102 | """Check if the current token is expired and refresh if needed."""
103 | if datetime.datetime.now() > self.__token_expiration_time:
104 | self._set_token()
105 |
106 | def _set_token(self) -> None:
107 | """Obtain a fresh access token from the API."""
108 | logger = logging.getLogger(__name__)
109 | token_url = f"{self.API_ENDPOINT}token"
110 | try:
111 | logger.info(
112 | "Getting token from GigaChat API",
113 | extra={"url": token_url, "username": self.__username}
114 | )
115 | response = requests.post(token_url, auth=(self.__username, self.__password))
116 | if response.status_code == 200:
117 | data = response.json()
118 | if "tok" not in data or "exp" not in data:
119 | raise ValueError("Incorrect response format from GigaChat API")
120 |
121 | self.__token_expiration_time = datetime.datetime.utcfromtimestamp(data["exp"])
122 | self._token = data["tok"]
123 | logger.info("Successfully obtained authentication token")
124 | else:
125 | raise ValueError(
126 | f"Failed to get token: HTTP {response.status_code} - {response.text}"
127 | )
128 | except requests.RequestException as e:
129 | logger.error(f"Network error during token retrieval: {str(e)}")
130 | raise
131 | except ValueError as e:
132 | logger.error(f"Error processing token response: {str(e)}")
133 | raise
134 | except Exception as e:
135 | logger.error(f"Unexpected error during token retrieval: {str(e)}")
136 | raise
137 |
138 |
139 | class LLMAuthorizeEnablers(APIAuthorize):
140 | """Access to GigaChat via client_id/client_secret with certificate."""
141 |
142 | def __init__(
143 | self,
144 | auth_data: str,
145 | client_id: Optional[str] = None,
146 | auth_endpoint: str = "https://ngw.devices.sberbank.ru:9443/api/v2/oauth",
147 | auth_scope: str = Literal["GIGACHAT_API_PERS", "GIGACHAT_API_CORP", "GIGACHAT_API_B2B"],
148 | cert_path: Optional[str] = None
149 | ) -> None:
150 | """Initialize with GigaChat API access parameters.
151 | Args:
152 | auth_data: Authorization key for exchanging messages with GigaChat API
153 | client_id: GigaChat API client ID (used as RqUID)
154 | auth_endpoint: The authentication endpoint URL
155 | auth_scope: The authentication scope. Contains information about the API version
156 | being accessed. If you are using the API version for individual
157 | entrepreneurs or legal entities, specify this explicitly in the scope
158 | parameter (GIGACHAT_API_PERS, GIGACHAT_API_CORP, or GIGACHAT_API_B2B)
159 | cert_path: Path to the certificate for GigaChat API access
160 | """
161 | self.auth_endpoint = auth_endpoint
162 | self.auth_scope = auth_scope
163 | self.__client_id = client_id if client_id is not None else str(uuid.uuid4())
164 | self.__auth_data = auth_data
165 | self.__token_expiration_time = datetime.datetime.min
166 | self._cert_path = cert_path if cert_path is not None else ""
167 | self._token = ""
168 | self._set_token()
169 |
170 | @property
171 | def token(self) -> str:
172 | """Get the current valid token, refreshing if necessary.
173 |
174 | Returns:
175 | str: Valid authentication token.
176 | """
177 | self._check_token()
178 | return self._token
179 |
180 | @token.setter
181 | def token(self, value: str) -> None:
182 | """Set the token value.
183 |
184 | Args:
185 | value: Token value to set.
186 | """
187 | self._token = value
188 |
189 | @property
190 | def cert_path(self) -> str:
191 | """Get the path to the certificate.
192 |
193 | Returns:
194 | str: Path to the certificate file.
195 | """
196 | return self._cert_path
197 |
198 | @cert_path.setter
199 | def cert_path(self, value: str) -> None:
200 | """Set the certificate path.
201 |
202 | Args:
203 | value: Certificate path to set.
204 | """
205 | self._cert_path = value
206 |
207 | def _check_token(self) -> None:
208 | """Check if the current token is expired and refresh if needed."""
209 | if datetime.datetime.now() > self.__token_expiration_time:
210 | self._set_token()
211 |
212 | def _set_token(self) -> None:
213 | """Obtain a fresh access token from the API."""
214 | logger = logging.getLogger(__name__)
215 | headers = {
216 | "Authorization": f"Bearer {self.__auth_data}",
217 | "RqUID": self.__client_id,
218 | "Content-Type": "application/x-www-form-urlencoded",
219 | }
220 | try:
221 | logger.info(
222 | "Getting token from GigaChat Auth API",
223 | extra={"url": self.auth_endpoint, "client_id": self.__client_id}
224 | )
225 | response = requests.post(
226 | self.auth_endpoint,
227 | data={"scope": self.auth_scope},
228 | headers=headers,
229 | verify=self._cert_path
230 | )
231 | if response.status_code == 200:
232 | data = response.json()
233 | if "access_token" not in data or "expires_at" not in data:
234 | raise ValueError("Incorrect response format from GigaChat Auth API")
235 |
236 | self._token = data["access_token"]
237 | expiry_timestamp = int(data["expires_at"]) / 1000
238 | self.__token_expiration_time = datetime.datetime.fromtimestamp(expiry_timestamp)
239 |
240 | logger.info("Successfully obtained authentication token")
241 | else:
242 | raise ValueError(
243 | f"Failed to get token: HTTP {response.status_code} - {response.text}"
244 | )
245 | except requests.RequestException as e:
246 | logger.error(f"Network error during token retrieval: {str(e)}")
247 | raise
248 | except ValueError as e:
249 | logger.error(f"Error processing token response: {str(e)}")
250 | raise
251 | except Exception as e:
252 | logger.error(f"Unexpected error during token retrieval: {str(e)}")
253 | raise
254 |
255 |
256 | class SberDSAuthorize(APIAuthorize):
257 | @property
258 | def cert_path(self) -> bool:
259 | return False
260 |
261 | @property
262 | def token(self) -> str:
263 | return ""
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/examples/structured_output.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "# !pip install gigasmol -q"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 2,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "%load_ext autoreload\n",
19 | "%autoreload 2\n",
20 | " \n",
21 | "import warnings\n",
22 | "warnings.filterwarnings('ignore')\n",
23 | "\n",
24 | "import json\n",
25 | "from gigasmol import GigaChat"
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "metadata": {},
31 | "source": [
32 | "### Initialize API model for structured outputs via function calling (without agents)"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": 3,
38 | "metadata": {},
39 | "outputs": [],
40 | "source": [
41 | "credentials = json.load(open('credentials.json'))\n",
42 | "\n",
43 | "model = GigaChat(\n",
44 | " auth_data=credentials['auth_data'],\n",
45 | " model_name=\"GigaChat-Max\",\n",
46 | " api_endpoint=\"https://gigachat.devices.sberbank.ru/api/v1/\", # \"https://gigachat-preview.devices.sberbank.ru/api/v1/\" \n",
47 | " temperature=0.0000001,\n",
48 | " top_p=0.1,\n",
49 | " repetition_penalty=1.1,\n",
50 | " max_tokens=1024,\n",
51 | " profanity_check=False,\n",
52 | " auth_scope=\"GIGACHAT_API_CORP\",\n",
53 | ")"
54 | ]
55 | },
56 | {
57 | "cell_type": "markdown",
58 | "metadata": {},
59 | "source": [
60 | "### Define schema for structured entity extraction"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": 4,
66 | "metadata": {},
67 | "outputs": [],
68 | "source": [
69 | "schema = {\n",
70 | " \"name\": \"extract_complex_entities\",\n",
71 | " \"description\": \"Extracts complex, nested entity information from text.\",\n",
72 | " \"parameters\": {\n",
73 | " \"type\": \"object\",\n",
74 | " \"properties\": {\n",
75 | " \"persons\": {\n",
76 | " \"type\": \"array\",\n",
77 | " \"description\": \"List of persons found in the text with nested properties.\",\n",
78 | " \"items\": {\n",
79 | " \"type\": \"object\",\n",
80 | " \"properties\": {\n",
81 | " \"name\": {\n",
82 | " \"type\": \"string\",\n",
83 | " \"description\": \"Full name of the person\"\n",
84 | " },\n",
85 | " \"birthPlace\": {\n",
86 | " \"type\": \"string\",\n",
87 | " \"description\": \"Birthplace of the person\"\n",
88 | " },\n",
89 | " \"roles\": {\n",
90 | " \"type\": \"array\",\n",
91 | " \"description\": \"Roles or titles the person has held\",\n",
92 | " \"items\": {\n",
93 | " \"type\": \"string\"\n",
94 | " }\n",
95 | " },\n",
96 | " \"education\": {\n",
97 | " \"type\": \"array\",\n",
98 | " \"description\": \"List of institutions where the person studied\",\n",
99 | " \"items\": {\n",
100 | " \"type\": \"string\"\n",
101 | " }\n",
102 | " }\n",
103 | " },\n",
104 | " \"required\": [\"name\"]\n",
105 | " }\n",
106 | " },\n",
107 | " \"organizations\": {\n",
108 | " \"type\": \"array\",\n",
109 | " \"description\": \"List of organizations with additional details.\",\n",
110 | " \"items\": {\n",
111 | " \"type\": \"object\",\n",
112 | " \"properties\": {\n",
113 | " \"orgName\": {\n",
114 | " \"type\": \"string\",\n",
115 | " \"description\": \"Name of the organization\"\n",
116 | " },\n",
117 | " \"orgType\": {\n",
118 | " \"type\": \"string\",\n",
119 | " \"description\": \"Type of organization (e.g., university, government, company)\"\n",
120 | " },\n",
121 | " \"location\": {\n",
122 | " \"type\": \"string\",\n",
123 | " \"description\": \"Location of the organization\"\n",
124 | " }\n",
125 | " },\n",
126 | " \"required\": [\"orgName\"]\n",
127 | " }\n",
128 | " },\n",
129 | " \"locations\": {\n",
130 | " \"type\": \"array\",\n",
131 | " \"description\": \"List of location entities with details.\",\n",
132 | " \"items\": {\n",
133 | " \"type\": \"object\",\n",
134 | " \"properties\": {\n",
135 | " \"placeName\": {\n",
136 | " \"type\": \"string\",\n",
137 | " \"description\": \"Name of the place\"\n",
138 | " },\n",
139 | " \"country\": {\n",
140 | " \"type\": \"string\",\n",
141 | " \"description\": \"Country of the place\"\n",
142 | " }\n",
143 | " },\n",
144 | " \"required\": [\"placeName\"]\n",
145 | " }\n",
146 | " }\n",
147 | " },\n",
148 | " \"required\": [\"persons\", \"organizations\", \"locations\"]\n",
149 | " }\n",
150 | "}"
151 | ]
152 | },
153 | {
154 | "cell_type": "markdown",
155 | "metadata": {},
156 | "source": [
157 | "### Extracting Structured Entities from Text\n"
158 | ]
159 | },
160 | {
161 | "cell_type": "code",
162 | "execution_count": 5,
163 | "metadata": {},
164 | "outputs": [],
165 | "source": [
166 | "system = (\n",
167 | " \"You are a helpful assistant that extracts complex entity information from text \"\n",
168 | " \"and returns ONLY JSON.\"\n",
169 | ")\n",
170 | "\n",
171 | "text = \"\"\"\n",
172 | "“Bill Gates, the co-founder of Microsoft, was born in Seattle, Washington, in 1955. He initially enrolled at Harvard University to study pre-law, but he shifted his focus to mathematics and computer science before dropping out. Along with Paul Allen, he founded Microsoft in 1975, which rapidly grew into one of the world’s largest software companies. Over the years, Bill Gates became a notable philanthropist through the Bill & Melinda Gates Foundation, focusing on global health and education initiatives.\n",
173 | "\n",
174 | "\tMeanwhile, Mark Zuckerberg, born in 1984, is the co-founder and CEO of Facebook, now known as Meta. He developed the initial version of the platform while studying computer science at Harvard University, though he never completed his degree. Under his leadership, Facebook rapidly expanded into a worldwide social media giant. Zuckerberg also launched the Chan Zuckerberg Initiative in collaboration with his wife, Dr. Priscilla Chan, to tackle issues related to education, healthcare, and scientific research.\n",
175 | "\n",
176 | "\tOn the other hand, Elon Musk, born in Pretoria, South Africa, in 1971, is known for founding SpaceX and co-founding Tesla. After moving to Canada, he later studied at the University of Pennsylvania in the United States, where he earned degrees in physics and economics. Musk’s ventures have often focused on emerging technologies and innovation—from electric vehicles and solar energy solutions at Tesla to space exploration and rocket technology at SpaceX. In recent years, he has also been actively involved in artificial intelligence research, among other futuristic endeavors.”\n",
177 | "\"\"\"\n",
178 | "\n",
179 | "\n",
180 | "prompt = (\n",
181 | " \"The user provides the following text:\\n\\n\"\n",
182 | " f\"{text}\\n\"\n",
183 | " \"Extract the following complex information:\\n\"\n",
184 | " \" - persons: with name, birthPlace, roles (array of strings), and education (array of schools),\\n\"\n",
185 | " \" - organizations: with orgName, orgType, and location,\\n\"\n",
186 | " \" - locations: with placeName, country.\\n\\n\"\n",
187 | " \"Return ONLY JSON containing arrays for 'persons', 'organizations', and 'locations'. \"\n",
188 | " \"Each 'person' should have nested fields such as 'name', 'birthPlace', 'roles', \"\n",
189 | " \"and an array of 'education' items. Each 'organization' has 'orgName', 'orgType', 'location'. \"\n",
190 | " \"Each 'location' has 'placeName' and 'country'.\"\n",
191 | ")"
192 | ]
193 | },
194 | {
195 | "cell_type": "code",
196 | "execution_count": 6,
197 | "metadata": {},
198 | "outputs": [],
199 | "source": [
200 | "messages = [\n",
201 | " {\"role\": \"system\", \"content\": system},\n",
202 | " {\"role\": \"user\", \"content\": prompt}\n",
203 | "]"
204 | ]
205 | },
206 | {
207 | "cell_type": "code",
208 | "execution_count": 7,
209 | "metadata": {},
210 | "outputs": [],
211 | "source": [
212 | "response_complete = model.chat(\n",
213 | " messages=messages,\n",
214 | " functions=[schema],\n",
215 | " function_call={\"name\": \"extract_complex_entities\"}\n",
216 | ")"
217 | ]
218 | },
219 | {
220 | "cell_type": "code",
221 | "execution_count": 8,
222 | "metadata": {},
223 | "outputs": [
224 | {
225 | "data": {
226 | "text/plain": [
227 | "{'locations': [{'country': 'United States',\n",
228 | " 'placeName': 'Seattle, Washington'},\n",
229 | " {'country': 'United States', 'placeName': 'White Plains, New York'},\n",
230 | " {'country': 'South Africa', 'placeName': 'Pretoria, South Africa'},\n",
231 | " {'country': 'United States', 'placeName': 'Menlo Park, California'},\n",
232 | " {'country': 'United States', 'placeName': 'Hawthorne, California'},\n",
233 | " {'country': 'United States', 'placeName': 'Palo Alto, California'}],\n",
234 | " 'organizations': [{'location': 'Seattle, Washington',\n",
235 | " 'orgName': 'Microsoft',\n",
236 | " 'orgType': 'software company'},\n",
237 | " {'location': 'Menlo Park, California',\n",
238 | " 'orgName': 'Facebook (Meta)',\n",
239 | " 'orgType': 'social media company'},\n",
240 | " {'location': 'Hawthorne, California',\n",
241 | " 'orgName': 'SpaceX',\n",
242 | " 'orgType': 'aerospace manufacturer and space transportation services company'},\n",
243 | " {'location': 'Palo Alto, California',\n",
244 | " 'orgName': 'Tesla',\n",
245 | " 'orgType': 'electric vehicle and clean energy company'}],\n",
246 | " 'persons': [{'birthPlace': 'Seattle, Washington',\n",
247 | " 'education': ['Harvard University'],\n",
248 | " 'name': 'Bill Gates',\n",
249 | " 'roles': ['co-founder of Microsoft', 'philanthropist']},\n",
250 | " {'birthPlace': 'White Plains, New York',\n",
251 | " 'education': ['Harvard University'],\n",
252 | " 'name': 'Mark Zuckerberg',\n",
253 | " 'roles': ['co-founder and CEO of Facebook (Meta)']},\n",
254 | " {'birthPlace': 'Pretoria, South Africa',\n",
255 | " 'education': ['University of Pennsylvania'],\n",
256 | " 'name': 'Elon Musk',\n",
257 | " 'roles': ['founder of SpaceX', 'co-founder of Tesla']}]}"
258 | ]
259 | },
260 | "execution_count": 8,
261 | "metadata": {},
262 | "output_type": "execute_result"
263 | }
264 | ],
265 | "source": [
266 | "response_complete['response']['choices'][0]['message']['function_call']['arguments']"
267 | ]
268 | }
269 | ],
270 | "metadata": {
271 | "kernelspec": {
272 | "display_name": "Python 3",
273 | "language": "python",
274 | "name": "python3"
275 | },
276 | "language_info": {
277 | "codemirror_mode": {
278 | "name": "ipython",
279 | "version": 3
280 | },
281 | "file_extension": ".py",
282 | "mimetype": "text/x-python",
283 | "name": "python",
284 | "nbconvert_exporter": "python",
285 | "pygments_lexer": "ipython3",
286 | "version": "3.12.3"
287 | }
288 | },
289 | "nbformat": 4,
290 | "nbformat_minor": 2
291 | }
292 |
--------------------------------------------------------------------------------
/src/gigasmol/models.py:
--------------------------------------------------------------------------------
1 | import uuid
2 | import json
3 | import logging
4 | from copy import deepcopy
5 | from typing import List, Dict, Optional, Any, Tuple, Union, Literal
6 |
7 | from smolagents.tools import Tool
8 | from smolagents.models import Model, MessageRole, ChatMessage
9 | from smolagents.monitoring import TokenUsage
10 | from huggingface_hub import ChatCompletionOutputToolCall
11 |
12 | from .gigachat_api.api_model import DialogRole, GigaChat, MessageList
13 |
14 |
15 | TOOL_ROLE_CONVERSIONS = {
16 | MessageRole.TOOL_CALL: DialogRole.ASSISTANT,
17 | MessageRole.TOOL_RESPONSE: DialogRole.USER,
18 | MessageRole.ASSISTANT: DialogRole.ASSISTANT,
19 | MessageRole.USER: DialogRole.USER,
20 | MessageRole.SYSTEM: DialogRole.SYSTEM,
21 | }
22 |
23 |
24 | def remove_stop_sequences(content: str, stop_sequences: List[str]) -> str:
25 | """Remove stop sequences from the end of content string.
26 |
27 | This function checks if the content ends with any of the provided stop sequences
28 | and removes them if found.
29 |
30 | Args:
31 | content: The string content to process.
32 | stop_sequences: A list of stop sequences to check for and remove.
33 |
34 | Returns:
35 | str: The content with any matching stop sequences removed from the end.
36 | """
37 | for stop_seq in stop_sequences:
38 | if content[-len(stop_seq) :] == stop_seq:
39 | content = content[: -len(stop_seq)]
40 | return content
41 |
42 |
43 | def parse_json_if_needed(arguments: Union[str, dict]) -> Union[str, dict]:
44 | """Parse JSON string to dictionary if needed.
45 |
46 | This function checks if the input is already a dictionary. If not, it attempts
47 | to parse it as JSON. If parsing fails, it returns the original string.
48 |
49 | Args:
50 | arguments: Either a string potentially containing JSON or a dictionary.
51 |
52 | Returns:
53 | Union[str, dict]: The parsed dictionary if successful, or the original input.
54 | """
55 | if isinstance(arguments, dict):
56 | return arguments
57 | else:
58 | try:
59 | return json.loads(arguments)
60 | except Exception:
61 | return arguments
62 |
63 |
64 | def parse_tool_args_if_needed(message: ChatMessage) -> ChatMessage:
65 | """Parse tool call arguments from JSON strings to dictionaries if needed.
66 |
67 | This function processes a ChatMessage object, checking if it contains tool calls.
68 | For each tool call, it attempts to parse the function arguments from JSON string
69 | to dictionary format if they're not already dictionaries.
70 |
71 | Args:
72 | message: A ChatMessage object that may contain tool calls.
73 |
74 | Returns:
75 | ChatMessage: The same message object with tool call arguments parsed if needed.
76 | """
77 | if message.tool_calls is not None:
78 | for tool_call in message.tool_calls:
79 | tool_call.function.arguments = parse_json_if_needed(tool_call.function.arguments)
80 | return message
81 |
82 |
83 | def get_tool_json_schema_gigachat(tool: Tool) -> Dict:
84 | """Convert a Tool object to a GigaChat-compatible function schema.
85 |
86 | This function transforms a smolagents Tool object into a JSON schema format
87 | that is compatible with GigaChat's function calling API. It handles type
88 | conversions and determines required parameters.
89 |
90 | Args:
91 | tool: A Tool object containing name, description, and input specifications.
92 |
93 | Returns:
94 | Dict: A dictionary representing the function schema in GigaChat's expected format
95 | """
96 | properties = deepcopy(tool.inputs)
97 | required = []
98 | for key, value in properties.items():
99 | if value["type"] == "any":
100 | value["type"] = "string"
101 | if not ("nullable" in value and value["nullable"]):
102 | required.append(key)
103 | return {
104 | "name": tool.name,
105 | "description": tool.description,
106 | "parameters": {
107 | "type": "object",
108 | "properties": properties,
109 | "required": required,
110 | },
111 | }
112 |
113 |
114 | def map_message_roles_to_api_format(messages: List[Dict[str, Any]]) -> List[Dict[DialogRole, str]]:
115 | """Convert smolagents message format to GigaChat API format.
116 |
117 | This function transforms messages from the smolagents format to the format
118 | expected by the GigaChat API. It converts message roles using the TOOL_ROLE_CONVERSIONS
119 | mapping and extracts the text content from each message.
120 |
121 | Args:
122 | messages: A list of message dictionaries in smolagents format, each containing
123 | 'role' and 'content' keys.
124 |
125 | Returns:
126 | List[Dict[DialogRole, str]]: A list of tuples containing the converted DialogRole
127 | and the message text content.
128 | """
129 | converted_messages = []
130 | for message in messages:
131 | message_role = TOOL_ROLE_CONVERSIONS[message.role]
132 | message_content = message.content[0]['text']
133 | converted_messages.append({"role": message_role, "content": message_content})
134 | return converted_messages
135 |
136 |
137 | def extract_tool_calls(response: Dict[str, Any]) -> Optional[ChatCompletionOutputToolCall]:
138 | """Extract and format tool calls from a raw GigaChat API response.
139 |
140 | This utility function processes a raw GigaChat response and extracts any function
141 | calls, formatting them into the standardized structure with unique IDs.
142 |
143 | Args:
144 | response: The raw response from GigaChat API
145 |
146 | Returns:
147 | ChatCompletionOutputToolCall
148 | """
149 | tool_calls = []
150 | for choice in response['response']['choices']:
151 | if 'message' in choice and 'function_call' in choice['message']:
152 | func_call = choice['message']['function_call']
153 | call_id = f"call_{str(uuid.uuid4())[:8]}"
154 | arguments = func_call['arguments']
155 | if isinstance(arguments, dict):
156 | arguments = json.dumps(arguments)
157 |
158 | formatted_call = {
159 | "id": call_id,
160 | "type": "function",
161 | "function": {
162 | "name": func_call['name'],
163 | "arguments": arguments
164 | }
165 | }
166 | tool_calls.append(formatted_call)
167 | return ChatCompletionOutputToolCall.parse_obj(tool_calls) if tool_calls else None
168 |
169 |
170 | def create_final_answer_tool_call(answer: str) -> ChatCompletionOutputToolCall:
171 | """Create a FinalAnswerTool call with the given answer.
172 |
173 | This helper method creates a properly formatted tool call for the FinalAnswerTool
174 | using the provided answer as the argument.
175 |
176 | Args:
177 | answer: The text answer to include in the tool call.
178 |
179 | Returns:
180 | ChatCompletionOutputToolCall: A formatted tool call for FinalAnswerTool.
181 | """
182 | call_id = f"call_{str(uuid.uuid4())[:8]}"
183 | final_answer_call = [{
184 | "id": call_id,
185 | "type": "function",
186 | "function": {
187 | "name": "final_answer",
188 | "arguments": json.dumps({"answer": answer})
189 | }
190 | }]
191 | return ChatCompletionOutputToolCall.parse_obj(final_answer_call)
192 |
193 |
194 | class GigaChatSmolModel(Model):
195 | """A wrapper for the GigaChat model that implements the smolagents Model interface.
196 |
197 | This class handles communication with the GigaChat API, including authentication,
198 | message formatting, and response processing.
199 |
200 | Attributes:
201 | model_name: The name of the GigaChat model to use.
202 | temperature: Controls randomness in generation (0.0-1.0).
203 | top_p: Controls diversity via nucleus sampling (0.0-1.0).
204 | repetition_penalty: Penalizes repetition in generated text (>= 1.0).
205 | max_tokens: Maximum number of tokens to generate.
206 | profanity_check: Whether to enable profanity filtering.
207 | auth: Authentication handler for the GigaChat API.
208 | gigachat_instance: The underlying GigaChat client.
209 | """
210 |
211 | def __init__(
212 | self,
213 | auth_data: str,
214 | model_name: str = "GigaChat",
215 | api_endpoint: str = "https://gigachat.devices.sberbank.ru/api/v1/",
216 | temperature: float = 0.1,
217 | top_p: float = 0.1,
218 | repetition_penalty: float = 1.0,
219 | max_tokens: int = 1500,
220 | profanity_check: bool = True,
221 | client_id: Optional[str] = None,
222 | auth_endpoint: str = "https://ngw.devices.sberbank.ru:9443/api/v2/oauth",
223 | auth_scope: Literal["GIGACHAT_API_PERS", "GIGACHAT_API_CORP", "GIGACHAT_API_B2B"] = "GIGACHAT_API_CORP",
224 | cert_path: Optional[str] = None,
225 | ) -> None:
226 | """Initialize a new GigaChatModel instance.
227 |
228 | Args:
229 | auth_data: Authorization key for exchanging messages with GigaChat API
230 | model_name: The name of the GigaChat model to use.
231 | api_endpoint: The GigaChat API endpoint URL.
232 | temperature: Controls randomness in generation (0.0-1.0).
233 | top_p: Controls diversity via nucleus sampling (0.0-1.0).
234 | repetition_penalty: Penalizes repetition in generated text (>= 1.0).
235 | max_tokens: Maximum number of tokens to generate.
236 | profanity_check: Whether to enable profanity filtering.
237 | client_id: The client ID for API authentication.
238 | auth_endpoint: The authentication endpoint URL.
239 | auth_scope: The authentication scope.
240 | cert_path: Path to the certificate file for API authentication.
241 | """
242 | super().__init__()
243 | self.model_name = model_name
244 | self.temperature = temperature
245 | self.top_p = top_p
246 | self.repetition_penalty = repetition_penalty
247 | self.max_tokens = max_tokens
248 | self.profanity_check = profanity_check
249 | self.gigachat_instance = GigaChat(
250 | auth_data=auth_data,
251 | model_name=self.model_name,
252 | api_endpoint=api_endpoint,
253 | temperature=self.temperature,
254 | top_p=self.top_p,
255 | repetition_penalty=self.repetition_penalty,
256 | max_tokens=self.max_tokens,
257 | profanity_check=self.profanity_check,
258 | client_id=client_id,
259 | auth_endpoint=auth_endpoint,
260 | auth_scope=auth_scope,
261 | cert_path=cert_path
262 | )
263 |
264 | def generate(
265 | self,
266 | messages: List[Dict[str, Any]],
267 | stop_sequences: Optional[List[str]] = None,
268 | grammar: Optional[str] = None,
269 | tools_to_call_from: Optional[List[Tool]] = None,
270 | ) -> ChatMessage:
271 | try:
272 | messages = map_message_roles_to_api_format(messages)
273 | functions = [get_tool_json_schema_gigachat(tool) for tool in tools_to_call_from] if tools_to_call_from else None
274 | response = self.chat(messages=messages, functions=functions)
275 | answer = response.get('answer', '')
276 | tool_calls = extract_tool_calls(response)
277 |
278 | if tool_calls is None and tools_to_call_from is not None:
279 | tool_calls = create_final_answer_tool_call(answer)
280 |
281 | if stop_sequences and isinstance(stop_sequences, list):
282 | answer = remove_stop_sequences(answer, stop_sequences)
283 |
284 | return parse_tool_args_if_needed(
285 | ChatMessage(
286 | role="assistant",
287 | content=answer,
288 | tool_calls=tool_calls,
289 | raw=response['response'],
290 | token_usage=TokenUsage(
291 | input_tokens=response['response']['usage']['prompt_tokens'],
292 | output_tokens=response['response']['usage']['completion_tokens']
293 | )
294 | )
295 | )
296 | except Exception as e:
297 | logging.error(f"Critical error in __call__: {str(e)}", exc_info=True)
298 | return ChatMessage(
299 | role="assistant",
300 | content=f"Error in model execution: {str(e)}"
301 | )
302 |
303 | def chat(
304 | self,
305 | messages: MessageList,
306 | params: Optional[Dict[str, Any]] = None,
307 | functions: Optional[List[Dict[str, Any]]] = None,
308 | function_call: Optional[Union[str, Dict[str, str]]] = None,
309 | ) -> Dict[str, Any]:
310 | return self.gigachat_instance.chat(messages, params, functions, function_call)
311 |
312 | def get_available_models(self) -> List[str]:
313 | return self.gigachat_instance._get_list_model()
--------------------------------------------------------------------------------
/src/gigasmol/gigachat_api/api_model.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | from datetime import datetime
4 | from typing import Any, Iterator, Optional, Union, Literal, Dict, List, Tuple
5 |
6 | import requests
7 | from sseclient import SSEClient
8 |
9 | from .auth import APIAuthorize, LLMAuthorizeEnablers, SberDSAuthorize
10 |
11 | try:
12 | from enum import StrEnum
13 | except ImportError:
14 | from enum import Enum
15 | class StrEnum(str, Enum):
16 | pass
17 |
18 |
19 | class DialogRole(StrEnum):
20 | """Roles for GigaChat conversations."""
21 | SYSTEM = "system"
22 | USER = "user"
23 | ASSISTANT = "assistant"
24 | FUNCTION = "function"
25 |
26 |
27 | MessageList = List[Dict[str, Any]]
28 |
29 |
30 | class GigaChat:
31 | """Access to LLM based on GigaChat API."""
32 |
33 | def __init__(
34 | self,
35 | auth_data: Optional[str] = None,
36 | model_name: str = 'GigaChat',
37 | api_endpoint: str = "https://gigachat.devices.sberbank.ru/api/v1/",
38 | sber_ds: bool = False,
39 | authorize: Optional[APIAuthorize] = None,
40 | temperature: float = 0.1,
41 | top_p: float = 0.1,
42 | repetition_penalty: float = 1.0,
43 | max_tokens: int = 5000,
44 | n: int = 1,
45 | n_stream: int = 1,
46 | profanity_check: bool = True,
47 | client_id: Optional[str] = None,
48 | auth_endpoint: Optional[str] = "https://ngw.devices.sberbank.ru:9443/api/v2/oauth",
49 | auth_scope: Literal["GIGACHAT_API_PERS", "GIGACHAT_API_CORP", "GIGACHAT_API_B2B"] = "GIGACHAT_API_CORP",
50 | cert_path: Optional[str] = None,
51 | ) -> None:
52 | """Initialize with GigaChat API access parameters and response generation settings.
53 |
54 | Args:
55 | auth_data: Authorization key for exchanging messages with GigaChat API
56 | model_name: Model name of the GigaChat to use.
57 | api_endpoint: GigaChat API URL
58 | sber_ds: Using on sber-ds platform
59 | authorize: Authorization method for GigaChat API. If not provided, LLMAuthorizeEnablers will be used.
60 | temperature: Temperature parameter; higher values produce more diverse outputs (typical value 0.7).
61 | top_p: Another parameter for output diversity (typical value 0.1).
62 | repetition_penalty: Controls word repetition. Value 1.0 is neutral, 0-1 increases repetition, >1 decreases repetition.
63 | max_tokens: Maximum number of tokens to generate in the response.
64 | n: Number of completions to generate (non-streaming mode).
65 | n_stream: Number of completions to generate (streaming mode).
66 | profanity_check: Whether to enable profanity checking.
67 | client_id: GigaChat API client ID (used as RqUID).
68 | auth_endpoint: The authentication endpoint URL.
69 | auth_scope: The authentication scope. Contains information about the API version being accessed.
70 | cert_path: Path to the certificate for GigaChat API access.
71 | """
72 | self.api_endpoint = api_endpoint
73 | self.sber_ds = sber_ds
74 | self.temperature = temperature
75 | self.top_p = top_p
76 | self.repetition_penalty = repetition_penalty
77 | self.model_name = model_name
78 | self.max_tokens = max_tokens
79 | self.n = n
80 | self.n_stream = n_stream
81 | self.profanity_check = profanity_check
82 |
83 | assert auth_data is not None or sber_ds is True, "auth_data is required for non-sber_ds mode"
84 |
85 | if authorize is not None:
86 | self.__authorize = authorize
87 | elif not sber_ds:
88 | self.__authorize = LLMAuthorizeEnablers(
89 | auth_data=auth_data,
90 | client_id=client_id,
91 | auth_endpoint=auth_endpoint,
92 | auth_scope=auth_scope,
93 | cert_path=cert_path
94 | )
95 | else:
96 | self.__authorize = SberDSAuthorize()
97 | self.__token_expiration_time = datetime.min
98 |
99 | def _get_list_model(self) -> List[str]:
100 | """Get a list of available models.
101 |
102 | The model list is useful when multiple models are available on the inference endpoint
103 | and a specific large or small model is needed.
104 |
105 | Returns:
106 | list[str]: List of available model names.
107 |
108 | Raises:
109 | RuntimeError: If the API response is invalid or an error occurs.
110 | """
111 | try:
112 | url = f"{self.api_endpoint}models"
113 | headers = {}
114 |
115 | if not self.sber_ds:
116 | headers["Authorization"] = f"Bearer {self.__authorize.token}"
117 | else:
118 | headers["Accept"] = "application/json"
119 |
120 | response = requests.get(url, headers=headers, verify=self.__authorize.cert_path)
121 | if response.status_code == 200:
122 | models = json.loads(response.content)
123 | if "data" not in models:
124 | raise RuntimeError("Incorrect response from GigaChat API")
125 | logging.info(f"Available models: {models['data']}")
126 | return models["data"]
127 | else:
128 | raise RuntimeError(f"Error in getting list models: {response.status_code} - {response.content}")
129 | except Exception as e:
130 | logging.error(f"{str(e)}")
131 | raise e
132 |
133 | def _prepare_request(
134 | self,
135 | params: Dict[str, Any],
136 | messages: MessageList,
137 | stream: bool,
138 | functions: Optional[List[Dict[str, Any]]] = None,
139 | function_call: Optional[Union[str, Dict[str, str]]] = None,
140 | ) -> Tuple[str, Dict[str, Any], str, str]:
141 | """Prepare and return all necessary parameters for a Gigachat API request.
142 |
143 | Args:
144 | params: Additional parameters for the request.
145 | messages: List of formatted messages.
146 | stream: Whether to use streaming mode.
147 | functions: Optional list of function definitions.
148 | function_call: Optional function call specification.
149 |
150 | Returns:
151 | tuple: URL, headers, JSON query, and model name.
152 | """
153 | temperature = params.get("temperature", self.temperature)
154 | top_p = params.get("top_p", self.top_p)
155 | repetition_penalty = params.get("repetition_penalty", self.repetition_penalty)
156 | model_name: str = params.get("model_name", self.model_name)
157 | max_tokens: int = params.get("max_tokens", self.max_tokens)
158 | profanity_check: bool = params.get("profanity_check", self.profanity_check)
159 | n = self.n_stream if stream else self.n
160 |
161 | url = f"{self.api_endpoint}chat/completions"
162 | headers = {"Content-Type": "application/json"}
163 |
164 | if not self.sber_ds:
165 | headers["Authorization"] = f"Bearer {self.__authorize.token}"
166 | else:
167 | headers["Accept"] = "application/json"
168 |
169 | params = {
170 | "model": f"{model_name}",
171 | "messages": messages,
172 | "temperature": temperature,
173 | "top_p": top_p,
174 | "n": n,
175 | "repetition_penalty": repetition_penalty,
176 | "profanity_check": bool(profanity_check),
177 | "stream": bool(stream),
178 | "max_tokens": int(max_tokens),
179 | }
180 | if functions is not None:
181 | params["functions"] = functions
182 | if function_call is not None:
183 | params["function_call"] = function_call
184 | else:
185 | params["function_call"] = "auto"
186 |
187 | if stream:
188 | params["update_interval"] = 0
189 | query = json.dumps(params, ensure_ascii=False)
190 | return url, headers, query, model_name
191 |
192 | def get_model_name(self) -> str:
193 | """Return the name of the default selected model.
194 |
195 | Returns:
196 | str: The model name.
197 | """
198 | return self.model_name
199 |
200 | def complete(
201 | self,
202 | prompt: str,
203 | system_prompt: Optional[str] = None,
204 | params: Optional[Dict[str, Any]] = None,
205 | functions: Optional[List[Dict[str, Any]]] = None,
206 | function_call: Optional[Union[str, Dict[str, str]]] = None,
207 | ) -> Dict[str, Any]:
208 | """Send a prompt to the model and return the model's response.
209 |
210 | Args:
211 | prompt: The user's prompt text.
212 | system_prompt: Optional system instructions.
213 | params: Additional parameters for controlling the response generation.
214 | functions: Optional list of function definitions for function calling.
215 | function_call: Optional function call specification.
216 |
217 | Returns:
218 | dict: A dictionary containing the LLM response and metadata.
219 | """
220 | messages = []
221 | if system_prompt is not None:
222 | messages.append({"role": "system", "content": system_prompt})
223 | messages.append({"role": "user", "content": prompt})
224 | return self.chat(messages=messages, params=params, functions=functions, function_call=function_call)
225 |
226 | def complete_stream(
227 | self,
228 | prompt: str,
229 | system_prompt: Optional[str] = None,
230 | params: Optional[Dict[str, Any]] = None,
231 | functions: Optional[List[Dict[str, Any]]] = None,
232 | function_call: Optional[Union[str, Dict[str, str]]] = None,
233 | ) -> Iterator[str]:
234 | """Send a prompt to the model and return the model's response as a stream.
235 |
236 | Args:
237 | prompt: The user's prompt text.
238 | system_prompt: Optional system instructions.
239 | params: Additional parameters for controlling the response generation.
240 | functions: Optional list of function definitions for function calling.
241 | function_call: Optional function call specification.
242 |
243 | Returns:
244 | Iterator[str]: An iterator that yields response tokens as they're generated.
245 | """
246 | messages = []
247 | if system_prompt is not None:
248 | messages.append({"role": "system", "content": system_prompt})
249 | messages.append({"role": "user", "content": prompt})
250 | return self.chat_stream(messages=messages, params=params, functions=functions, function_call=function_call)
251 |
252 | def tokens(self, text: str, params: Dict[str, Any]) -> int:
253 | """Count the number of tokens in the given text string.
254 |
255 | Args:
256 | text: The text string to count tokens for.
257 | params: Additional parameters as a dictionary.
258 |
259 | Returns:
260 | int: The number of tokens in the provided text.
261 |
262 | Raises:
263 | RuntimeError: If the API response is invalid or an error occurs.
264 | """
265 | try:
266 | model_name = params.get("model_name", self.model_name)
267 | headers = {"Content-Type": "application/json"}
268 |
269 | if not self.sber_ds:
270 | headers["Authorization"] = f"Bearer {self.__authorize.token}"
271 | else:
272 | headers["Accept"] = "application/json"
273 |
274 | query = json.dumps({"model": f"{model_name}", "input": [text]})
275 | url = f"{self.api_endpoint}tokens/count"
276 | logging.debug(f"Count tokens on LLM '{model_name}' and prompt: {text}...")
277 | response = requests.post(url, data=query, headers=headers, verify=self.__authorize.cert_path)
278 | if response.status_code == 200:
279 | answer = json.loads(response.content)
280 | if "tokens" not in answer[0] or "characters" not in answer[0]:
281 | raise RuntimeError("Incorrect response from GigaChat API")
282 | return answer[0]["tokens"]
283 | else:
284 | raise RuntimeError(f"Error in running tokens method: {response.status_code} - {response.content}")
285 | except Exception as e:
286 | logging.error(f"{str(e)}")
287 | raise e
288 |
289 | def chat(
290 | self,
291 | messages: MessageList,
292 | params: Optional[Dict[str, Any]] = None,
293 | functions: Optional[List[Dict[str, Any]]] = None,
294 | function_call: Optional[Union[str, Dict[str, str]]] = None,
295 | ) -> Dict[str, Any]:
296 | """Send a list of messages to the model and return the model's response.
297 |
298 | Args:
299 | messages: Messages either as a list of (role, message) tuples or
300 | a list of dictionaries with 'role' and 'content' keys.
301 | params: Additional parameters for controlling the response generation.
302 | functions: Optional list of function definitions for function calling.
303 | function_call: Optional function call specification.
304 |
305 | Returns:
306 | Dict: A dictionary containing the LLM response and metadata.
307 |
308 | Raises:
309 | RuntimeError: If the API response is invalid or an error occurs.
310 | """
311 | try:
312 | params = params if params is not None else {}
313 | url, headers, query, model_name = self._prepare_request(
314 | params=params,
315 | messages=messages,
316 | stream=False,
317 | functions=functions,
318 | function_call=function_call,
319 | )
320 | logging.debug(f"Run chat on LLM '{model_name}' and chat: {messages}...")
321 | response = requests.post(url, data=query, headers=headers, verify=self.__authorize.cert_path)
322 | if response.status_code == 200:
323 | answer = json.loads(response.content)
324 | if "choices" not in answer or "usage" not in answer:
325 | raise RuntimeError("Incorrect response from GigaChat API")
326 | complete = answer["choices"][0]["message"]["content"]
327 | prompt_tokens = answer["usage"]["prompt_tokens"]
328 | completion_tokens = answer["usage"]["completion_tokens"]
329 | finish_reason = answer["choices"][0]["finish_reason"]
330 | logging.debug(f"answer: {complete}")
331 | logging.debug(f"prompt tokens number: {prompt_tokens} complete tokens number: {completion_tokens}")
332 | return {
333 | "answer": complete,
334 | "response": answer,
335 | "prompt_tokens": prompt_tokens,
336 | "answer_tokens": completion_tokens,
337 | "finish_reason": finish_reason,
338 | "info": json.dumps({'model':response.json()['model'],
339 | 'x-request-id':response.headers.get('x-request-id'),
340 | 'x-session-id':response.headers.get('x-session-id')})
341 | }
342 | else:
343 | raise RuntimeError(f"Error in running complete method: {response.status_code} - {response.content}")
344 | except Exception as e:
345 | logging.error(f"{str(e)}")
346 | raise e
347 |
348 | def chat_stream(
349 | self,
350 | messages: MessageList,
351 | params: Optional[Dict[str, Any]] = None,
352 | functions: Optional[List[Dict[str, Any]]] = None,
353 | function_call: Optional[Union[str, Dict[str, str]]] = None,
354 | ) -> Iterator[str]:
355 | """Send a list of messages to the model and return the model's response as a stream.
356 |
357 | Args:
358 | messages: Messages either as a list of (role, message) tuples or
359 | a list of dictionaries with 'role' and 'content' keys.
360 | params: Additional parameters for controlling the response generation.
361 | functions: Optional list of function definitions for function calling.
362 | function_call: Optional function call specification.
363 |
364 | Returns:
365 | Iterator[str]: An iterator that yields response tokens as they're generated.
366 |
367 | Raises:
368 | RuntimeError: If the API response is invalid or an error occurs.
369 | """
370 | try:
371 | params = params if params is not None else {}
372 | url, headers, query, model_name = self._prepare_request(
373 | params=params,
374 | messages=messages,
375 | stream=True,
376 | functions=functions,
377 | function_call=function_call,
378 | )
379 | logging.debug(f"Run chat on LLM '{model_name}' and chat: {messages}...")
380 | response = requests.post(url, data=query, headers=headers, verify=self.__authorize.cert_path, stream=True)
381 | if response.status_code == 200:
382 | finish_reason = str()
383 | client = SSEClient(response)
384 | for event in client.events():
385 | if event.data == "[DONE]":
386 | return {"finish_reason": finish_reason}
387 | event_data = json.loads(event.data)
388 | if "choices" not in event_data:
389 | raise RuntimeError("No choices field in response from GigaChat API")
390 | if "finish_reason" in event_data["choices"][0]:
391 | finish_reason = event_data["choices"][0]["finish_reason"]
392 | token = event_data["choices"][0]["delta"]["content"]
393 | logging.debug(f"token: {token}")
394 | yield token
395 | else:
396 | raise RuntimeError(f"Error in running complete method: {response.status_code} - {response.content}")
397 | except Exception as e:
398 | logging.error(f"{str(e)}")
399 | raise e
400 |
401 | def check_chat_profanity(self, messages: MessageList) -> bool:
402 | """Check if the given message history contains prohibited content.
403 |
404 | Args:
405 | messages: The message history as a list of (role, message) tuples.
406 |
407 | Returns:
408 | bool: True if the history contains prohibited content, False otherwise.
409 | """
410 | try:
411 | check_result = self.chat(messages=messages, params={"profanity_check": True})
412 | return True if check_result["finish_reason"] == "blacklist" else False
413 | except Exception as err:
414 | logging.error(f"Error in checking censorship for chat: '{messages}', error: '{err}'")
415 | raise err
416 |
417 | def check_question_profanity(self, question: str) -> bool:
418 | """Check if the given question contains prohibited content.
419 |
420 | Args:
421 | question: The question text.
422 |
423 | Returns:
424 | bool: True if the question contains prohibited content, False otherwise.
425 | """
426 | try:
427 | check_result = self.complete(prompt=question, params={"profanity_check": True})
428 | return True if check_result["finish_reason"] == "blacklist" else False
429 | except Exception as err:
430 | logging.error(f"Error in checking censorship for question: '{question}', error: '{err}'")
431 | raise err
432 |
433 |
434 | class GigaFilter:
435 | """Class for direct interaction with GigaFilter for profanity checking."""
436 |
437 | def __init__(self, api_endpoint: str, authorize: APIAuthorize) -> None:
438 | """Initialize GigaFilter.
439 |
440 | Args:
441 | api_endpoint: The API endpoint URL.
442 | authorize: Authorization method for the API.
443 | """
444 | self.api_endpoint = api_endpoint
445 | self.__authorize = authorize
446 |
447 | def check_profanity(self, question: str, return_json: bool = False) -> Union[bool, Dict[str, Any]]:
448 | """Check if the given text contains profanity.
449 |
450 | Args:
451 | question: The text to check.
452 | return_json: If True, return the full JSON response; if False, return only a boolean result.
453 |
454 | Returns:
455 | Union[bool, dict]: Either a boolean indicating whether profanity was detected,
456 | or the complete JSON response if return_json is True.
457 | """
458 | url = f"{self.api_endpoint}filter/check"
459 | headers = {
460 | "Authorization": f"Bearer {self.__authorize.token}",
461 | "Content-Type": "application/json",
462 | }
463 | query = {
464 | "model": "GigaFilter",
465 | "messages": [{
466 | "content": question,
467 | "role": "user"
468 | }]
469 | }
470 | query = json.dumps(query)
471 | response = requests.post(url, data=query, headers=headers, verify=self.__authorize.cert_path)
472 | if return_json:
473 | return response.json()
474 | else:
475 | return response.json()['is_profane']
476 |
--------------------------------------------------------------------------------