├── libs └── gigachat │ ├── tests │ ├── __init__.py │ └── unit_tests │ │ ├── __init__.py │ │ ├── utils │ │ ├── __init__.py │ │ └── test_function_calling.py │ │ ├── test_imports.py │ │ ├── test_utils.py │ │ ├── stubs.py │ │ └── test_gigachat.py │ ├── langchain_gigachat │ ├── py.typed │ ├── tools │ │ ├── __init__.py │ │ ├── load_prompt.py │ │ └── giga_tool.py │ ├── utils │ │ ├── __init__.py │ │ ├── pydantic_generator.py │ │ └── function_calling.py │ ├── output_parsers │ │ ├── __init__.py │ │ └── gigachat_functions.py │ ├── chat_models │ │ ├── __init__.py │ │ ├── base_gigachat.py │ │ └── gigachat.py │ ├── embeddings │ │ ├── __init__.py │ │ └── gigachat.py │ └── __init__.py │ ├── scripts │ ├── lint_imports.sh │ └── check_imports.py │ ├── LICENSE │ ├── Makefile │ ├── pyproject.toml │ ├── README.md │ └── README-ru_RU.md ├── README.md ├── .pre-commit-config.yaml ├── LICENSE ├── .github ├── scripts │ ├── check_diff.py │ └── get_min_versions.py ├── workflows │ ├── _test.yml │ ├── check_diffs.yml │ └── _lint.yml └── actions │ └── poetry_setup │ └── action.yml └── .gitignore /libs/gigachat/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /libs/gigachat/langchain_gigachat/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /libs/gigachat/tests/unit_tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /libs/gigachat/tests/unit_tests/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /libs/gigachat/langchain_gigachat/tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /libs/gigachat/langchain_gigachat/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /libs/gigachat/langchain_gigachat/output_parsers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /libs/gigachat/langchain_gigachat/chat_models/__init__.py: -------------------------------------------------------------------------------- 1 | from langchain_gigachat.chat_models.gigachat import GigaChat 2 | 3 | __all__ = ["GigaChat"] 4 | -------------------------------------------------------------------------------- /libs/gigachat/langchain_gigachat/embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | from langchain_gigachat.embeddings.gigachat import GigaChatEmbeddings 2 | 3 | __all__ = ["GigaChatEmbeddings"] 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🦜️🔗 LangChain GigaChat 2 | 3 | This repository contains 1 package with GigaChat integrations with LangChain: 4 | 5 | - [langchain-gigachat](https://pypi.org/project/langchain-gigachat/) -------------------------------------------------------------------------------- /libs/gigachat/langchain_gigachat/__init__.py: -------------------------------------------------------------------------------- 1 | from langchain_gigachat.chat_models import GigaChat 2 | from langchain_gigachat.embeddings import GigaChatEmbeddings 3 | 4 | __all__ = ["GigaChat", "GigaChatEmbeddings"] 5 | -------------------------------------------------------------------------------- /libs/gigachat/tests/unit_tests/test_imports.py: -------------------------------------------------------------------------------- 1 | from langchain_gigachat import __all__ 2 | 3 | EXPECTED_ALL = ["GigaChat", "GigaChatEmbeddings"] 4 | 5 | 6 | def test_all_imports() -> None: 7 | assert sorted(EXPECTED_ALL) == sorted(__all__) 8 | -------------------------------------------------------------------------------- /libs/gigachat/scripts/lint_imports.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -eu 4 | 5 | # Initialize a variable to keep track of errors 6 | errors=0 7 | 8 | # make sure not importing from langchain or langchain_experimental 9 | git --no-pager grep '^from langchain\.' . && errors=$((errors+1)) 10 | git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1)) 11 | 12 | # Decide on an exit status based on the errors 13 | if [ "$errors" -gt 0 ]; then 14 | exit 1 15 | else 16 | exit 0 17 | fi -------------------------------------------------------------------------------- /libs/gigachat/scripts/check_imports.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import traceback 3 | from importlib.machinery import SourceFileLoader 4 | 5 | if __name__ == "__main__": 6 | files = sys.argv[1:] 7 | has_failure = False 8 | for file in files: 9 | try: 10 | SourceFileLoader("x", file).load_module() 11 | except Exception: 12 | has_failure = True 13 | print(file) # noqa: T201 14 | traceback.print_exc() 15 | print() # noqa: T201 16 | 17 | sys.exit(1 if has_failure else 0) 18 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_install_hook_types: 2 | - pre-commit 3 | - commit-msg 4 | - pre-push 5 | repos: 6 | - repo: local 7 | hooks: 8 | - id: format 9 | name: format 10 | language: system 11 | entry: make -C libs/gigachat format 12 | pass_filenames: false 13 | stages: [pre-commit] 14 | - id: lint 15 | name: lint 16 | language: system 17 | entry: make -C libs/gigachat lint 18 | pass_filenames: false 19 | stages: [pre-commit] 20 | - id: test 21 | name: test 22 | language: system 23 | entry: make -C libs/gigachat test 24 | pass_filenames: false 25 | stages: [pre-commit] 26 | - repo: https://github.com/commitizen-tools/commitizen 27 | rev: v3.30.0 28 | hooks: 29 | - id: commitizen 30 | stages: [commit-msg] 31 | - id: commitizen-branch 32 | stages: [pre-push] 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 GigaChain 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /libs/gigachat/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 GigaChain 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /libs/gigachat/langchain_gigachat/utils/pydantic_generator.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue, core_schema 4 | 5 | 6 | class GigaChatJsonSchema(GenerateJsonSchema): 7 | def field_is_required( 8 | self, 9 | field: Union[ 10 | core_schema.ModelField, 11 | core_schema.DataclassField, 12 | core_schema.TypedDictField, 13 | ], 14 | total: bool, 15 | ) -> bool: 16 | """ 17 | Makers nullable fields not required 18 | """ 19 | if field["schema"]["type"] == "nullable": 20 | return False 21 | return super().field_is_required(field, total) 22 | 23 | def nullable_schema(self, schema: core_schema.NullableSchema) -> JsonSchemaValue: 24 | """ 25 | Remove anyOf if field is nullable 26 | """ 27 | null_schema = {"type": "null"} 28 | inner_json_schema = self.generate_inner(schema["schema"]) 29 | 30 | if inner_json_schema == null_schema: 31 | return null_schema 32 | else: 33 | return inner_json_schema 34 | -------------------------------------------------------------------------------- /libs/gigachat/tests/unit_tests/test_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Generator 2 | 3 | import pytest 4 | import requests_mock 5 | from langchain_core.prompts.prompt import PromptTemplate 6 | 7 | from langchain_gigachat.tools.load_prompt import load_from_giga_hub 8 | 9 | 10 | @pytest.fixture 11 | def mock_requests_get() -> Generator: 12 | with requests_mock.Mocker() as mocker: 13 | mocker.get( 14 | "https://raw.githubusercontent.com/ai-forever/gigachain/master/hub/prompts/entertainment/meditation.yaml", 15 | text=( 16 | "input_variables: [background, topic]\n" 17 | "output_parser: null\n" 18 | "template: 'Create mediation for {topic} with {background}'\n" 19 | "template_format: f-string\n" 20 | "_type: prompt" 21 | ), 22 | ) 23 | yield mocker 24 | 25 | 26 | def test__load_from_giga_hub(mock_requests_get: Generator) -> None: 27 | template = load_from_giga_hub("lc://prompts/entertainment/meditation.yaml") 28 | assert isinstance(template, PromptTemplate) 29 | assert template.template == "Create mediation for {topic} with {background}" 30 | assert "background" in template.input_variables 31 | assert "topic" in template.input_variables 32 | -------------------------------------------------------------------------------- /.github/scripts/check_diff.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | from typing import Dict 4 | 5 | LIB_DIRS = ["libs/gigachat"] 6 | 7 | if __name__ == "__main__": 8 | files = sys.argv[1:] 9 | 10 | dirs_to_run: Dict[str, set] = { 11 | "lint": set(), 12 | "test": set(), 13 | } 14 | 15 | for file in files: 16 | if any( 17 | file.startswith(dir_) 18 | for dir_ in ( 19 | ".github/workflows", 20 | ".github/actions", 21 | ".github/scripts/check_diff.py", 22 | ) 23 | ): 24 | # add all LANGCHAIN_DIRS for infra changes 25 | dirs_to_run["test"].update(LIB_DIRS) 26 | 27 | if any(file.startswith(dir_) for dir_ in LIB_DIRS): 28 | for dir_ in LIB_DIRS: 29 | if file.startswith(dir_): 30 | dirs_to_run["test"].add(dir_) 31 | elif file.startswith("libs/"): 32 | raise ValueError( 33 | f"Unknown lib: {file}. check_diff.py likely needs " 34 | "an update for this new library!" 35 | ) 36 | 37 | outputs = { 38 | "dirs-to-lint": list(dirs_to_run["lint"] | dirs_to_run["test"]), 39 | "dirs-to-test": list(dirs_to_run["test"]), 40 | } 41 | for key, value in outputs.items(): 42 | json_output = json.dumps(value) 43 | print(f"{key}={json_output}") # noqa: T201 44 | -------------------------------------------------------------------------------- /.github/workflows/_test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | working-directory: 7 | required: true 8 | type: string 9 | description: "From which folder this pipeline executes" 10 | 11 | env: 12 | POETRY_VERSION: "1.7.1" 13 | 14 | jobs: 15 | build: 16 | defaults: 17 | run: 18 | working-directory: ${{ inputs.working-directory }} 19 | runs-on: ubuntu-latest 20 | strategy: 21 | matrix: 22 | python-version: 23 | - "3.9" 24 | - "3.10" 25 | - "3.11" 26 | - "3.12" 27 | name: "make test #${{ matrix.python-version }}" 28 | steps: 29 | - uses: actions/checkout@v4 30 | 31 | - name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }} 32 | uses: "./.github/actions/poetry_setup" 33 | with: 34 | python-version: ${{ matrix.python-version }} 35 | poetry-version: ${{ env.POETRY_VERSION }} 36 | working-directory: ${{ inputs.working-directory }} 37 | cache-key: core 38 | 39 | - name: Install dependencies 40 | shell: bash 41 | run: poetry install --with test 42 | 43 | - name: Run core tests 44 | shell: bash 45 | run: | 46 | make test 47 | 48 | - name: Ensure the tests did not create any additional files 49 | shell: bash 50 | run: | 51 | set -eu 52 | 53 | STATUS="$(git status)" 54 | echo "$STATUS" 55 | 56 | # grep will exit non-zero if the target message isn't found, 57 | # and `set -e` above will cause the step to fail. 58 | echo "$STATUS" | grep 'nothing to commit, working tree clean' 59 | -------------------------------------------------------------------------------- /libs/gigachat/Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: all format lint test tests check_imports dev 2 | 3 | # Default target executed when no arguments are given to make. 4 | all: help 5 | 6 | # Define a variable for the test file path. 7 | TEST_FILE ?= tests/unit_tests/ 8 | integration_test integration_tests: TEST_FILE = tests/integration_tests/ 9 | 10 | test tests integration_test integration_tests: 11 | poetry run pytest $(TEST_FILE) 12 | 13 | check_imports: $(shell find langchain_gigachat -name '*.py') 14 | poetry run python ./scripts/check_imports.py $^ 15 | 16 | dev: 17 | poetry run pre-commit install && \ 18 | git remote set-head origin -a 19 | 20 | ###################### 21 | # LINTING AND FORMATTING 22 | ###################### 23 | 24 | # Define a variable for Python and notebook files. 25 | PYTHON_FILES=. 26 | MYPY_CACHE=.mypy_cache 27 | lint format: PYTHON_FILES=. 28 | lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/gigachat --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$') 29 | lint_package: PYTHON_FILES=langchain_gigachat 30 | lint_tests: PYTHON_FILES=tests 31 | lint_tests: MYPY_CACHE=.mypy_cache_test 32 | 33 | lint lint_diff lint_package lint_tests: 34 | ./scripts/lint_imports.sh 35 | poetry run ruff check . 36 | poetry run ruff format $(PYTHON_FILES) --diff 37 | poetry run ruff check --select I $(PYTHON_FILES) 38 | mkdir -p $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) 39 | 40 | format format_diff: 41 | poetry run ruff format $(PYTHON_FILES) 42 | poetry run ruff check --select I --fix $(PYTHON_FILES) 43 | 44 | ###################### 45 | # HELP 46 | ###################### 47 | 48 | help: 49 | @echo '----' 50 | @echo 'check_imports - check imports' 51 | @echo 'format - run code formatters' 52 | @echo 'lint - run linters' 53 | @echo 'test - run unit tests' 54 | @echo 'tests - run unit tests' 55 | @echo 'test TEST_FILE= - run all tests in file' 56 | @echo 'dev - configure development environment' -------------------------------------------------------------------------------- /libs/gigachat/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "langchain-gigachat" 3 | version = "0.3.12" 4 | description = "An integration package connecting GigaChat and LangChain" 5 | authors = [] 6 | readme = "README.md" 7 | repository = "https://github.com/ai-forever/langchain-gigachat" 8 | license = "MIT" 9 | 10 | [tool.poetry.urls] 11 | "Source Code" = "https://github.com/ai-forever/langchain-gigachat/tree/master/libs/gigachat" 12 | 13 | [tool.poetry.dependencies] 14 | python = ">=3.9,<4.0" 15 | langchain-core = "^0.3" 16 | gigachat = "^0.1.41.post1" 17 | types-requests = "^2.32" 18 | 19 | [tool.poetry.group.dev] 20 | optional = true 21 | 22 | [tool.poetry.group.dev.dependencies] 23 | pre-commit = "^4.0.1" 24 | 25 | [tool.poetry.group.lint] 26 | optional = true 27 | 28 | [tool.poetry.group.lint.dependencies] 29 | ruff = "^0.7.0" 30 | 31 | [tool.poetry.group.typing] 32 | optional = true 33 | 34 | [tool.poetry.group.typing.dependencies] 35 | mypy = "^1.13.0" 36 | 37 | [tool.poetry.group.test] 38 | optional = true 39 | 40 | [tool.poetry.group.test.dependencies] 41 | pytest = "^8.3.3" 42 | pytest-cov = "^5.0.0" 43 | pytest-asyncio = "^0.24.0" 44 | pytest-mock = "^3.14.0" 45 | requests_mock = "^1.12.1" 46 | 47 | [build-system] 48 | requires = ["poetry-core>=1.0.0"] 49 | build-backend = "poetry.core.masonry.api" 50 | 51 | [tool.ruff.lint] 52 | select = [ 53 | "E", # pycodestyle 54 | "F", # pyflakes 55 | "I", # isort 56 | "T201", # print 57 | ] 58 | 59 | [tool.ruff.format] 60 | 61 | [tool.mypy] 62 | ignore_missing_imports = "True" 63 | disallow_untyped_defs = "True" 64 | 65 | [tool.pytest.ini_options] 66 | addopts = "--strict-markers --strict-config --durations=5 --cov=langchain_gigachat -vv" 67 | markers = [ 68 | "requires: mark tests as requiring a specific library", 69 | "compile: mark placeholder test used to compile integration tests without running them", 70 | "scheduled: mark tests to run in scheduled testing", 71 | ] 72 | asyncio_mode = "auto" 73 | filterwarnings = [ 74 | "ignore::langchain_core._api.beta_decorator.LangChainBetaWarning", 75 | ] 76 | 77 | [tool.coverage.run] 78 | omit = ["tests/*"] 79 | -------------------------------------------------------------------------------- /libs/gigachat/README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | [![GitHub Release](https://img.shields.io/github/v/release/ai-forever/langchain-gigachat?style=flat-square)](https://github.com/ai-forever/langchain-gigachat/releases) 4 | [![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/ai-forever/langchain-gigachat/check_diffs.yml?style=flat-square)](https://github.com/ai-forever/langchain-gigachat/actions/workflows/check_diffs.yml) 5 | [![GitHub License](https://img.shields.io/github/license/ai-forever/langchain-gigachat?style=flat-square)](https://opensource.org/license/MIT) 6 | [![GitHub Downloads (all assets, all releases)](https://img.shields.io/pypi/dm/langchain-gigachat?style=flat-square?style=flat-square)](https://pypistats.org/packages/langchain-gigachat) 7 | [![GitHub Repo stars](https://img.shields.io/github/stars/ai-forever/langchain-gigachat?style=flat-square)](https://star-history.com/#ai-forever/langchain-gigachat) 8 | [![GitHub Open Issues](https://img.shields.io/github/issues-raw/ai-forever/langchain-gigachat)](https://github.com/ai-forever/langchain-gigachat/issues) 9 | 10 | [English](README.md) | [Русский](README-ru_RU.md) 11 | 12 |
13 | 14 | # langchain-gigachat 15 | 16 | This is a library integration with [GigaChat](https://giga.chat/). 17 | 18 | ## Installation 19 | 20 | ```bash 21 | pip install -U langchain-gigachat 22 | ``` 23 | 24 | ## Quickstart 25 | Follow these simple steps to get up and running quickly. 26 | 27 | ### Installation 28 | 29 | To install the package use following command: 30 | 31 | ```shell 32 | pip install -U langchain-gigachat 33 | ``` 34 | 35 | ### Initialization 36 | 37 | To initialize chat model: 38 | 39 | ```python 40 | from langchain_gigachat.chat_models import GigaChat 41 | 42 | giga = GigaChat(credentials="YOUR_AUTHORIZATION_KEY", verify_ssl_certs=False) 43 | ``` 44 | 45 | To initialize embeddings: 46 | 47 | ```python 48 | from langchain_gigachat.embeddings import GigaChatEmbeddings 49 | 50 | embedding = GigaChatEmbeddings( 51 | credentials="YOUR_AUTHORIZATION_KEY", 52 | verify_ssl_certs=False 53 | ) 54 | ``` 55 | 56 | ### Usage 57 | 58 | Use the GigaChat object to generate responses: 59 | 60 | ```python 61 | print(giga.invoke("Hello, world!")) 62 | ``` 63 | 64 | Now you can use the GigaChat object with LangChain's standard primitives to create LLM-applications. -------------------------------------------------------------------------------- /.github/workflows/check_diffs.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: CI 3 | 4 | on: 5 | push: 6 | branches: 7 | - master 8 | - dev 9 | pull_request: 10 | 11 | # If another push to the same PR or branch happens while this workflow is still running, 12 | # cancel the earlier run in favor of the next run. 13 | # 14 | # There's no point in testing an outdated version of the code. GitHub only allows 15 | # a limited number of job runners to be active at the same time, so it's better to cancel 16 | # pointless jobs early so that more useful jobs can run sooner. 17 | concurrency: 18 | group: ${{ github.workflow }}-${{ github.ref }} 19 | cancel-in-progress: true 20 | 21 | env: 22 | POETRY_VERSION: "1.7.1" 23 | 24 | jobs: 25 | build: 26 | runs-on: ubuntu-latest 27 | steps: 28 | - uses: actions/checkout@v4 29 | - uses: actions/setup-python@v5 30 | with: 31 | python-version: "3.10" 32 | - id: files 33 | uses: Ana06/get-changed-files@v2.3.0 34 | - id: set-matrix 35 | run: | 36 | python .github/scripts/check_diff.py ${{ steps.files.outputs.all }} >> $GITHUB_OUTPUT 37 | outputs: 38 | dirs-to-lint: ${{ steps.set-matrix.outputs.dirs-to-lint }} 39 | dirs-to-test: ${{ steps.set-matrix.outputs.dirs-to-test }} 40 | 41 | lint: 42 | name: cd ${{ matrix.working-directory }} 43 | needs: [build] 44 | if: ${{ needs.build.outputs.dirs-to-lint != '[]' }} 45 | strategy: 46 | matrix: 47 | working-directory: ${{ fromJson(needs.build.outputs.dirs-to-lint) }} 48 | uses: ./.github/workflows/_lint.yml 49 | with: 50 | working-directory: ${{ matrix.working-directory }} 51 | secrets: inherit 52 | 53 | test: 54 | name: cd ${{ matrix.working-directory }} 55 | needs: [build] 56 | if: ${{ needs.build.outputs.dirs-to-test != '[]' }} 57 | strategy: 58 | matrix: 59 | working-directory: ${{ fromJson(needs.build.outputs.dirs-to-test) }} 60 | uses: ./.github/workflows/_test.yml 61 | with: 62 | working-directory: ${{ matrix.working-directory }} 63 | secrets: inherit 64 | 65 | ci_success: 66 | name: "CI Success" 67 | needs: [build, lint, test] 68 | if: | 69 | always() 70 | runs-on: ubuntu-latest 71 | env: 72 | JOBS_JSON: ${{ toJSON(needs) }} 73 | RESULTS_JSON: ${{ toJSON(needs.*.result) }} 74 | EXIT_CODE: ${{!contains(needs.*.result, 'failure') && !contains(needs.*.result, 'cancelled') && '0' || '1'}} 75 | steps: 76 | - name: "CI Success" 77 | run: | 78 | echo $JOBS_JSON 79 | echo $RESULTS_JSON 80 | echo "Exiting with $EXIT_CODE" 81 | exit $EXIT_CODE 82 | -------------------------------------------------------------------------------- /libs/gigachat/langchain_gigachat/tools/load_prompt.py: -------------------------------------------------------------------------------- 1 | """Utilities for loading templates from gigachain 2 | github-based hub or other extenal sources.""" 3 | 4 | import os 5 | import re 6 | import tempfile 7 | from pathlib import Path, PurePosixPath 8 | from typing import Any, Callable, Optional, Set, TypeVar, Union 9 | from urllib.parse import urljoin 10 | 11 | import requests 12 | from langchain_core.prompts.base import BasePromptTemplate 13 | from langchain_core.prompts.loading import _load_prompt_from_file 14 | 15 | DEFAULT_REF = os.environ.get("GIGACHAIN_HUB_DEFAULT_REF", "master") 16 | URL_BASE = os.environ.get( 17 | "GIGACHAIN_HUB_DEFAULT_REF", 18 | "https://raw.githubusercontent.com/ai-forever/gigachain/{ref}/hub/", 19 | ) 20 | HUB_PATH_RE = re.compile(r"lc(?P@[^:]+)?://(?P.*)") 21 | 22 | T = TypeVar("T") 23 | 24 | 25 | def _load_from_giga_hub( 26 | path: Union[str, Path], 27 | loader: Callable[[str], T], 28 | valid_prefix: str, 29 | valid_suffixes: Set[str], 30 | **kwargs: Any, 31 | ) -> Optional[T]: 32 | """Load configuration from hub. Returns None if path is not a hub path.""" 33 | if not isinstance(path, str) or not (match := HUB_PATH_RE.match(path)): 34 | return None 35 | ref, remote_path_str = match.groups() 36 | ref = ref[1:] if ref else DEFAULT_REF 37 | remote_path = Path(remote_path_str) 38 | if remote_path.parts[0] != valid_prefix: 39 | return None 40 | if remote_path.suffix[1:] not in valid_suffixes: 41 | raise ValueError(f"Unsupported file type, must be one of {valid_suffixes}.") 42 | 43 | # Using Path with URLs is not recommended, because on Windows 44 | # the backslash is used as the path separator, which can cause issues 45 | # when working with URLs that use forward slashes as the path separator. 46 | # Instead, use PurePosixPath to ensure that forward slashes are used as the 47 | # path separator, regardless of the operating system. 48 | full_url = urljoin(URL_BASE.format(ref=ref), PurePosixPath(remote_path).__str__()) 49 | 50 | r = requests.get(full_url, timeout=5) 51 | if r.status_code != 200: 52 | raise ValueError(f"Could not find file at {full_url}") 53 | with tempfile.TemporaryDirectory() as tmpdirname: 54 | file = Path(tmpdirname) / remote_path.name 55 | with open(file, "wb") as f: 56 | f.write(r.content) 57 | return loader(str(file), **kwargs) 58 | 59 | 60 | def load_from_giga_hub(path: Union[str, Path]) -> BasePromptTemplate: 61 | """Unified method for loading a prompt from GigaChain repo or local fs.""" 62 | if hub_result := _load_from_giga_hub( 63 | path, _load_prompt_from_file, "prompts", {"py", "json", "yaml"} 64 | ): 65 | return hub_result 66 | else: 67 | raise ValueError("Prompt not found in GigaChain Hub.") 68 | -------------------------------------------------------------------------------- /.github/scripts/get_min_versions.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | if sys.version_info >= (3, 11): 4 | import tomllib 5 | else: 6 | # for python 3.10 and below, which doesnt have stdlib tomllib 7 | import tomli as tomllib 8 | 9 | import re 10 | 11 | from packaging.version import parse as parse_version 12 | 13 | MIN_VERSION_LIBS = ["langchain-core"] 14 | 15 | SKIP_IF_PULL_REQUEST = ["langchain-core"] 16 | 17 | 18 | def get_min_version(version: str) -> str: 19 | # base regex for x.x.x with cases for rc/post/etc 20 | # valid strings: https://peps.python.org/pep-0440/#public-version-identifiers 21 | vstring = r"\d+(?:\.\d+){0,2}(?:(?:a|b|rc|\.post|\.dev)\d+)?" 22 | # case ^x.x.x 23 | _match = re.match(f"^\\^({vstring})$", version) 24 | if _match: 25 | return _match.group(1) 26 | 27 | # case >=x.x.x,=({vstring}),<({vstring})$", version) 29 | if _match: 30 | _min = _match.group(1) 31 | _max = _match.group(2) 32 | assert parse_version(_min) < parse_version(_max) 33 | return _min 34 | 35 | # case x.x.x 36 | _match = re.match(f"^({vstring})$", version) 37 | if _match: 38 | return _match.group(1) 39 | 40 | raise ValueError(f"Unrecognized version format: {version}") 41 | 42 | 43 | def get_min_version_from_toml(toml_path: str, versions_for: str): 44 | # Parse the TOML file 45 | with open(toml_path, "rb") as file: 46 | toml_data = tomllib.load(file) 47 | 48 | # Get the dependencies from tool.poetry.dependencies 49 | dependencies = toml_data["tool"]["poetry"]["dependencies"] 50 | 51 | # Initialize a dictionary to store the minimum versions 52 | min_versions = {} 53 | 54 | # Iterate over the libs in MIN_VERSION_LIBS 55 | for lib in MIN_VERSION_LIBS: 56 | if versions_for == "pull_request" and lib in SKIP_IF_PULL_REQUEST: 57 | # some libs only get checked on release because of simultaneous 58 | # changes 59 | continue 60 | # Check if the lib is present in the dependencies 61 | if lib in dependencies: 62 | # Get the version string 63 | version_string = dependencies[lib] 64 | 65 | if isinstance(version_string, dict): 66 | version_string = version_string["version"] 67 | 68 | # Use parse_version to get the minimum supported version from version_string 69 | min_version = get_min_version(version_string) 70 | 71 | # Store the minimum version in the min_versions dictionary 72 | min_versions[lib] = min_version 73 | 74 | return min_versions 75 | 76 | 77 | if __name__ == "__main__": 78 | # Get the TOML file path from the command line argument 79 | toml_file = sys.argv[1] 80 | versions_for = sys.argv[2] 81 | assert versions_for in ["release", "pull_request"] 82 | 83 | # Call the function to get the minimum versions 84 | min_versions = get_min_version_from_toml(toml_file, versions_for) 85 | 86 | print(" ".join([f"{lib}=={version}" for lib, version in min_versions.items()])) 87 | -------------------------------------------------------------------------------- /.github/actions/poetry_setup/action.yml: -------------------------------------------------------------------------------- 1 | # An action for setting up poetry install with caching. 2 | # Using a custom action since the default action does not 3 | # take poetry install groups into account. 4 | # Action code from: 5 | # https://github.com/actions/setup-python/issues/505#issuecomment-1273013236 6 | name: poetry-install-with-caching 7 | description: Poetry install with support for caching of dependency groups. 8 | 9 | inputs: 10 | python-version: 11 | description: Python version, supporting MAJOR.MINOR only 12 | required: true 13 | 14 | poetry-version: 15 | description: Poetry version 16 | required: true 17 | 18 | cache-key: 19 | description: Cache key to use for manual handling of caching 20 | required: true 21 | 22 | working-directory: 23 | description: Directory whose poetry.lock file should be cached 24 | required: true 25 | 26 | runs: 27 | using: composite 28 | steps: 29 | - uses: actions/setup-python@v5 30 | name: Setup python ${{ inputs.python-version }} 31 | id: setup-python 32 | with: 33 | python-version: ${{ inputs.python-version }} 34 | 35 | - uses: actions/cache@v4 36 | id: cache-bin-poetry 37 | name: Cache Poetry binary - Python ${{ inputs.python-version }} 38 | env: 39 | SEGMENT_DOWNLOAD_TIMEOUT_MIN: "1" 40 | with: 41 | path: | 42 | /opt/pipx/venvs/poetry 43 | # This step caches the poetry installation, so make sure it's keyed on the poetry version as well. 44 | key: bin-poetry-${{ runner.os }}-${{ runner.arch }}-py-${{ inputs.python-version }}-${{ inputs.poetry-version }} 45 | 46 | - name: Refresh shell hashtable and fixup softlinks 47 | if: steps.cache-bin-poetry.outputs.cache-hit == 'true' 48 | shell: bash 49 | env: 50 | POETRY_VERSION: ${{ inputs.poetry-version }} 51 | PYTHON_VERSION: ${{ inputs.python-version }} 52 | run: | 53 | set -eux 54 | 55 | # Refresh the shell hashtable, to ensure correct `which` output. 56 | hash -r 57 | 58 | # `actions/cache@v3` doesn't always seem able to correctly unpack softlinks. 59 | # Delete and recreate the softlinks pipx expects to have. 60 | rm /opt/pipx/venvs/poetry/bin/python 61 | cd /opt/pipx/venvs/poetry/bin 62 | ln -s "$(which "python$PYTHON_VERSION")" python 63 | chmod +x python 64 | cd /opt/pipx_bin/ 65 | ln -s /opt/pipx/venvs/poetry/bin/poetry poetry 66 | chmod +x poetry 67 | 68 | # Ensure everything got set up correctly. 69 | /opt/pipx/venvs/poetry/bin/python --version 70 | /opt/pipx_bin/poetry --version 71 | 72 | - name: Install poetry 73 | if: steps.cache-bin-poetry.outputs.cache-hit != 'true' 74 | shell: bash 75 | env: 76 | POETRY_VERSION: ${{ inputs.poetry-version }} 77 | PYTHON_VERSION: ${{ inputs.python-version }} 78 | # Install poetry using the python version installed by setup-python step. 79 | run: pipx install "poetry==$POETRY_VERSION" --python '${{ steps.setup-python.outputs.python-path }}' --verbose 80 | 81 | - name: Restore pip and poetry cached dependencies 82 | uses: actions/cache@v4 83 | env: 84 | SEGMENT_DOWNLOAD_TIMEOUT_MIN: "4" 85 | WORKDIR: ${{ inputs.working-directory == '' && '.' || inputs.working-directory }} 86 | with: 87 | path: | 88 | ~/.cache/pip 89 | ~/.cache/pypoetry/virtualenvs 90 | ~/.cache/pypoetry/cache 91 | ~/.cache/pypoetry/artifacts 92 | ${{ env.WORKDIR }}/.venv 93 | key: py-deps-${{ runner.os }}-${{ runner.arch }}-py-${{ inputs.python-version }}-poetry-${{ inputs.poetry-version }}-${{ inputs.cache-key }}-${{ hashFiles(format('{0}/**/poetry.lock', env.WORKDIR)) }} 94 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .mypy_cache_test/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | 165 | # VS Code configuration 166 | .vscode/ 167 | 168 | # ruff 169 | .ruff_cache/ -------------------------------------------------------------------------------- /.github/workflows/_lint.yml: -------------------------------------------------------------------------------- 1 | name: lint 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | working-directory: 7 | required: true 8 | type: string 9 | description: "From which folder this pipeline executes" 10 | 11 | env: 12 | POETRY_VERSION: "1.7.1" 13 | WORKDIR: ${{ inputs.working-directory == '' && '.' || inputs.working-directory }} 14 | 15 | # This env var allows us to get inline annotations when ruff has complaints. 16 | RUFF_OUTPUT_FORMAT: github 17 | 18 | jobs: 19 | build: 20 | name: "make lint #${{ matrix.python-version }}" 21 | runs-on: ubuntu-latest 22 | strategy: 23 | matrix: 24 | # Only lint on the min and max supported Python versions. 25 | # It's extremely unlikely that there's a lint issue on any version in between 26 | # that doesn't show up on the min or max versions. 27 | # 28 | # GitHub rate-limits how many jobs can be running at any one time. 29 | # Starting new jobs is also relatively slow, 30 | # so linting on fewer versions makes CI faster. 31 | python-version: 32 | - "3.9" 33 | - "3.12" 34 | steps: 35 | - uses: actions/checkout@v4 36 | 37 | - name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }} 38 | uses: "./.github/actions/poetry_setup" 39 | with: 40 | python-version: ${{ matrix.python-version }} 41 | poetry-version: ${{ env.POETRY_VERSION }} 42 | working-directory: ${{ inputs.working-directory }} 43 | cache-key: lint-with-extras 44 | 45 | - name: Check Poetry File 46 | shell: bash 47 | working-directory: ${{ inputs.working-directory }} 48 | run: | 49 | poetry check 50 | 51 | - name: Check lock file 52 | shell: bash 53 | working-directory: ${{ inputs.working-directory }} 54 | run: | 55 | poetry lock --check 56 | 57 | - name: Install dependencies 58 | # Also installs dev/lint/test/typing dependencies, to ensure we have 59 | # type hints for as many of our libraries as possible. 60 | # This helps catch errors that require dependencies to be spotted, for example: 61 | # https://github.com/langchain-ai/langchain/pull/10249/files#diff-935185cd488d015f026dcd9e19616ff62863e8cde8c0bee70318d3ccbca98341 62 | # 63 | # If you change this configuration, make sure to change the `cache-key` 64 | # in the `poetry_setup` action above to stop using the old cache. 65 | # It doesn't matter how you change it, any change will cause a cache-bust. 66 | working-directory: ${{ inputs.working-directory }} 67 | run: | 68 | poetry install --with lint,typing 69 | 70 | - name: Get .mypy_cache to speed up mypy 71 | uses: actions/cache@v4 72 | env: 73 | SEGMENT_DOWNLOAD_TIMEOUT_MIN: "2" 74 | with: 75 | path: | 76 | ${{ env.WORKDIR }}/.mypy_cache 77 | key: mypy-lint-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/poetry.lock', inputs.working-directory)) }} 78 | 79 | - name: Analysing the code with our lint 80 | working-directory: ${{ inputs.working-directory }} 81 | run: | 82 | make lint_package 83 | 84 | - name: Install unit test dependencies 85 | working-directory: ${{ inputs.working-directory }} 86 | run: | 87 | poetry install --with test 88 | 89 | - name: Get .mypy_cache_test to speed up mypy 90 | uses: actions/cache@v4 91 | env: 92 | SEGMENT_DOWNLOAD_TIMEOUT_MIN: "2" 93 | with: 94 | path: | 95 | ${{ env.WORKDIR }}/.mypy_cache_test 96 | key: mypy-test-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/poetry.lock', inputs.working-directory)) }} 97 | 98 | - name: Analysing the code with our lint 99 | working-directory: ${{ inputs.working-directory }} 100 | run: | 101 | make lint_tests 102 | -------------------------------------------------------------------------------- /libs/gigachat/langchain_gigachat/output_parsers/gigachat_functions.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from types import GenericAlias 3 | from typing import Any, Dict, List, Type, Union 4 | 5 | from langchain_core.exceptions import OutputParserException 6 | from langchain_core.output_parsers import BaseGenerationOutputParser 7 | from langchain_core.outputs import ChatGeneration, Generation 8 | from pydantic import BaseModel, model_validator 9 | 10 | 11 | class OutputFunctionsParser(BaseGenerationOutputParser[Any]): 12 | """Parse an output that is one of sets of values.""" 13 | 14 | args_only: bool = True 15 | """Whether to only return the arguments to the function call.""" 16 | 17 | def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: 18 | generation = result[0] 19 | if not isinstance(generation, ChatGeneration): 20 | raise OutputParserException( 21 | "This output parser can only be used with a chat generation." 22 | ) 23 | message = generation.message 24 | try: 25 | func_call = copy.deepcopy(message.additional_kwargs["function_call"]) 26 | except KeyError as exc: 27 | raise OutputParserException( 28 | f"Could not parse function call: {exc}" 29 | ) from exc 30 | 31 | if self.args_only: 32 | return func_call["arguments"] 33 | return func_call 34 | 35 | 36 | class PydanticOutputFunctionsParser(OutputFunctionsParser): 37 | """Parse an output as a pydantic object.""" 38 | 39 | pydantic_schema: Union[Type[BaseModel], Dict[str, Type[BaseModel]]] 40 | """The pydantic schema to parse the output with. 41 | 42 | If multiple schemas are provided, then the function name will be used to 43 | determine which schema to use. 44 | """ 45 | 46 | @model_validator(mode="before") 47 | @classmethod 48 | def validate_schema(cls, values: dict) -> Any: 49 | """Validate the pydantic schema. 50 | 51 | Args: 52 | values: The values to validate. 53 | 54 | Returns: 55 | The validated values. 56 | 57 | Raises: 58 | ValueError: If the schema is not a pydantic schema. 59 | """ 60 | schema = values["pydantic_schema"] 61 | if "args_only" not in values: 62 | values["args_only"] = ( 63 | isinstance(schema, type) 64 | and not isinstance(schema, GenericAlias) 65 | and issubclass(schema, BaseModel) 66 | ) 67 | elif values["args_only"] and isinstance(schema, dict): 68 | msg = ( 69 | "If multiple pydantic schemas are provided then args_only should be" 70 | " False." 71 | ) 72 | raise ValueError(msg) 73 | return values 74 | 75 | def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any: 76 | """Parse the result of an LLM call to a JSON object. 77 | 78 | Args: 79 | result: The result of the LLM call. 80 | partial: Whether to parse partial JSON objects. Default is False. 81 | 82 | Returns: 83 | The parsed JSON object. 84 | """ 85 | _result = super().parse_result(result) 86 | if self.args_only: 87 | if hasattr(self.pydantic_schema, "model_validate"): 88 | pydantic_args = self.pydantic_schema.model_validate(_result) 89 | else: 90 | pydantic_args = self.pydantic_schema.parse_obj(_result) # type: ignore 91 | else: 92 | fn_name = _result["name"] 93 | _args = _result["arguments"] 94 | if isinstance(self.pydantic_schema, dict): 95 | pydantic_schema = self.pydantic_schema[fn_name] 96 | else: 97 | pydantic_schema = self.pydantic_schema 98 | if hasattr(pydantic_schema, "model_validate"): 99 | pydantic_args = pydantic_schema.model_validate(_args) # type: ignore 100 | else: 101 | pydantic_args = pydantic_schema.parse_obj(_args) # type: ignore 102 | return pydantic_args 103 | 104 | 105 | class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser): 106 | """Parse an output as an attribute of a pydantic object.""" 107 | 108 | attr_name: str 109 | """The name of the attribute to return.""" 110 | 111 | def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: 112 | result = super().parse_result(result) 113 | return getattr(result, self.attr_name) 114 | -------------------------------------------------------------------------------- /libs/gigachat/langchain_gigachat/embeddings/gigachat.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | import ssl 5 | from functools import cached_property 6 | from typing import Any, Dict, List, Optional 7 | 8 | from langchain_core.embeddings import Embeddings 9 | from langchain_core.utils import pre_init 10 | from langchain_core.utils.pydantic import get_fields 11 | from pydantic import BaseModel 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | MAX_BATCH_SIZE_CHARS = 1000000 16 | MAX_BATCH_SIZE_PARTS = 90 17 | 18 | 19 | class GigaChatEmbeddings(BaseModel, Embeddings): 20 | """GigaChat Embeddings models. 21 | 22 | Example: 23 | .. code-block:: python 24 | from langchain_community.embeddings.gigachat import GigaChatEmbeddings 25 | 26 | embeddings = 27 | GigaChatEmbeddings(credentials=..., scope=..., verify_ssl_certs=False) 28 | """ 29 | 30 | """ DEPRECATED: Send texts one-by-one to server (to increase token limit) """ 31 | one_by_one_mode: bool = False 32 | """ DEPRECATED: Debug timeout for limit rps to server """ 33 | _debug_delay: float = 0 34 | 35 | base_url: Optional[str] = None 36 | """ Base API URL """ 37 | auth_url: Optional[str] = None 38 | """ Auth URL """ 39 | credentials: Optional[str] = None 40 | """ Auth Token """ 41 | scope: Optional[str] = None 42 | """ Permission scope for access token """ 43 | 44 | access_token: Optional[str] = None 45 | """ Access token for GigaChat """ 46 | 47 | model: Optional[str] = None 48 | """Model name to use.""" 49 | user: Optional[str] = None 50 | """ Username for authenticate """ 51 | password: Optional[str] = None 52 | """ Password for authenticate """ 53 | 54 | timeout: Optional[float] = 600 55 | """ Timeout for request. By default it works for long requests. """ 56 | verify_ssl_certs: Optional[bool] = None 57 | """ Check certificates for all requests """ 58 | 59 | ssl_context: Optional[ssl.SSLContext] = None 60 | 61 | class Config: 62 | arbitrary_types_allowed = True 63 | 64 | ca_bundle_file: Optional[str] = None 65 | cert_file: Optional[str] = None 66 | key_file: Optional[str] = None 67 | key_file_password: Optional[str] = None 68 | # Support for connection to GigaChat through SSL certificates 69 | 70 | prefix_query: str = ( 71 | "Дано предложение, необходимо найти его парафраз \nпредложение: " 72 | ) 73 | 74 | use_prefix_query: bool = False 75 | 76 | @cached_property 77 | def _client(self) -> Any: 78 | """Returns GigaChat API client""" 79 | import gigachat 80 | 81 | return gigachat.GigaChat( 82 | base_url=self.base_url, 83 | auth_url=self.auth_url, 84 | credentials=self.credentials, 85 | scope=self.scope, 86 | access_token=self.access_token, 87 | model=self.model, 88 | user=self.user, 89 | password=self.password, 90 | timeout=self.timeout, 91 | ssl_context=self.ssl_context, 92 | verify_ssl_certs=self.verify_ssl_certs, 93 | ca_bundle_file=self.ca_bundle_file, 94 | cert_file=self.cert_file, 95 | key_file=self.key_file, 96 | key_file_password=self.key_file_password, 97 | ) 98 | 99 | @pre_init 100 | def validate_environment(cls, values: Dict) -> Dict: 101 | """Validate authenticate data in environment and python package is installed.""" 102 | try: 103 | import gigachat # noqa: F401 104 | except ImportError: 105 | raise ImportError( 106 | "Could not import gigachat python package. " 107 | "Please install it with `pip install gigachat`." 108 | ) 109 | fields = set(get_fields(cls).keys()) 110 | diff = set(values.keys()) - fields 111 | if diff: 112 | logger.warning(f"Extra fields {diff} in GigaChat class") 113 | return values 114 | 115 | def embed_documents(self, texts: List[str]) -> List[List[float]]: 116 | """Embed documents using a GigaChat embeddings models. 117 | 118 | Args: 119 | texts: The list of texts to embed. 120 | 121 | Returns: 122 | List of embeddings, one for each text. 123 | """ 124 | result: List[List[float]] = [] 125 | size = 0 126 | local_texts = [] 127 | embed_kwargs = {} 128 | if self.model is not None: 129 | embed_kwargs["model"] = self.model 130 | for text in texts: 131 | local_texts.append(text) 132 | size += len(text) 133 | if size > MAX_BATCH_SIZE_CHARS or len(local_texts) > MAX_BATCH_SIZE_PARTS: 134 | for embedding in self._client.embeddings( 135 | texts=local_texts, **embed_kwargs 136 | ).data: 137 | result.append(embedding.embedding) 138 | size = 0 139 | local_texts = [] 140 | # Call for last iteration 141 | if local_texts: 142 | for embedding in self._client.embeddings( 143 | texts=local_texts, **embed_kwargs 144 | ).data: 145 | result.append(embedding.embedding) 146 | 147 | return result 148 | 149 | async def aembed_documents(self, texts: List[str]) -> List[List[float]]: 150 | """Embed documents using a GigaChat embeddings models. 151 | 152 | Args: 153 | texts: The list of texts to embed. 154 | 155 | Returns: 156 | List of embeddings, one for each text. 157 | """ 158 | result: List[List[float]] = [] 159 | size = 0 160 | local_texts = [] 161 | embed_kwargs = {} 162 | if self.model is not None: 163 | embed_kwargs["model"] = self.model 164 | for text in texts: 165 | local_texts.append(text) 166 | size += len(text) 167 | if size > MAX_BATCH_SIZE_CHARS or len(local_texts) > MAX_BATCH_SIZE_PARTS: 168 | embeddings = await self._client.aembeddings( 169 | texts=local_texts, **embed_kwargs 170 | ) 171 | for embedding in embeddings.data: 172 | result.append(embedding.embedding) 173 | size = 0 174 | local_texts = [] 175 | # Call for last iteration 176 | if local_texts: 177 | embeddings = await self._client.aembeddings( 178 | texts=local_texts, **embed_kwargs 179 | ) 180 | for embedding in embeddings.data: 181 | result.append(embedding.embedding) 182 | 183 | return result 184 | 185 | def embed_query(self, text: str) -> List[float]: 186 | """Embed a query using a GigaChat embeddings models. 187 | 188 | Args: 189 | text: The text to embed. 190 | 191 | Returns: 192 | Embeddings for the text. 193 | """ 194 | if self.use_prefix_query: 195 | text = self.prefix_query + text 196 | return self.embed_documents(texts=[text])[0] 197 | 198 | async def aembed_query(self, text: str) -> List[float]: 199 | """Embed a query using a GigaChat embeddings models. 200 | 201 | Args: 202 | text: The text to embed. 203 | 204 | Returns: 205 | Embeddings for the text. 206 | """ 207 | if self.use_prefix_query: 208 | text = self.prefix_query + text 209 | docs = await self.aembed_documents(texts=[text]) 210 | return docs[0] 211 | -------------------------------------------------------------------------------- /libs/gigachat/langchain_gigachat/chat_models/base_gigachat.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | import ssl 5 | from functools import cached_property 6 | from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional 7 | 8 | from langchain_core.load.serializable import Serializable 9 | from langchain_core.utils import pre_init 10 | from langchain_core.utils.pydantic import get_fields 11 | 12 | if TYPE_CHECKING: 13 | import gigachat 14 | import gigachat.models as gm 15 | from gigachat._types import FileTypes 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class _BaseGigaChat(Serializable): 21 | base_url: Optional[str] = None 22 | """ Base API URL """ 23 | auth_url: Optional[str] = None 24 | """ Auth URL """ 25 | credentials: Optional[str] = None 26 | """ Auth Token """ 27 | scope: Optional[str] = None 28 | """ Permission scope for access token """ 29 | 30 | access_token: Optional[str] = None 31 | """ Access token for GigaChat """ 32 | 33 | model: Optional[str] = None 34 | """Model name to use.""" 35 | user: Optional[str] = None 36 | """ Username for authenticate """ 37 | password: Optional[str] = None 38 | """ Password for authenticate """ 39 | 40 | timeout: Optional[float] = None 41 | """ Timeout for request """ 42 | verify_ssl_certs: Optional[bool] = None 43 | """ Check certificates for all requests """ 44 | 45 | ssl_context: Optional[ssl.SSLContext] = None 46 | 47 | class Config: 48 | arbitrary_types_allowed = True 49 | 50 | ca_bundle_file: Optional[str] = None 51 | cert_file: Optional[str] = None 52 | key_file: Optional[str] = None 53 | key_file_password: Optional[str] = None 54 | # Support for connection to GigaChat through SSL certificates 55 | 56 | profanity: bool = True 57 | """ DEPRECATED: Check for profanity """ 58 | profanity_check: Optional[bool] = None 59 | """ Check for profanity """ 60 | streaming: bool = False 61 | """ Whether to stream the results or not. """ 62 | temperature: Optional[float] = None 63 | """ What sampling temperature to use. """ 64 | max_tokens: Optional[int] = None 65 | """ Maximum number of tokens to generate """ 66 | use_api_for_tokens: bool = False 67 | """ Use GigaChat API for tokens count """ 68 | verbose: bool = False 69 | """ Verbose logging """ 70 | flags: Optional[List[str]] = None 71 | """ Feature flags """ 72 | top_p: Optional[float] = None 73 | """ top_p value to use for nucleus sampling. Must be between 0.0 and 1.0 """ 74 | repetition_penalty: Optional[float] = None 75 | """ The penalty applied to repeated tokens """ 76 | update_interval: Optional[float] = None 77 | """ Minimum interval in seconds that elapses between sending tokens """ 78 | 79 | @property 80 | def _llm_type(self) -> str: 81 | return "giga-chat-model" 82 | 83 | @property 84 | def lc_secrets(self) -> Dict[str, str]: 85 | return { 86 | "credentials": "GIGACHAT_CREDENTIALS", 87 | "access_token": "GIGACHAT_ACCESS_TOKEN", 88 | "password": "GIGACHAT_PASSWORD", 89 | "key_file_password": "GIGACHAT_KEY_FILE_PASSWORD", 90 | } 91 | 92 | @classmethod 93 | def is_lc_serializable(cls) -> bool: 94 | return True 95 | 96 | @cached_property 97 | def _client(self) -> gigachat.GigaChat: 98 | """Returns GigaChat API client""" 99 | import gigachat 100 | 101 | return gigachat.GigaChat( 102 | base_url=self.base_url, 103 | auth_url=self.auth_url, 104 | credentials=self.credentials, 105 | scope=self.scope, 106 | access_token=self.access_token, 107 | model=self.model, 108 | profanity_check=self.profanity_check, 109 | user=self.user, 110 | password=self.password, 111 | timeout=self.timeout, 112 | ssl_context=self.ssl_context, 113 | verify_ssl_certs=self.verify_ssl_certs, 114 | ca_bundle_file=self.ca_bundle_file, 115 | cert_file=self.cert_file, 116 | key_file=self.key_file, 117 | key_file_password=self.key_file_password, 118 | verbose=self.verbose, 119 | flags=self.flags, 120 | ) 121 | 122 | @pre_init 123 | def validate_environment(cls, values: Dict) -> Dict: 124 | """Validate authenticate data in environment and python package is installed.""" 125 | try: 126 | import gigachat # noqa: F401 127 | except ImportError: 128 | raise ImportError( 129 | "Could not import gigachat python package. " 130 | "Please install it with `pip install gigachat`." 131 | ) 132 | fields = set(get_fields(cls).keys()) 133 | diff = set(values.keys()) - fields 134 | if diff: 135 | logger.warning(f"Extra fields {diff} in GigaChat class") 136 | if "profanity" in fields and values.get("profanity") is False: 137 | logger.warning( 138 | "'profanity' field is deprecated. Use 'profanity_check' instead." 139 | ) 140 | if values.get("profanity_check") is None: 141 | values["profanity_check"] = values.get("profanity") 142 | return values 143 | 144 | @property 145 | def _identifying_params(self) -> Dict[str, Any]: 146 | """Get the identifying parameters.""" 147 | return { 148 | "temperature": self.temperature, 149 | "model": self.model, 150 | "profanity": self.profanity_check, 151 | "streaming": self.streaming, 152 | "max_tokens": self.max_tokens, 153 | "top_p": self.top_p, 154 | "repetition_penalty": self.repetition_penalty, 155 | } 156 | 157 | def tokens_count( 158 | self, input_: List[str], model: Optional[str] = None 159 | ) -> List[gm.TokensCount]: 160 | """Get tokens of string list""" 161 | return self._client.tokens_count(input_, model) 162 | 163 | async def atokens_count( 164 | self, input_: List[str], model: Optional[str] = None 165 | ) -> List[gm.TokensCount]: 166 | """Get tokens of strings list (async)""" 167 | return await self._client.atokens_count(input_, model) 168 | 169 | def get_models(self) -> gm.Models: 170 | """Get available models of Gigachat""" 171 | return self._client.get_models() 172 | 173 | async def aget_models(self) -> gm.Models: 174 | """Get available models of Gigachat (async)""" 175 | return await self._client.aget_models() 176 | 177 | def get_model(self, model: str) -> gm.Model: 178 | """Get info about model""" 179 | return self._client.get_model(model) 180 | 181 | async def aget_model(self, model: str) -> gm.Model: 182 | """Get info about model (async)""" 183 | return await self._client.aget_model(model) 184 | 185 | def get_num_tokens(self, text: str) -> int: 186 | """Count approximate number of tokens""" 187 | if self.use_api_for_tokens: 188 | return self.tokens_count([text])[0].tokens # type: ignore 189 | else: 190 | return round(len(text) / 4.6) 191 | 192 | def upload_file( 193 | self, file: FileTypes, purpose: Literal["general", "assistant"] = "general" 194 | ) -> gm.UploadedFile: 195 | return self._client.upload_file(file, purpose) 196 | 197 | async def aupload_file( 198 | self, file: FileTypes, purpose: Literal["general", "assistant"] = "general" 199 | ) -> gm.UploadedFile: 200 | return await self._client.aupload_file(file, purpose) 201 | 202 | def get_file(self, file_id: str) -> gm.Image: 203 | return self._client.get_image(file_id) 204 | 205 | async def aget_file(self, file_id: str) -> gm.Image: 206 | return await self._client.aget_image(file_id) 207 | -------------------------------------------------------------------------------- /libs/gigachat/tests/unit_tests/stubs.py: -------------------------------------------------------------------------------- 1 | """A fake callback handlers and stubs for testing purposes.""" 2 | 3 | from itertools import chain 4 | from typing import Any, Dict, List, Optional, Union 5 | from uuid import UUID 6 | 7 | from langchain_core.callbacks import AsyncCallbackHandler, BaseCallbackHandler 8 | from langchain_core.messages import BaseMessage 9 | from pydantic import BaseModel 10 | 11 | 12 | class AnyStr(str): 13 | def __eq__(self, other: Any) -> bool: 14 | return isinstance(other, str) 15 | 16 | 17 | class BaseFakeCallbackHandler(BaseModel): 18 | """Base fake callback handler for testing.""" 19 | 20 | starts: int = 0 21 | ends: int = 0 22 | errors: int = 0 23 | text: int = 0 24 | ignore_llm_: bool = False 25 | ignore_chain_: bool = False 26 | ignore_agent_: bool = False 27 | ignore_retriever_: bool = False 28 | ignore_chat_model_: bool = False 29 | 30 | # to allow for similar callback handlers that are not technically equal 31 | fake_id: Union[str, None] = None 32 | 33 | # add finer-grained counters for easier debugging of failing tests 34 | chain_starts: int = 0 35 | chain_ends: int = 0 36 | llm_starts: int = 0 37 | llm_ends: int = 0 38 | llm_streams: int = 0 39 | tool_starts: int = 0 40 | tool_ends: int = 0 41 | agent_actions: int = 0 42 | agent_ends: int = 0 43 | chat_model_starts: int = 0 44 | retriever_starts: int = 0 45 | retriever_ends: int = 0 46 | retriever_errors: int = 0 47 | retries: int = 0 48 | 49 | 50 | class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler): 51 | """Base fake callback handler mixin for testing.""" 52 | 53 | def on_llm_start_common(self) -> None: 54 | self.llm_starts += 1 55 | self.starts += 1 56 | 57 | def on_llm_end_common(self) -> None: 58 | self.llm_ends += 1 59 | self.ends += 1 60 | 61 | def on_llm_error_common(self) -> None: 62 | self.errors += 1 63 | 64 | def on_llm_new_token_common(self) -> None: 65 | self.llm_streams += 1 66 | 67 | def on_retry_common(self) -> None: 68 | self.retries += 1 69 | 70 | def on_chain_start_common(self) -> None: 71 | self.chain_starts += 1 72 | self.starts += 1 73 | 74 | def on_chain_end_common(self) -> None: 75 | self.chain_ends += 1 76 | self.ends += 1 77 | 78 | def on_chain_error_common(self) -> None: 79 | self.errors += 1 80 | 81 | def on_tool_start_common(self) -> None: 82 | self.tool_starts += 1 83 | self.starts += 1 84 | 85 | def on_tool_end_common(self) -> None: 86 | self.tool_ends += 1 87 | self.ends += 1 88 | 89 | def on_tool_error_common(self) -> None: 90 | self.errors += 1 91 | 92 | def on_agent_action_common(self) -> None: 93 | self.agent_actions += 1 94 | self.starts += 1 95 | 96 | def on_agent_finish_common(self) -> None: 97 | self.agent_ends += 1 98 | self.ends += 1 99 | 100 | def on_chat_model_start_common(self) -> None: 101 | self.chat_model_starts += 1 102 | self.starts += 1 103 | 104 | def on_text_common(self) -> None: 105 | self.text += 1 106 | 107 | def on_retriever_start_common(self) -> None: 108 | self.starts += 1 109 | self.retriever_starts += 1 110 | 111 | def on_retriever_end_common(self) -> None: 112 | self.ends += 1 113 | self.retriever_ends += 1 114 | 115 | def on_retriever_error_common(self) -> None: 116 | self.errors += 1 117 | self.retriever_errors += 1 118 | 119 | 120 | class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): 121 | """Fake callback handler for testing.""" 122 | 123 | @property 124 | def ignore_llm(self) -> bool: 125 | """Whether to ignore LLM callbacks.""" 126 | return self.ignore_llm_ 127 | 128 | @property 129 | def ignore_chain(self) -> bool: 130 | """Whether to ignore chain callbacks.""" 131 | return self.ignore_chain_ 132 | 133 | @property 134 | def ignore_agent(self) -> bool: 135 | """Whether to ignore agent callbacks.""" 136 | return self.ignore_agent_ 137 | 138 | @property 139 | def ignore_retriever(self) -> bool: 140 | """Whether to ignore retriever callbacks.""" 141 | return self.ignore_retriever_ 142 | 143 | def on_llm_start(self, *args: Any, **kwargs: Any) -> Any: 144 | self.on_llm_start_common() 145 | 146 | def on_llm_new_token(self, *args: Any, **kwargs: Any) -> Any: 147 | self.on_llm_new_token_common() 148 | 149 | def on_llm_end(self, *args: Any, **kwargs: Any) -> Any: 150 | self.on_llm_end_common() 151 | 152 | def on_llm_error(self, *args: Any, **kwargs: Any) -> Any: 153 | self.on_llm_error_common() 154 | 155 | def on_retry(self, *args: Any, **kwargs: Any) -> Any: 156 | self.on_retry_common() 157 | 158 | def on_chain_start(self, *args: Any, **kwargs: Any) -> Any: 159 | self.on_chain_start_common() 160 | 161 | def on_chain_end(self, *args: Any, **kwargs: Any) -> Any: 162 | self.on_chain_end_common() 163 | 164 | def on_chain_error(self, *args: Any, **kwargs: Any) -> Any: 165 | self.on_chain_error_common() 166 | 167 | def on_tool_start(self, *args: Any, **kwargs: Any) -> Any: 168 | self.on_tool_start_common() 169 | 170 | def on_tool_end(self, *args: Any, **kwargs: Any) -> Any: 171 | self.on_tool_end_common() 172 | 173 | def on_tool_error(self, *args: Any, **kwargs: Any) -> Any: 174 | self.on_tool_error_common() 175 | 176 | def on_agent_action(self, *args: Any, **kwargs: Any) -> Any: 177 | self.on_agent_action_common() 178 | 179 | def on_agent_finish(self, *args: Any, **kwargs: Any) -> Any: 180 | self.on_agent_finish_common() 181 | 182 | def on_text(self, *args: Any, **kwargs: Any) -> Any: 183 | self.on_text_common() 184 | 185 | def on_retriever_start(self, *args: Any, **kwargs: Any) -> Any: 186 | self.on_retriever_start_common() 187 | 188 | def on_retriever_end(self, *args: Any, **kwargs: Any) -> Any: 189 | self.on_retriever_end_common() 190 | 191 | def on_retriever_error(self, *args: Any, **kwargs: Any) -> Any: 192 | self.on_retriever_error_common() 193 | 194 | # def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": 195 | # return self 196 | 197 | 198 | class FakeCallbackHandlerWithChatStart(FakeCallbackHandler): 199 | def on_chat_model_start( 200 | self, 201 | serialized: Dict[str, Any], 202 | messages: List[List[BaseMessage]], 203 | *, 204 | run_id: UUID, 205 | parent_run_id: Optional[UUID] = None, 206 | **kwargs: Any, 207 | ) -> Any: 208 | assert all(isinstance(m, BaseMessage) for m in chain(*messages)) 209 | self.on_chat_model_start_common() 210 | 211 | 212 | class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixin): 213 | """Fake async callback handler for testing.""" 214 | 215 | @property 216 | def ignore_llm(self) -> bool: 217 | """Whether to ignore LLM callbacks.""" 218 | return self.ignore_llm_ 219 | 220 | @property 221 | def ignore_chain(self) -> bool: 222 | """Whether to ignore chain callbacks.""" 223 | return self.ignore_chain_ 224 | 225 | @property 226 | def ignore_agent(self) -> bool: 227 | """Whether to ignore agent callbacks.""" 228 | return self.ignore_agent_ 229 | 230 | async def on_retry(self, *args: Any, **kwargs: Any) -> Any: 231 | self.on_retry_common() 232 | 233 | async def on_llm_start(self, *args: Any, **kwargs: Any) -> None: 234 | self.on_llm_start_common() 235 | 236 | async def on_llm_new_token(self, *args: Any, **kwargs: Any) -> None: 237 | self.on_llm_new_token_common() 238 | 239 | async def on_llm_end(self, *args: Any, **kwargs: Any) -> None: 240 | self.on_llm_end_common() 241 | 242 | async def on_llm_error(self, *args: Any, **kwargs: Any) -> None: 243 | self.on_llm_error_common() 244 | 245 | async def on_chain_start(self, *args: Any, **kwargs: Any) -> None: 246 | self.on_chain_start_common() 247 | 248 | async def on_chain_end(self, *args: Any, **kwargs: Any) -> None: 249 | self.on_chain_end_common() 250 | 251 | async def on_chain_error(self, *args: Any, **kwargs: Any) -> None: 252 | self.on_chain_error_common() 253 | 254 | async def on_tool_start(self, *args: Any, **kwargs: Any) -> None: 255 | self.on_tool_start_common() 256 | 257 | async def on_tool_end(self, *args: Any, **kwargs: Any) -> None: 258 | self.on_tool_end_common() 259 | 260 | async def on_tool_error(self, *args: Any, **kwargs: Any) -> None: 261 | self.on_tool_error_common() 262 | 263 | async def on_agent_action(self, *args: Any, **kwargs: Any) -> None: 264 | self.on_agent_action_common() 265 | 266 | async def on_agent_finish(self, *args: Any, **kwargs: Any) -> None: 267 | self.on_agent_finish_common() 268 | 269 | async def on_text(self, *args: Any, **kwargs: Any) -> None: 270 | self.on_text_common() 271 | 272 | # def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler": 273 | # return self 274 | -------------------------------------------------------------------------------- /libs/gigachat/README-ru_RU.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | [![GitHub Release](https://img.shields.io/github/v/release/ai-forever/langchain-gigachat?style=flat-square)](https://github.com/ai-forever/langchain-gigachat/releases) 4 | [![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/ai-forever/langchain-gigachat/check_diffs.yml?style=flat-square)](https://github.com/ai-forever/langchain-gigachat/actions/workflows/check_diffs.yml) 5 | [![GitHub License](https://img.shields.io/github/license/ai-forever/langchain-gigachat?style=flat-square)](https://opensource.org/license/MIT) 6 | [![GitHub Downloads (all assets, all releases)](https://img.shields.io/pypi/dm/langchain-gigachat?style=flat-square?style=flat-square)](https://pypistats.org/packages/langchain-gigachat) 7 | [![GitHub Repo stars](https://img.shields.io/github/stars/ai-forever/langchain-gigachat?style=flat-square)](https://star-history.com/#ai-forever/langchain-gigachat) 8 | [![GitHub Open Issues](https://img.shields.io/github/issues-raw/ai-forever/langchain-gigachat)](https://github.com/ai-forever/langchain-gigachat/issues) 9 | 10 | [English](README.md) | [Русский](README-ru_RU.md) 11 | 12 |
13 | 14 | # langchain-gigachat 15 | 16 | Библиотека `langchain-gigachat` позволяет использовать нейросетевые модели GigaChat при разработке LLM-приложений с помощью фреймворков LangChain и LangGraph. 17 | 18 | Библиотека входит в набор решений [GigaChain](https://github.com/ai-forever/gigachain). 19 | 20 | ## Требования 21 | 22 | Для работы с библиотекой и обмена сообщениями с моделями GigaChat понадобятся: 23 | 24 | * Python версии 3.9 и выше; 25 | * [сертификат НУЦ Минцифры](https://developers.sber.ru/docs/ru/gigachat/certificates); 26 | * [ключ авторизации](https://developers.sber.ru/docs/ru/gigachat/quickstart/ind-using-api#poluchenie-avtorizatsionnyh-dannyh) GigaChat API. 27 | 28 | > [!NOTE] 29 | > Вы также можете использовать другие [способы авторизации](#способы-авторизации). 30 | 31 | ## Установка 32 | 33 | Для установки библиотеки используйте менеджер пакетов pip: 34 | 35 | ```sh 36 | pip install -U langchain-gigachat 37 | ``` 38 | 39 | ## Быстрый старт 40 | 41 | ### Запрос на генерацию 42 | 43 | Пример запроса на генерацию: 44 | 45 | ```py 46 | from langchain_gigachat.chat_models import GigaChat 47 | 48 | giga = GigaChat( 49 | # Для авторизации запросов используйте ключ, полученный в проекте GigaChat API 50 | credentials="ваш_ключ_авторизации", 51 | verify_ssl_certs=False, 52 | ) 53 | 54 | print(giga.invoke("Hello, world!")) 55 | ``` 56 | 57 | ### Создание эмбеддингов 58 | 59 | Пример создания векторного представления текста: 60 | 61 | ```py 62 | from langchain_gigachat.embeddings import GigaChatEmbeddings 63 | 64 | embeddings = GigaChatEmbeddings(credentials="ключ_авторизации", verify_ssl_certs=False) 65 | result = embeddings.embed_documents(texts=["Привет!"]) 66 | print(result) 67 | ``` 68 | 69 | ## Параметры объекта GigaChat 70 | 71 | В таблице описаны параметры, которые можно передать при инициализации объекта GigaChat: 72 | 73 | | Параметр | Обязательный | Описание | 74 | | ------------------ | ------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | 75 | | `credentials` | да | Ключ авторизации для обмена сообщениями с GigaChat API.
Ключ авторизации содержит информацию о версии API, к которой выполняются запросы. Если вы используете версию API для ИП или юрлиц, укажите это явно в параметре `scope` | 76 | | `verify_ssl_certs` | нет | Отключение проверки ssl-сертификатов.

Для обращения к GigaChat API нужно [установить корневой сертификат НУЦ Минцифры](#установка-корневого-сертификата-нуц-минцифры).

Используйте параметр ответственно, так как отключение проверки сертификатов снижает безопасность обмена данными | 77 | | `scope` | нет | Версия API, к которой будет выполнен запрос. По умолчанию запросы передаются в версию для физических лиц. Возможные значения:
  • `GIGACHAT_API_PERS` — версия API для физических лиц;
  • `GIGACHAT_API_B2B` — версия API для ИП и юрлиц при работе по предоплате.
  • `GIGACHAT_API_CORP` — версия API для ИП и юрлиц при работе по постоплате.
| 78 | | `model` | нет | необязательный параметр, в котором можно явно задать [модель GigaChat](https://developers.sber.ru/docs/ru/gigachat/models). Вы можете посмотреть список доступных моделей с помощью метода `get_models()`, который выполняет запрос [`GET /models`](https://developers.sber.ru/docs/ru/gigachat/api/reference#get-models).

Стоимость запросов к разным моделям отличается. Подробную информацию о тарификации запросов к той или иной модели вы ищите в [официальной документации](https://developers.sber.ru/docs/ru/gigachat/api/tariffs) | 79 | | `base_url` | нет | Адрес API. По умолчанию запросы отправляются по адресу `https://gigachat.devices.sberbank.ru/api/v1/`, но если вы хотите использовать [модели в раннем доступе](https://developers.sber.ru/docs/ru/gigachat/models/preview-models), укажите адрес `https://gigachat-preview.devices.sberbank.ru/api/v1` | 80 | 81 | > [!TIP] 82 | > Чтобы не указывать параметры при каждой инициализации, задайте их в [переменных окружения](#настройка-переменных-окружения). 83 | 84 | ## Способы авторизации 85 | 86 | Для авторизации запросов, кроме ключа, полученного в личном кабинете, вы можете использовать: 87 | 88 | * имя пользователя и пароль для доступа к сервису; 89 | * сертификаты TLS; 90 | * токен доступа (access token), полученный в обмен на ключ авторизации в запросе [`POST /api/v2/oauth`](https://developers.sber.ru/docs/ru/gigachat/api/reference/rest/post-token). 91 | 92 | Для этого передайте соответствующие параметры при инициализации. 93 | 94 | Пример авторизации с помощью логина и пароля: 95 | 96 | ```py 97 | giga = GigaChat( 98 | base_url="https://gigachat.devices.sberbank.ru/api/v1", 99 | user="имя_пользоваеля", 100 | password="пароль", 101 | ) 102 | ``` 103 | 104 | Авторизация с помощью сертификатов по протоколу TLS (mTLS): 105 | 106 | ```py 107 | giga = GigaChat( 108 | base_url="https://gigachat.devices.sberbank.ru/api/v1", 109 | ca_bundle_file="certs/ca.pem", # chain_pem.txt 110 | cert_file="certs/tls.pem", # published_pem.txt 111 | key_file="certs/tls.key", 112 | key_file_password="123456", 113 | ssl_context=context # optional ssl.SSLContext instance 114 | ) 115 | ``` 116 | 117 | Авторизация с помощью токена доступа: 118 | 119 | ```py 120 | giga = GigaChat( 121 | access_token="ваш_токен_доступа", 122 | ) 123 | ``` 124 | 125 | > [!NOTE] 126 | > Токен действителен в течение 30 минут. 127 | > При использовании такого способа авторизации, в приложении нужно реализовать механизм обновления токена. 128 | 129 | ### Предварительная авторизация 130 | 131 | По умолчанию, библиотека GigaChat получает токен доступа при первом запросе к API. 132 | 133 | Если вам нужно получить токен и авторизоваться до выполнения запроса, инициализируйте объект GigaChat и вызовите метод `get_token()`. 134 | 135 | ```py 136 | giga = GigaChat( 137 | base_url="https://gigachat.devices.sberbank.ru/api/v1", 138 | user="имя_пользователя", 139 | password="пароль", 140 | ) 141 | giga.get_token() 142 | ``` 143 | 144 | ## Настройка переменных окружения 145 | 146 | Чтобы задать параметры с помощью переменных окружения, в названии переменной используйте префикс `GIGACHAT_`. 147 | 148 | Пример переменных окружения, которые задают ключ авторизации, версию API и отключают проверку сертификатов. 149 | 150 | ```sh 151 | export GIGACHAT_CREDENTIALS=... 152 | export GIGACHAT_SCOPE=... 153 | export GIGACHAT_VERIFY_SSL_CERTS=False 154 | ``` 155 | 156 | Пример переменных окружения, которые задают адрес API, имя пользователя и пароль. 157 | 158 | ```sh 159 | export GIGACHAT_BASE_URL=https://gigachat.devices.sberbank.ru/api/v1 160 | export GIGACHAT_USER=... 161 | export GIGACHAT_PASSWORD=... 162 | ``` 163 | -------------------------------------------------------------------------------- /libs/gigachat/langchain_gigachat/tools/giga_tool.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import inspect 3 | import textwrap 4 | from typing import ( 5 | Annotated, 6 | Any, 7 | Awaitable, 8 | Callable, 9 | Dict, 10 | List, 11 | Literal, 12 | Optional, 13 | Type, 14 | Union, 15 | get_type_hints, 16 | ) 17 | 18 | from langchain_core.callbacks import Callbacks 19 | from langchain_core.runnables import Runnable, RunnableConfig 20 | from langchain_core.tools import ( 21 | FILTERED_ARGS, 22 | BaseTool, 23 | StructuredTool, 24 | Tool, 25 | create_schema_from_function, 26 | ) 27 | from langchain_core.utils.pydantic import TypeBaseModel 28 | from pydantic import BaseModel 29 | from pydantic.functional_validators import SkipValidation 30 | 31 | from langchain_gigachat.utils.function_calling import create_return_schema_from_function 32 | 33 | FewShotExamples = Optional[List[Dict[str, Any]]] 34 | 35 | 36 | class GigaBaseTool(BaseTool): # type: ignore[override] 37 | """Interface of GigaChat tools with additional properties, that GigaChat supports""" 38 | 39 | return_schema: Annotated[Optional[TypeBaseModel], SkipValidation()] = None 40 | """Return schema of JSON that function returns""" 41 | few_shot_examples: FewShotExamples = None 42 | """Few-shot examples to help the model understand how to use the tool.""" 43 | 44 | 45 | class GigaTool(GigaBaseTool, Tool): # type: ignore[override] 46 | pass 47 | 48 | 49 | def _get_type_hints(func: Callable) -> Optional[dict[str, type]]: 50 | if isinstance(func, functools.partial): 51 | func = func.func 52 | try: 53 | return get_type_hints(func) 54 | except Exception: 55 | return None 56 | 57 | 58 | def _get_runnable_config_param(func: Callable) -> Optional[str]: 59 | type_hints = _get_type_hints(func) 60 | if not type_hints: 61 | return None 62 | for name, type_ in type_hints.items(): 63 | if type_ is RunnableConfig: 64 | return name 65 | return None 66 | 67 | 68 | def _filter_schema_args(func: Callable) -> list[str]: 69 | filter_args = list(FILTERED_ARGS) 70 | if config_param := _get_runnable_config_param(func): 71 | filter_args.append(config_param) 72 | return filter_args 73 | 74 | 75 | class GigaStructuredTool(GigaBaseTool, StructuredTool): # type: ignore[override] 76 | @classmethod 77 | def from_function( 78 | cls, 79 | func: Optional[Callable] = None, 80 | coroutine: Optional[Callable[..., Awaitable[Any]]] = None, 81 | name: Optional[str] = None, 82 | description: Optional[str] = None, 83 | return_direct: bool = False, 84 | args_schema: Union[type[BaseModel], dict[str, Any], None] = None, 85 | infer_schema: bool = True, 86 | return_schema: Optional[Type[BaseModel]] = None, 87 | few_shot_examples: FewShotExamples = None, 88 | *, 89 | response_format: Literal["content", "content_and_artifact"] = "content", 90 | parse_docstring: bool = False, 91 | error_on_invalid_docstring: bool = False, 92 | **kwargs: Any, 93 | ) -> StructuredTool: 94 | """Create tool from a given function. 95 | 96 | A classmethod that helps to create a tool from a function. 97 | 98 | Args: 99 | func: The function from which to create a tool. 100 | coroutine: The async function from which to create a tool. 101 | name: The name of the tool. Defaults to the function name. 102 | description: The description of the tool. 103 | Defaults to the function docstring. 104 | return_direct: Whether to return the result directly or as a callback. 105 | Defaults to False. 106 | args_schema: The schema of the tool's input arguments. Defaults to None. 107 | infer_schema: Whether to infer the schema from the function's signature. 108 | Defaults to True. 109 | return_schema: The return schema of tool output. Defaults to None 110 | few_shot_examples: Few shot examples of tool usage 111 | response_format: The tool response format. If "content" then the output of 112 | the tool is interpreted as the contents of a ToolMessage. If 113 | "content_and_artifact" then the output is expected to be a two-tuple 114 | corresponding to the (content, artifact) of a ToolMessage. 115 | Defaults to "content". 116 | parse_docstring: if ``infer_schema`` and ``parse_docstring``, will attempt 117 | to parse parameter descriptions from Google Style function docstrings. 118 | Defaults to False. 119 | error_on_invalid_docstring: if ``parse_docstring`` is provided, configure 120 | whether to raise ValueError on invalid Google Style docstrings. 121 | Defaults to False. 122 | kwargs: Additional arguments to pass to the tool 123 | """ 124 | 125 | if func is not None: 126 | source_function = func 127 | elif coroutine is not None: 128 | source_function = coroutine 129 | else: 130 | msg = "Function and/or coroutine must be provided" 131 | raise ValueError(msg) 132 | name = name or source_function.__name__ 133 | if args_schema is None and infer_schema: 134 | # schema name is appended within function 135 | args_schema = create_schema_from_function( 136 | name, 137 | source_function, 138 | parse_docstring=parse_docstring, 139 | error_on_invalid_docstring=error_on_invalid_docstring, 140 | filter_args=_filter_schema_args(source_function), 141 | ) 142 | if return_schema is None and infer_schema: 143 | # schema name is appended within function 144 | return_schema = create_return_schema_from_function(source_function) 145 | description_ = description 146 | if description is None and not parse_docstring: 147 | description_ = source_function.__doc__ or None 148 | if description_ is None and args_schema: 149 | description_ = args_schema.__doc__ or None 150 | if description_ is None: 151 | msg = "Function must have a docstring if description not provided." 152 | raise ValueError(msg) 153 | if description is None: 154 | # Only apply if using the function's docstring 155 | description_ = textwrap.dedent(description_).strip() 156 | 157 | # Description example: 158 | # search_api(query: str) - Searches the API for the query. 159 | description_ = f"{description_.strip()}" 160 | return cls( 161 | name=name, 162 | func=func, 163 | coroutine=coroutine, 164 | args_schema=args_schema, # type: ignore[arg-type] 165 | description=description_, 166 | return_direct=return_direct, 167 | response_format=response_format, 168 | return_schema=return_schema, 169 | few_shot_examples=few_shot_examples, 170 | **kwargs, 171 | ) 172 | 173 | 174 | def giga_tool( 175 | *args: Union[str, Callable, Runnable], 176 | return_direct: bool = False, 177 | args_schema: Union[type[BaseModel], dict[str, Any], None] = None, 178 | infer_schema: bool = True, 179 | response_format: Literal["content", "content_and_artifact"] = "content", 180 | parse_docstring: bool = False, 181 | error_on_invalid_docstring: bool = True, 182 | return_schema: Optional[type] = None, 183 | few_shot_examples: FewShotExamples = None, 184 | ) -> Callable: 185 | """Make tools out of functions, can be used with or without arguments. 186 | 187 | Args: 188 | *args: The arguments to the tool. 189 | return_direct: Whether to return directly from the tool rather 190 | than continuing the agent loop. Defaults to False. 191 | args_schema: optional argument schema for user to specify. 192 | Defaults to None. 193 | infer_schema: Whether to infer the schema of the arguments from 194 | the function's signature. This also makes the resultant tool 195 | accept a dictionary input to its `run()` function. 196 | Defaults to True. 197 | response_format: The tool response format. If "content" then the output of 198 | the tool is interpreted as the contents of a ToolMessage. If 199 | "content_and_artifact" then the output is expected to be a two-tuple 200 | corresponding to the (content, artifact) of a ToolMessage. 201 | Defaults to "content". 202 | parse_docstring: if ``infer_schema`` and ``parse_docstring``, will attempt to 203 | parse parameter descriptions from Google Style function docstrings. 204 | Defaults to False. 205 | error_on_invalid_docstring: if ``parse_docstring`` is provided, configure 206 | whether to raise ValueError on invalid Google Style docstrings. 207 | Defaults to True. 208 | return_schema: The return schema of tool output. Defaults to None 209 | few_shot_examples: Few shot examples of tool usage 210 | """ 211 | 212 | def _make_with_name(tool_name: str) -> Callable: 213 | def _make_tool(dec_func: Union[Callable, Runnable]) -> BaseTool: 214 | if isinstance(dec_func, Runnable): 215 | runnable = dec_func 216 | 217 | if runnable.input_schema.model_json_schema().get("type") != "object": 218 | msg = "Runnable must have an object schema." 219 | raise ValueError(msg) 220 | 221 | async def ainvoke_wrapper( 222 | callbacks: Optional[Callbacks] = None, **kwargs: Any 223 | ) -> Any: 224 | return await runnable.ainvoke(kwargs, {"callbacks": callbacks}) 225 | 226 | def invoke_wrapper( 227 | callbacks: Optional[Callbacks] = None, **kwargs: Any 228 | ) -> Any: 229 | return runnable.invoke(kwargs, {"callbacks": callbacks}) 230 | 231 | coroutine = ainvoke_wrapper 232 | func = invoke_wrapper 233 | schema: Optional[Union[type[BaseModel], dict[str, Any]],] = ( 234 | runnable.input_schema 235 | ) 236 | description = repr(runnable) 237 | elif inspect.iscoroutinefunction(dec_func): 238 | coroutine = dec_func 239 | func = None 240 | schema = args_schema 241 | description = None 242 | else: 243 | coroutine = None 244 | func = dec_func 245 | schema = args_schema 246 | description = None 247 | 248 | if infer_schema or args_schema is not None: 249 | return GigaStructuredTool.from_function( 250 | func, 251 | coroutine, 252 | name=tool_name, 253 | description=description, 254 | return_direct=return_direct, 255 | args_schema=schema, 256 | infer_schema=infer_schema, 257 | response_format=response_format, 258 | parse_docstring=parse_docstring, 259 | error_on_invalid_docstring=error_on_invalid_docstring, 260 | return_schema=return_schema, 261 | few_shot_examples=few_shot_examples, 262 | ) 263 | # If someone doesn't want a schema applied, we must treat it as 264 | # a simple string->string function 265 | if dec_func.__doc__ is None: 266 | msg = ( 267 | "Function must have a docstring if " 268 | "description not provided and infer_schema is False." 269 | ) 270 | raise ValueError(msg) 271 | return GigaTool( 272 | name=tool_name, 273 | func=func, 274 | description=f"{tool_name} tool", 275 | return_direct=return_direct, 276 | coroutine=coroutine, 277 | response_format=response_format, 278 | return_schema=return_schema, 279 | few_shot_examples=few_shot_examples, 280 | ) 281 | 282 | return _make_tool 283 | 284 | if len(args) == 2 and isinstance(args[0], str) and isinstance(args[1], Runnable): 285 | return _make_with_name(args[0])(args[1]) 286 | elif len(args) == 1 and isinstance(args[0], str): 287 | # if the argument is a string, then we use the string as the tool name 288 | # Example usage: @tool("search", return_direct=True) 289 | return _make_with_name(args[0]) 290 | elif len(args) == 1 and callable(args[0]): 291 | # if the argument is a function, then we use the function name as the tool name 292 | # Example usage: @tool 293 | return _make_with_name(args[0].__name__)(args[0]) 294 | elif len(args) == 0: 295 | # if there are no arguments, then we use the function name as the tool name 296 | # Example usage: @tool(return_direct=True) 297 | def _partial(func: Callable[[str], str]) -> BaseTool: 298 | return _make_with_name(func.__name__)(func) 299 | 300 | return _partial 301 | else: 302 | msg = "Too many arguments for tool decorator" 303 | raise ValueError(msg) 304 | -------------------------------------------------------------------------------- /libs/gigachat/langchain_gigachat/utils/function_calling.py: -------------------------------------------------------------------------------- 1 | import collections.abc 2 | import functools 3 | import inspect 4 | import types 5 | import typing 6 | from typing import ( 7 | Annotated, 8 | Any, 9 | Callable, 10 | Dict, 11 | List, 12 | Optional, 13 | Type, 14 | Union, 15 | cast, 16 | get_type_hints, 17 | ) 18 | 19 | from langchain_core.tools import BaseTool, Tool 20 | from langchain_core.utils.function_calling import ( 21 | FunctionDescription, 22 | is_basemodel_subclass, 23 | ) 24 | from langchain_core.utils.json_schema import dereference_refs 25 | from pydantic import BaseModel 26 | from typing_extensions import get_args, get_origin, is_typeddict 27 | 28 | 29 | class GigaFunctionDescription(FunctionDescription): 30 | """The parameters of the function.""" 31 | 32 | return_parameters: Optional[dict] 33 | """The result settings of the function.""" 34 | few_shot_examples: Optional[list] 35 | """The examples of the function.""" 36 | 37 | 38 | SCHEMA_DO_NOT_SUPPORT_MESSAGE = """Incorrect function schema! 39 | {schema} 40 | GigaChat currently do not support these typings: 41 | Union[X, Y, ...]""" 42 | 43 | 44 | class IncorrectSchemaException(Exception): 45 | pass 46 | 47 | 48 | def gigachat_fix_schema(schema: Any, prev_key: str = "") -> Any: 49 | """ 50 | GigaChat do not support allOf/anyOf in JSON schema. 51 | We need to fix this in case of allOf with one object or 52 | in case with optional parameter. 53 | In other cases throw exception that we do not support this types of schemas 54 | """ 55 | if isinstance(schema, dict): 56 | obj_out: Any = {} 57 | for k, v in schema.items(): 58 | if k == "title": 59 | if isinstance(v, dict) and prev_key == "properties" and "title" in v: 60 | obj_out[k] = gigachat_fix_schema(v, k) 61 | else: 62 | continue 63 | if k == "allOf": 64 | if len(v) > 1: 65 | raise IncorrectSchemaException() 66 | obj = gigachat_fix_schema(v[0], k) 67 | outer_description = schema.get("description") 68 | obj_out = {**obj_out, **obj} 69 | if outer_description: 70 | # Внешнее описания приоритетнее внутреннего для ref 71 | obj_out["description"] = outer_description 72 | if k == "anyOf": 73 | if len(v) > 1: 74 | raise IncorrectSchemaException() 75 | elif isinstance(v, (list, dict)): 76 | obj_out[k] = gigachat_fix_schema(v, k) 77 | else: 78 | obj_out[k] = v 79 | return obj_out 80 | elif isinstance(schema, list): 81 | return [gigachat_fix_schema(el) for el in schema] 82 | else: 83 | return schema 84 | 85 | 86 | def _convert_typed_dict_to_gigachat_function( 87 | typed_dict: type, 88 | ) -> GigaFunctionDescription: 89 | visited: dict = {} 90 | from pydantic.v1 import BaseModel 91 | 92 | model = cast( 93 | type[BaseModel], 94 | _convert_any_typed_dicts_to_pydantic(typed_dict, visited=visited), 95 | ) 96 | return convert_pydantic_to_gigachat_function(model) # type: ignore 97 | 98 | 99 | _MAX_TYPED_DICT_RECURSION = 25 100 | 101 | 102 | def _is_optional(field: type) -> bool: 103 | return typing.get_origin(field) is Union and type(None) in typing.get_args(field) 104 | 105 | 106 | def _convert_any_typed_dicts_to_pydantic( 107 | type_: type, *, visited: dict, depth: int = 0 108 | ) -> type: 109 | from pydantic.v1 import Field as Field_v1 110 | from pydantic.v1 import create_model as create_model_v1 111 | 112 | if type_ in visited: 113 | return visited[type_] 114 | elif depth >= _MAX_TYPED_DICT_RECURSION: 115 | return type_ 116 | elif is_typeddict(type_): 117 | typed_dict = type_ 118 | docstring = inspect.getdoc(typed_dict) 119 | annotations_ = typed_dict.__annotations__ 120 | description, arg_descriptions = _parse_google_docstring( 121 | docstring, list(annotations_) 122 | ) 123 | fields: dict = {} 124 | for arg, arg_type in annotations_.items(): 125 | if get_origin(arg_type) is Annotated: 126 | annotated_args = get_args(arg_type) 127 | new_arg_type = _convert_any_typed_dicts_to_pydantic( 128 | annotated_args[0], depth=depth + 1, visited=visited 129 | ) 130 | field_kwargs = dict(zip(("default", "description"), annotated_args[1:])) 131 | if (field_desc := field_kwargs.get("description")) and not isinstance( 132 | field_desc, str 133 | ): 134 | msg = ( 135 | f"Invalid annotation for field {arg}. Third argument to " 136 | f"Annotated must be a string description, received value of " 137 | f"type {type(field_desc)}." 138 | ) 139 | raise ValueError(msg) 140 | elif arg_desc := arg_descriptions.get(arg): 141 | field_kwargs["description"] = arg_desc 142 | else: 143 | pass 144 | fields[arg] = (new_arg_type, Field_v1(**field_kwargs)) 145 | else: 146 | new_arg_type = _convert_any_typed_dicts_to_pydantic( 147 | arg_type, depth=depth + 1, visited=visited 148 | ) 149 | if _is_optional(new_arg_type): 150 | field_kwargs = {"default": None} 151 | else: 152 | field_kwargs = {"default": ...} 153 | if arg_desc := arg_descriptions.get(arg): 154 | field_kwargs["description"] = arg_desc 155 | fields[arg] = (new_arg_type, Field_v1(**field_kwargs)) 156 | model = create_model_v1(typed_dict.__name__, **fields) 157 | model.__doc__ = description 158 | visited[typed_dict] = model 159 | return model 160 | elif (origin := get_origin(type_)) and (type_args := get_args(type_)): 161 | subscriptable_origin = _py_38_safe_origin(origin) 162 | type_args = tuple( 163 | _convert_any_typed_dicts_to_pydantic(arg, depth=depth + 1, visited=visited) 164 | for arg in type_args # type: ignore[index] 165 | ) 166 | return subscriptable_origin[type_args] # type: ignore[index] 167 | else: 168 | return type_ 169 | 170 | 171 | def _py_38_safe_origin(origin: type) -> type: 172 | origin_union_type_map: dict[type, Any] = ( 173 | {types.UnionType: Union} if hasattr(types, "UnionType") else {} 174 | ) 175 | 176 | origin_map: dict[type, Any] = { 177 | dict: dict, 178 | list: list, 179 | tuple: tuple, 180 | set: set, 181 | collections.abc.Iterable: typing.Iterable, 182 | collections.abc.Mapping: typing.Mapping, 183 | collections.abc.Sequence: typing.Sequence, 184 | collections.abc.MutableMapping: typing.MutableMapping, 185 | **origin_union_type_map, 186 | } 187 | return cast(type, origin_map.get(origin, origin)) 188 | 189 | 190 | def _parse_google_docstring( 191 | docstring: Optional[str], 192 | args: list[str], 193 | *, 194 | error_on_invalid_docstring: bool = False, 195 | ) -> tuple[str, dict]: 196 | """Parse the function and argument descriptions from the docstring of a function. 197 | 198 | Assumes the function docstring follows Google Python style guide. 199 | """ 200 | if docstring: 201 | docstring_blocks = docstring.split("\n\n") 202 | if error_on_invalid_docstring: 203 | filtered_annotations = { 204 | arg for arg in args if arg not in ("run_manager", "callbacks", "return") 205 | } 206 | if filtered_annotations and ( 207 | len(docstring_blocks) < 2 or not docstring_blocks[1].startswith("Args:") 208 | ): 209 | msg = "Found invalid Google-Style docstring." 210 | raise ValueError(msg) 211 | descriptors = [] 212 | args_block = None 213 | past_descriptors = False 214 | for block in docstring_blocks: 215 | if block.startswith("Args:"): 216 | args_block = block 217 | break 218 | elif block.startswith(("Returns:", "Example:")): 219 | # Don't break in case Args come after 220 | past_descriptors = True 221 | elif not past_descriptors: 222 | descriptors.append(block) 223 | else: 224 | continue 225 | description = " ".join(descriptors) 226 | else: 227 | if error_on_invalid_docstring: 228 | msg = "Found invalid Google-Style docstring." 229 | raise ValueError(msg) 230 | description = "" 231 | args_block = None 232 | arg_descriptions = {} 233 | if args_block: 234 | arg = None 235 | for line in args_block.split("\n")[1:]: 236 | if ":" in line: 237 | arg, desc = line.split(":", maxsplit=1) 238 | arg_descriptions[arg.strip()] = desc.strip() 239 | elif arg: 240 | arg_descriptions[arg.strip()] += " " + line.strip() 241 | return description, arg_descriptions 242 | 243 | 244 | def _get_python_function_name(function: Callable) -> str: 245 | """Get the name of a Python function.""" 246 | return function.__name__ 247 | 248 | 249 | def _model_to_schema(model: Union[type[BaseModel], dict[str, Any]]) -> dict: 250 | if hasattr(model, "model_json_schema"): 251 | # Pydantic 2 252 | from langchain_gigachat.utils.pydantic_generator import GigaChatJsonSchema 253 | 254 | return model.model_json_schema(schema_generator=GigaChatJsonSchema) 255 | elif hasattr(model, "schema"): 256 | return model.schema() # Pydantic 1 257 | else: 258 | msg = "Model must be a Pydantic model." 259 | raise TypeError(msg) 260 | 261 | 262 | def _convert_return_schema( 263 | return_model: Optional[Union[Type[BaseModel], dict[str, Any]]], 264 | ) -> Dict[str, Any]: 265 | if not return_model: 266 | return {} 267 | 268 | if isinstance(return_model, dict): 269 | return_schema = return_model 270 | else: 271 | return_schema = dereference_refs(_model_to_schema(return_model)) 272 | 273 | if "definitions" in return_schema: # pydantic 1 274 | return_schema.pop("definitions", None) 275 | if "$defs" in return_schema: # pydantic 2 276 | return_schema.pop("$defs", None) 277 | if "title" in return_schema: 278 | return_schema.pop("title", None) 279 | 280 | for key in return_schema["properties"]: 281 | if "type" not in return_schema["properties"][key]: 282 | return_schema["properties"][key]["type"] = "object" 283 | if "description" not in return_schema["properties"][key]: 284 | return_schema["properties"][key]["description"] = "" 285 | 286 | return return_schema 287 | 288 | 289 | def format_tool_to_gigachat_function(tool: BaseTool) -> GigaFunctionDescription: 290 | """Format tool into the GigaChat function API.""" 291 | if not tool.description or tool.description == "": 292 | raise RuntimeError( 293 | "Incorrect function or tool description. Description is required." 294 | ) 295 | tool_schema = tool.args_schema 296 | if tool.tool_call_schema: 297 | tool_schema = tool.tool_call_schema 298 | 299 | if hasattr(tool, "return_schema") and tool.return_schema: 300 | # return_schema = _convert_return_schema(tool.return_schema) 301 | return_schema = tool.return_schema 302 | else: 303 | return_schema = None 304 | 305 | if hasattr(tool, "few_shot_examples") and tool.few_shot_examples: 306 | few_shot_examples = tool.few_shot_examples 307 | else: 308 | few_shot_examples = None 309 | 310 | is_simple_tool = isinstance(tool, Tool) and not tool.args_schema 311 | 312 | if tool_schema and not is_simple_tool: 313 | if isinstance(tool_schema, dict) and "properties" in tool_schema: 314 | tool_schema = dereference_refs(tool_schema) 315 | if "definitions" in tool_schema: # pydantic 1 316 | tool_schema.pop("definitions", None) 317 | if "$defs" in tool_schema: # pydantic 2 318 | tool_schema.pop("$defs", None) 319 | default_description = tool_schema.pop("description") 320 | return GigaFunctionDescription( 321 | name=tool.name, 322 | description=tool.description or default_description, 323 | parameters=tool_schema, 324 | few_shot_examples=few_shot_examples, 325 | return_parameters=return_schema, 326 | ) 327 | return convert_pydantic_to_gigachat_function( 328 | tool_schema, 329 | name=tool.name, 330 | description=tool.description, 331 | return_model=return_schema, 332 | few_shot_examples=few_shot_examples, 333 | ) 334 | else: 335 | if hasattr(tool, "return_schema") and tool.return_schema: 336 | return_schema = _convert_return_schema(tool.return_schema) 337 | else: 338 | return_schema = None 339 | 340 | return GigaFunctionDescription( 341 | name=tool.name, 342 | description=tool.description, 343 | parameters={"properties": {}, "type": "object"}, 344 | few_shot_examples=few_shot_examples, 345 | return_parameters=return_schema, 346 | ) 347 | 348 | 349 | def convert_pydantic_to_gigachat_function( 350 | model: Union[type[BaseModel], dict[str, Any]], 351 | *, 352 | name: Optional[str] = None, 353 | description: Optional[str] = None, 354 | return_model: Optional[Type[BaseModel]] = None, 355 | few_shot_examples: Optional[List[dict]] = None, 356 | ) -> GigaFunctionDescription: 357 | """Converts a Pydantic model to a function description for the GigaChat API.""" 358 | schema = dereference_refs(_model_to_schema(model)) 359 | if "definitions" in schema: # pydantic 1 360 | schema.pop("definitions", None) 361 | if "$defs" in schema: # pydantic 2 362 | schema.pop("$defs", None) 363 | title = schema.pop("title", None) 364 | if "properties" in schema: 365 | for key in schema["properties"]: 366 | if "type" not in schema["properties"][key]: 367 | schema["properties"][key]["type"] = "object" 368 | if "description" not in schema["properties"][key]: 369 | schema["properties"][key]["description"] = "" 370 | 371 | if return_model: 372 | return_schema = _convert_return_schema(return_model) 373 | else: 374 | return_schema = None 375 | default_description = schema.pop("description", "") 376 | if (not description or description == "") and ( 377 | not default_description or default_description == "" 378 | ): 379 | raise ValueError( 380 | "Incorrect function or tool description. Description is required." 381 | ) 382 | 383 | if few_shot_examples is None and hasattr(model, "few_shot_examples"): 384 | few_shot_examples_attr = getattr(model, "few_shot_examples") 385 | if inspect.isfunction(few_shot_examples_attr): 386 | few_shot_examples = few_shot_examples_attr() 387 | 388 | return GigaFunctionDescription( 389 | name=name or title, 390 | description=description or default_description, 391 | parameters=schema, 392 | return_parameters=return_schema, 393 | few_shot_examples=few_shot_examples, 394 | ) 395 | 396 | 397 | def _get_type_hints(func: Callable) -> Optional[Dict[str, Type]]: 398 | if isinstance(func, functools.partial): 399 | func = func.func 400 | try: 401 | return get_type_hints(func) 402 | except Exception: 403 | return None 404 | 405 | 406 | def create_return_schema_from_function(func: Callable) -> Optional[Type[BaseModel]]: 407 | return_type = get_type_hints(func).get("return", Any) 408 | if ( 409 | return_type is not str 410 | and return_type is not int 411 | and return_type is not float 412 | and return_type is not None 413 | ): 414 | try: 415 | if isinstance(return_type, type) and is_basemodel_subclass(return_type): 416 | return return_type 417 | except TypeError: # It's normal for testing 418 | return None 419 | 420 | return None 421 | 422 | 423 | def convert_python_function_to_gigachat_function( 424 | function: Callable, 425 | ) -> GigaFunctionDescription: 426 | """Convert a Python function to an GigaChat function-calling API compatible dict. 427 | 428 | Assumes the Python function has type hints and a docstring with a description. If 429 | the docstring has Google Python style argument descriptions, these will be 430 | included as well. 431 | 432 | Args: 433 | function: The Python function to convert. 434 | 435 | Returns: 436 | The GigaChat function description. 437 | """ 438 | from langchain_core import tools 439 | 440 | func_name = _get_python_function_name(function) 441 | model = tools.create_schema_from_function( 442 | func_name, 443 | function, 444 | filter_args=(), 445 | parse_docstring=True, 446 | error_on_invalid_docstring=False, 447 | include_injected=False, 448 | ) 449 | _return_schema = create_return_schema_from_function(function) 450 | return convert_pydantic_to_gigachat_function( 451 | model, name=func_name, return_model=_return_schema, description=model.__doc__ 452 | ) 453 | 454 | 455 | def convert_to_gigachat_function( 456 | function: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool, type], 457 | ) -> Dict[str, Any]: 458 | """Convert a raw function/class to an GigaChat function. 459 | 460 | Args: 461 | function: 462 | A dictionary, Pydantic BaseModel class, TypedDict class, a LangChain 463 | Tool object, or a Python function. If a dictionary is passed in, it is 464 | assumed to already be a valid GigaChat function. 465 | 466 | Returns: 467 | A dict version of the passed in function which is compatible with the 468 | GigaChat function-calling API. 469 | """ 470 | from langchain_core.tools import BaseTool 471 | 472 | if isinstance(function, dict): 473 | return function 474 | elif isinstance(function, type) and is_basemodel_subclass(function): 475 | function = cast(Dict, convert_pydantic_to_gigachat_function(function)) 476 | elif isinstance(function, BaseTool): 477 | function = cast(Dict, format_tool_to_gigachat_function(function)) 478 | elif is_typeddict(function): 479 | function = cast( 480 | dict, _convert_typed_dict_to_gigachat_function(cast(type, function)) 481 | ) 482 | elif callable(function): 483 | function = cast(Dict, convert_python_function_to_gigachat_function(function)) 484 | else: 485 | raise ValueError( 486 | f"Unsupported function type {type(function)}. Functions must be passed in" 487 | f" as Dict, pydantic.BaseModel, or Callable." 488 | ) 489 | try: 490 | return gigachat_fix_schema(function) 491 | except IncorrectSchemaException: 492 | raise IncorrectSchemaException( 493 | SCHEMA_DO_NOT_SUPPORT_MESSAGE.format(schema=function) 494 | ) 495 | 496 | 497 | def convert_to_gigachat_tool( 498 | tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool], 499 | ) -> Dict[str, Any]: 500 | """Convert a raw function/class to an GigaChat tool. 501 | 502 | Args: 503 | tool: Either a dictionary, a pydantic.BaseModel class, Python function, or 504 | BaseTool. If a dictionary is passed in, it is assumed to already be a valid 505 | GigaChat tool, GigaChat function, 506 | or a JSON schema with top-level 'title' and 507 | 'description' keys specified. 508 | 509 | Returns: 510 | A dict version of the passed in tool which is compatible with the 511 | GigaChat tool-calling API. 512 | """ 513 | if isinstance(tool, dict) and tool.get("type") == "function" and "function" in tool: 514 | return tool 515 | function = convert_to_gigachat_function(tool) 516 | return {"type": "function", "function": function} 517 | -------------------------------------------------------------------------------- /libs/gigachat/tests/unit_tests/test_gigachat.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import hashlib 3 | from typing import Any, AsyncGenerator, Iterable, List, Tuple 4 | from unittest.mock import MagicMock 5 | 6 | import pytest 7 | from gigachat.models import ( 8 | ChatCompletion, 9 | ChatCompletionChunk, 10 | Choices, 11 | ChoicesChunk, 12 | Messages, 13 | MessagesChunk, 14 | MessagesRole, 15 | UploadedFile, 16 | Usage, 17 | ) 18 | from langchain_core.messages import ( 19 | AIMessage, 20 | AIMessageChunk, 21 | ChatMessage, 22 | FunctionMessage, 23 | HumanMessage, 24 | SystemMessage, 25 | ) 26 | from langchain_core.runnables import RunnableConfig 27 | from langchain_core.tools import tool 28 | from langchain_core.tools.base import InjectedToolArg 29 | from pydantic import BaseModel, Field 30 | from pytest_mock import MockerFixture 31 | from typing_extensions import Annotated 32 | 33 | from langchain_gigachat.chat_models.gigachat import ( 34 | GigaChat, 35 | _convert_dict_to_message, 36 | _convert_message_to_dict, 37 | ) 38 | from langchain_gigachat.tools.giga_tool import FewShotExamples, giga_tool 39 | from tests.unit_tests.stubs import FakeAsyncCallbackHandler, FakeCallbackHandler 40 | 41 | 42 | @pytest.fixture 43 | def chat_completion() -> ChatCompletion: 44 | return ChatCompletion( 45 | choices=[ 46 | Choices( 47 | message=Messages( 48 | id=None, role=MessagesRole.ASSISTANT, content="Bar Baz" 49 | ), 50 | index=0, 51 | finish_reason="stop", 52 | ) 53 | ], 54 | created=1678878333, 55 | model="GigaChat:v1.2.19.2", 56 | usage=Usage( 57 | prompt_tokens=18, 58 | completion_tokens=68, 59 | total_tokens=86, 60 | precached_prompt_tokens=0, 61 | ), 62 | object="chat.completion", 63 | ) 64 | 65 | 66 | @pytest.fixture 67 | def chat_completion_stream() -> List[ChatCompletionChunk]: 68 | return [ 69 | ChatCompletionChunk( 70 | choices=[ChoicesChunk(delta=MessagesChunk(content="Bar Baz"), index=0)], 71 | created=1695802242, 72 | model="GigaChat:v1.2.19.2", 73 | object="chat.completion", 74 | ), 75 | ChatCompletionChunk( 76 | choices=[ 77 | ChoicesChunk( 78 | delta=MessagesChunk(content=" Stream"), 79 | index=0, 80 | finish_reason="stop", 81 | ) 82 | ], 83 | created=1695802242, 84 | model="GigaChat:v1.2.19.2", 85 | object="chat.completion", 86 | ), 87 | ] 88 | 89 | 90 | @pytest.fixture 91 | def patch_gigachat( 92 | mocker: MockerFixture, 93 | chat_completion: ChatCompletion, 94 | chat_completion_stream: List[ChatCompletionChunk], 95 | ) -> None: 96 | mock = mocker.Mock() 97 | mock.chat.return_value = chat_completion 98 | mock.stream.return_value = chat_completion_stream 99 | 100 | mocker.patch("gigachat.GigaChat", return_value=mock) 101 | 102 | 103 | @pytest.fixture 104 | def patch_gigachat_achat( 105 | mocker: MockerFixture, chat_completion: ChatCompletion 106 | ) -> None: 107 | async def return_value_coroutine(value: Any) -> Any: 108 | return value 109 | 110 | mock = mocker.Mock() 111 | mock.achat.return_value = return_value_coroutine(chat_completion) 112 | 113 | mocker.patch("gigachat.GigaChat", return_value=mock) 114 | 115 | 116 | @pytest.fixture 117 | def patch_gigachat_astream( 118 | mocker: MockerFixture, chat_completion_stream: List[ChatCompletionChunk] 119 | ) -> None: 120 | async def return_value_async_generator(value: Iterable) -> AsyncGenerator: 121 | for chunk in value: 122 | yield chunk 123 | 124 | mock = mocker.Mock() 125 | mock.astream.return_value = return_value_async_generator(chat_completion_stream) 126 | 127 | mocker.patch("gigachat.GigaChat", return_value=mock) 128 | 129 | 130 | @pytest.fixture 131 | def patch_gigachat_upload_file( 132 | mocker: MockerFixture, 133 | chat_completion: ChatCompletion, 134 | chat_completion_stream: List[ChatCompletionChunk], 135 | ) -> MagicMock: 136 | mocker.patch("gigachat.GigaChat.chat", return_value=chat_completion) 137 | mocker.patch("gigachat.GigaChat.stream", return_value=chat_completion_stream) 138 | return mocker.patch( 139 | "gigachat.GigaChat.upload_file", 140 | return_value=UploadedFile( 141 | id="0", object="file", bytes=0, created_at=0, filename="", purpose="" 142 | ), 143 | ) 144 | 145 | 146 | @pytest.fixture 147 | def patch_gigachat_aupload_file( 148 | mocker: MockerFixture, 149 | chat_completion: ChatCompletion, 150 | chat_completion_stream: List[ChatCompletionChunk], 151 | ) -> MagicMock: 152 | async_mock = mocker.AsyncMock() 153 | async_mock.return_value = chat_completion 154 | mocker.patch("gigachat.GigaChat.achat", side_effect=async_mock) 155 | 156 | async def return_value_async_generator(value: Iterable) -> AsyncGenerator: 157 | for chunk in value: 158 | yield chunk 159 | 160 | mocker.patch( 161 | "gigachat.GigaChat.astream", 162 | return_value=return_value_async_generator(chat_completion_stream), 163 | ) 164 | async_mock = mocker.AsyncMock() 165 | async_mock.return_value = UploadedFile( 166 | id="0", object="file", bytes=0, created_at=0, filename="", purpose="" 167 | ) 168 | return mocker.patch("gigachat.GigaChat.aupload_file", side_effect=async_mock) 169 | 170 | 171 | UploadDialog = Tuple[List[HumanMessage], str, str] 172 | 173 | 174 | @pytest.fixture 175 | def upload_images_dialog() -> UploadDialog: 176 | image_1 = f"data:image/jpeg;base64,{base64.b64encode('123'.encode()).decode()}" 177 | image_2 = f"data:image/jpeg;base64,{base64.b64encode('124'.encode()).decode()}" 178 | hashed_1 = hashlib.sha256(image_1.encode()).hexdigest() 179 | hashed_2 = hashlib.sha256(image_2.encode()).hexdigest() 180 | return ( 181 | [ 182 | HumanMessage( 183 | content=[ 184 | {"type": "text", "text": "1"}, 185 | { 186 | "type": "image_url", 187 | "image_url": {"url": image_1}, 188 | }, 189 | ] 190 | ), 191 | HumanMessage( 192 | content=[ 193 | {"type": "text", "text": "2"}, 194 | { 195 | "type": "image_url", 196 | "image_url": {"url": image_2}, 197 | }, 198 | ] 199 | ), 200 | HumanMessage( 201 | content=[ 202 | {"type": "text", "text": "3"}, 203 | { 204 | "type": "image_url", 205 | "image_url": {"url": image_1}, 206 | }, 207 | ] 208 | ), 209 | ], 210 | hashed_1, 211 | hashed_2, 212 | ) 213 | 214 | 215 | def test__convert_dict_to_message_system() -> None: 216 | message = Messages(id=None, role=MessagesRole.SYSTEM, content="foo") 217 | expected = SystemMessage(content="foo") 218 | actual = _convert_dict_to_message(message) 219 | assert actual == expected 220 | 221 | 222 | def test__convert_dict_to_message_human() -> None: 223 | message = Messages(id=None, role=MessagesRole.USER, content="foo") 224 | expected = HumanMessage(content="foo") 225 | actual = _convert_dict_to_message(message) 226 | assert actual == expected 227 | 228 | 229 | def test__convert_dict_to_message_ai() -> None: 230 | message = Messages(id=None, role=MessagesRole.ASSISTANT, content="foo") 231 | expected = AIMessage(content="foo") 232 | actual = _convert_dict_to_message(message) 233 | assert actual == expected 234 | 235 | 236 | def test__convert_message_to_dict_system() -> None: 237 | message = SystemMessage(content="foo") 238 | expected = Messages(id=None, role=MessagesRole.SYSTEM, content="foo") 239 | actual = _convert_message_to_dict(message) 240 | assert actual == expected 241 | 242 | 243 | def test__convert_message_to_dict_human() -> None: 244 | message = HumanMessage(content="foo") 245 | expected = Messages(id=None, role=MessagesRole.USER, content="foo") 246 | actual = _convert_message_to_dict(message) 247 | assert actual == expected 248 | 249 | 250 | def test__convert_message_to_dict_ai() -> None: 251 | message = AIMessage(content="foo") 252 | expected = Messages(id=None, role=MessagesRole.ASSISTANT, content="foo") 253 | actual = _convert_message_to_dict(message) 254 | assert actual == expected 255 | 256 | 257 | @pytest.mark.parametrize("pairs", (("{}", "{}"), ("abc", '"abc"'), ("[]", "[]"))) 258 | def test__convert_message_to_dict_function(pairs: Any) -> None: 259 | """Checks if string, that was not JSON was converted to JSON""" 260 | message = FunctionMessage(content=pairs[0], id="1", name="func") 261 | expected = Messages(id=None, role=MessagesRole.FUNCTION, content=pairs[1]) 262 | 263 | actual = _convert_message_to_dict(message) 264 | 265 | assert actual == expected 266 | 267 | 268 | @pytest.mark.parametrize( 269 | "role", 270 | ( 271 | MessagesRole.SYSTEM, 272 | MessagesRole.USER, 273 | MessagesRole.ASSISTANT, 274 | MessagesRole.FUNCTION, 275 | ), 276 | ) 277 | def test__convert_message_to_dict_chat(role: MessagesRole) -> None: 278 | message = ChatMessage(role=role, content="foo") 279 | expected = Messages(id=None, role=role, content="foo") 280 | actual = _convert_message_to_dict(message) 281 | assert actual == expected 282 | 283 | 284 | def test_gigachat_predict(patch_gigachat: None) -> None: 285 | expected = "Bar Baz" 286 | 287 | llm = GigaChat() 288 | actual = llm.predict("bar") 289 | 290 | assert actual == expected 291 | 292 | 293 | def test_gigachat_predict_stream(patch_gigachat: None) -> None: 294 | expected = "Bar Baz Stream" 295 | llm = GigaChat() 296 | callback_handler = FakeCallbackHandler() 297 | actual = llm.predict("bar", stream=True, callbacks=[callback_handler]) 298 | assert actual == expected 299 | assert callback_handler.llm_streams == 2 300 | 301 | 302 | @pytest.mark.asyncio() 303 | async def test_gigachat_apredict(patch_gigachat_achat: None) -> None: 304 | expected = "Bar Baz" 305 | 306 | llm = GigaChat() 307 | actual = await llm.apredict("bar") 308 | 309 | assert actual == expected 310 | 311 | 312 | @pytest.mark.asyncio() 313 | async def test_gigachat_apredict_stream(patch_gigachat_astream: None) -> None: 314 | expected = "Bar Baz Stream" 315 | llm = GigaChat() 316 | callback_handler = FakeAsyncCallbackHandler() 317 | actual = await llm.apredict("bar", stream=True, callbacks=[callback_handler]) 318 | assert actual == expected 319 | assert callback_handler.llm_streams == 2 320 | 321 | 322 | def test_gigachat_stream(patch_gigachat: None) -> None: 323 | expected = [ 324 | AIMessageChunk(content="Bar Baz", response_metadata={"x_headers": {}}, id=""), 325 | AIMessageChunk( 326 | content=" Stream", 327 | response_metadata={ 328 | "model_name": "GigaChat:v1.2.19.2", 329 | "finish_reason": "stop", 330 | }, 331 | id="", 332 | ), 333 | ] 334 | 335 | llm = GigaChat() 336 | actual = [chunk for chunk in llm.stream("bar")] 337 | for chunk in actual: 338 | chunk.id = "" 339 | assert actual == expected 340 | 341 | 342 | @pytest.mark.asyncio() 343 | async def test_gigachat_astream(patch_gigachat_astream: None) -> None: 344 | expected = [ 345 | AIMessageChunk(content="Bar Baz", response_metadata={"x_headers": {}}, id=""), 346 | AIMessageChunk( 347 | content=" Stream", 348 | response_metadata={ 349 | "model_name": "GigaChat:v1.2.19.2", 350 | "finish_reason": "stop", 351 | }, 352 | id="", 353 | ), 354 | ] 355 | llm = GigaChat() 356 | actual = [chunk async for chunk in llm.astream("bar")] 357 | for chunk in actual: 358 | chunk.id = "" 359 | assert actual == expected 360 | 361 | 362 | def test_gigachat_build_payload_existing_parameter() -> None: 363 | llm = GigaChat() 364 | payload = llm._build_payload([], max_tokens=1) 365 | assert payload.max_tokens == 1 366 | 367 | 368 | def test_gigachat_build_payload_non_existing_parameter() -> None: 369 | llm = GigaChat() 370 | payload = llm._build_payload([], fake_parameter=1) 371 | assert getattr(payload, "fake_param", None) is None 372 | 373 | 374 | async def test_gigachat_bind_without_description() -> None: 375 | class Person(BaseModel): 376 | name: str = Field(..., title="Name", description="The person's name") 377 | 378 | llm = GigaChat() 379 | with pytest.raises(ValueError): 380 | llm.bind_functions(functions=[Person], function_call="Person") 381 | with pytest.raises(ValueError): 382 | llm.bind_tools(tools=[Person], tool_choice="Person") 383 | 384 | 385 | async def test_gigachat_bind_with_description() -> None: 386 | class Person(BaseModel): 387 | """Simple description""" 388 | 389 | name: str = Field(..., title="Name") 390 | 391 | llm = GigaChat() 392 | llm.bind_functions(functions=[Person], function_call="Person") 393 | llm.bind_tools(tools=[Person], tool_choice="Person") 394 | 395 | 396 | @tool 397 | def _test_tool( 398 | arg: str, config: RunnableConfig, injected: Annotated[str, InjectedToolArg] 399 | ) -> None: 400 | """Some description""" 401 | return 402 | 403 | 404 | def test_gigachat_bind_with_injected_vars() -> None: 405 | llm = GigaChat().bind_tools(tools=[_test_tool]) 406 | assert llm.kwargs["tools"][0]["function"]["parameters"]["required"] == ["arg"] # type: ignore[attr-defined] 407 | 408 | 409 | class SendSmsResult(BaseModel): 410 | status: str = Field(description="status") 411 | message: str = Field(description="message") 412 | 413 | 414 | few_shot_examples = [ 415 | { 416 | "request": "Sms 'hello' to 123", 417 | "params": {"recipient": "123", "message": "hello"}, 418 | } 419 | ] 420 | 421 | 422 | @giga_tool(few_shot_examples=few_shot_examples) 423 | def _test_send_sms( 424 | arg: str, config: RunnableConfig, injected: Annotated[str, InjectedToolArg] 425 | ) -> SendSmsResult: 426 | """Sends SMS message""" 427 | return SendSmsResult(status="success", message="SMS sent") 428 | 429 | 430 | def test_gigachat_bind_gigatool() -> None: 431 | llm = GigaChat().bind_tools(tools=[_test_send_sms]) 432 | assert llm.kwargs["tools"][0]["function"]["few_shot_examples"] == few_shot_examples # type: ignore[attr-defined] 433 | assert llm.kwargs["tools"][0]["function"]["return_parameters"] == { # type: ignore[attr-defined] 434 | "properties": { 435 | "status": {"description": "status", "type": "string"}, 436 | "message": {"description": "message", "type": "string"}, 437 | }, 438 | "required": ["status", "message"], 439 | "type": "object", 440 | } 441 | 442 | 443 | class SomeResult(BaseModel): 444 | """My desc""" 445 | 446 | @staticmethod 447 | def few_shot_examples() -> FewShotExamples: 448 | return [ 449 | { 450 | "request": "request example", 451 | "params": {"is_valid": 1, "description": "correct message"}, 452 | } 453 | ] 454 | 455 | value: int = Field(description="some value") 456 | description: str = Field(description="some descriptin") 457 | 458 | 459 | def test_structured_output() -> None: 460 | llm = GigaChat().with_structured_output(SomeResult) 461 | assert llm.steps[0].kwargs["function_call"] == {"name": "SomeResult"} # type: ignore[attr-defined] 462 | assert llm.steps[0].kwargs["tools"][0]["function"] == { # type: ignore[attr-defined] 463 | "name": "SomeResult", 464 | "description": "My desc", 465 | "parameters": { 466 | "properties": { 467 | "value": {"description": "some value", "type": "integer"}, 468 | "description": {"description": "some descriptin", "type": "string"}, 469 | }, 470 | "required": ["value", "description"], 471 | "type": "object", 472 | }, 473 | "return_parameters": None, 474 | "few_shot_examples": [ 475 | { 476 | "request": "request example", 477 | "params": {"is_valid": 1, "description": "correct message"}, 478 | } 479 | ], 480 | } 481 | 482 | 483 | def test_structured_output_json() -> None: 484 | llm = GigaChat().with_structured_output(SomeResult.model_json_schema()) 485 | assert llm.steps[0].kwargs["function_call"] == {"name": "SomeResult"} # type: ignore[attr-defined] 486 | assert llm.steps[0].kwargs["tools"][0]["function"] is not None # type: ignore[attr-defined] 487 | 488 | 489 | def test_structured_output_format_instructions() -> None: 490 | llm = GigaChat().with_structured_output(SomeResult, method="format_instructions") 491 | assert ( 492 | llm.steps[0].invoke(input="Hello") # type: ignore[attr-defined] 493 | == 'Hello\n\nThe output should be formatted as a JSON instance that conforms to the JSON schema below.\n\nAs an example, for the schema {"properties": {"foo": {"title": "Foo", "description": "a list of strings", "type": "array", "items": {"type": "string"}}}, "required": ["foo"]}\nthe object {"foo": ["bar", "baz"]} is a well-formatted instance of the schema. The object {"properties": {"foo": ["bar", "baz"]}} is not well-formatted.\n\nHere is the output schema:\n```\n{"description": "My desc", "properties": {"value": {"description": "some value", "title": "Value", "type": "integer"}, "description": {"description": "some descriptin", "title": "Description", "type": "string"}}, "required": ["value", "description"]}\n```' # noqa: E501 494 | ) 495 | 496 | 497 | def test_ai_message_json_serialization(patch_gigachat: None) -> None: 498 | llm = GigaChat() 499 | response = llm.invoke("hello") 500 | response.model_dump_json() 501 | 502 | 503 | def test_ai_upload_image( 504 | patch_gigachat_upload_file: MagicMock, upload_images_dialog: UploadDialog 505 | ) -> None: 506 | llm = GigaChat(auto_upload_images=True) 507 | dialog, hashed_1, hashed_2 = upload_images_dialog 508 | llm.invoke(dialog) 509 | assert len(llm._cached_images.keys()) == 2 510 | assert patch_gigachat_upload_file.call_count == 2 511 | assert patch_gigachat_upload_file.call_args_list[0][0][0][1] == b"123" 512 | assert patch_gigachat_upload_file.call_args_list[1][0][0][1] == b"124" 513 | assert hashed_1 in llm._cached_images 514 | assert hashed_2 in llm._cached_images 515 | 516 | 517 | async def test_ai_aupload_image( 518 | patch_gigachat_aupload_file: MagicMock, upload_images_dialog: UploadDialog 519 | ) -> None: 520 | llm = GigaChat(auto_upload_images=True) 521 | dialog, hashed_1, hashed_2 = upload_images_dialog 522 | await llm.ainvoke(dialog) 523 | assert len(llm._cached_images.keys()) == 2 524 | assert patch_gigachat_aupload_file.call_count == 2 525 | assert patch_gigachat_aupload_file.call_args_list[0][0][0][1] == b"123" 526 | assert patch_gigachat_aupload_file.call_args_list[1][0][0][1] == b"124" 527 | assert hashed_1 in llm._cached_images 528 | assert hashed_2 in llm._cached_images 529 | 530 | 531 | def test_ai_upload_image_stream( 532 | patch_gigachat_upload_file: MagicMock, upload_images_dialog: UploadDialog 533 | ) -> None: 534 | llm = GigaChat(auto_upload_images=True) 535 | dialog, hashed_1, hashed_2 = upload_images_dialog 536 | list(llm.stream(dialog)) 537 | assert len(llm._cached_images.keys()) == 2 538 | assert patch_gigachat_upload_file.call_count == 2 539 | assert patch_gigachat_upload_file.call_args_list[0][0][0][1] == b"123" 540 | assert patch_gigachat_upload_file.call_args_list[1][0][0][1] == b"124" 541 | assert hashed_1 in llm._cached_images 542 | assert hashed_2 in llm._cached_images 543 | 544 | 545 | async def test_ai_aupload_image_stream( 546 | patch_gigachat_aupload_file: MagicMock, upload_images_dialog: UploadDialog 547 | ) -> None: 548 | llm = GigaChat(auto_upload_images=True) 549 | dialog, hashed_1, hashed_2 = upload_images_dialog 550 | async for _ in llm.astream(dialog): 551 | pass 552 | assert len(llm._cached_images.keys()) == 2 553 | assert patch_gigachat_aupload_file.call_count == 2 554 | assert patch_gigachat_aupload_file.call_args_list[0][0][0][1] == b"123" 555 | assert patch_gigachat_aupload_file.call_args_list[1][0][0][1] == b"124" 556 | assert hashed_1 in llm._cached_images 557 | assert hashed_2 in llm._cached_images 558 | 559 | 560 | def test_ai_upload_disabled_image( 561 | patch_gigachat_upload_file: MagicMock, upload_images_dialog: UploadDialog 562 | ) -> None: 563 | llm = GigaChat() 564 | dialog, hashed_1, hashed_2 = upload_images_dialog 565 | llm.invoke(dialog) 566 | assert len(llm._cached_images.keys()) == 0 567 | assert patch_gigachat_upload_file.call_count == 0 568 | 569 | 570 | async def test_ai_aupload_disabled_image( 571 | patch_gigachat_aupload_file: MagicMock, upload_images_dialog: UploadDialog 572 | ) -> None: 573 | llm = GigaChat() 574 | dialog, hashed_1, hashed_2 = upload_images_dialog 575 | await llm.ainvoke(dialog) 576 | assert len(llm._cached_images.keys()) == 0 577 | assert patch_gigachat_aupload_file.call_count == 0 578 | 579 | 580 | def test_ai_upload_image_disabled_stream( 581 | patch_gigachat_upload_file: MagicMock, upload_images_dialog: UploadDialog 582 | ) -> None: 583 | llm = GigaChat() 584 | dialog, hashed_1, hashed_2 = upload_images_dialog 585 | list(llm.stream(dialog)) 586 | assert len(llm._cached_images.keys()) == 0 587 | assert patch_gigachat_upload_file.call_count == 0 588 | 589 | 590 | async def test_ai_aupload_image_disabled_stream( 591 | patch_gigachat_aupload_file: MagicMock, upload_images_dialog: UploadDialog 592 | ) -> None: 593 | llm = GigaChat() 594 | dialog, _, _ = upload_images_dialog 595 | async for _ in llm.astream(dialog): 596 | pass 597 | assert len(llm._cached_images.keys()) == 0 598 | assert patch_gigachat_aupload_file.call_count == 0 599 | 600 | 601 | def test__convert_message_with_attachments_to_dict_system( 602 | upload_images_dialog: UploadDialog, 603 | ) -> None: 604 | excepted = Messages(id=None, role=MessagesRole.USER, attachments=["1"], content="1") 605 | dialog, hashed_1, hashed_2 = upload_images_dialog 606 | actual = _convert_message_to_dict(dialog[0], {hashed_1: "1"}) 607 | assert actual == excepted 608 | 609 | 610 | def test__convert_message_with_attachments_no_cache_to_dict_system( 611 | upload_images_dialog: UploadDialog, 612 | ) -> None: 613 | excepted = Messages(id=None, role=MessagesRole.USER, content="1") 614 | dialog, hashed_1, hashed_2 = upload_images_dialog 615 | actual = _convert_message_to_dict(dialog[0]) 616 | assert actual == excepted 617 | -------------------------------------------------------------------------------- /libs/gigachat/tests/unit_tests/utils/test_function_calling.py: -------------------------------------------------------------------------------- 1 | # mypy: disable-error-code="annotation-unchecked" 2 | from typing import Annotated as ExtensionsAnnotated 3 | from typing import Any, Callable, List, Literal, Optional, Union 4 | from typing import TypedDict as TypingTypedDict 5 | 6 | import pytest 7 | from pydantic import BaseModel as BaseModelV2Maybe # pydantic: ignore 8 | from pydantic import Field as FieldV2Maybe # pydantic: ignore 9 | from typing_extensions import TypedDict as ExtensionsTypedDict 10 | 11 | from langchain_gigachat.tools.giga_tool import FewShotExamples, GigaBaseTool, giga_tool 12 | 13 | try: 14 | from typing import Annotated as TypingAnnotated # type: ignore[attr-defined] 15 | except ImportError: 16 | TypingAnnotated = ExtensionsAnnotated 17 | 18 | from langchain_core.runnables import Runnable, RunnableLambda 19 | from langchain_core.tools import BaseTool, StructuredTool, Tool, tool 20 | from pydantic import BaseModel, Field 21 | 22 | from langchain_gigachat.utils.function_calling import ( 23 | IncorrectSchemaException, 24 | convert_to_gigachat_function, 25 | ) 26 | 27 | 28 | @pytest.fixture() 29 | def pydantic() -> type[BaseModel]: 30 | class dummy_function(BaseModel): # noqa: N801 31 | """dummy function""" 32 | 33 | arg1: Optional[int] = Field(..., description="foo") 34 | arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'") 35 | 36 | return dummy_function 37 | 38 | 39 | @pytest.fixture() 40 | def annotated_function() -> Callable: 41 | def dummy_function( 42 | arg1: ExtensionsAnnotated[Optional[int], "foo"], 43 | arg2: ExtensionsAnnotated[Literal["bar", "baz"], "one of 'bar', 'baz'"], 44 | ) -> None: 45 | """dummy function""" 46 | 47 | return dummy_function 48 | 49 | 50 | @pytest.fixture() 51 | def function() -> Callable: 52 | def dummy_function(arg1: Optional[int], arg2: Literal["bar", "baz"]) -> None: 53 | """dummy function 54 | 55 | Args: 56 | arg1: foo 57 | arg2: one of 'bar', 'baz' 58 | """ 59 | 60 | return dummy_function 61 | 62 | 63 | @pytest.fixture() 64 | def runnable() -> Runnable: 65 | class Args(ExtensionsTypedDict): 66 | arg1: ExtensionsAnnotated[Optional[int], "foo"] 67 | arg2: ExtensionsAnnotated[Literal["bar", "baz"], "one of 'bar', 'baz'"] 68 | 69 | def dummy_function(input_dict: Args) -> None: 70 | pass 71 | 72 | return RunnableLambda(dummy_function) 73 | 74 | 75 | @pytest.fixture() 76 | def dummy_tool() -> BaseTool: 77 | class Schema(BaseModel): 78 | arg1: Optional[int] = Field(..., description="foo") 79 | arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'") 80 | 81 | class DummyFunction(BaseTool): # type: ignore[override] 82 | args_schema: type[BaseModel] = Schema 83 | name: str = "dummy_function" 84 | description: str = "dummy function" 85 | 86 | def _run(self, *args: Any, **kwargs: Any) -> Any: 87 | pass 88 | 89 | return DummyFunction() 90 | 91 | 92 | @pytest.fixture() 93 | def dummy_structured_tool() -> StructuredTool: 94 | class Schema(BaseModel): 95 | arg1: Optional[int] = Field(..., description="foo") 96 | arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'") 97 | 98 | return StructuredTool.from_function( 99 | lambda x: None, 100 | name="dummy_function", 101 | description="dummy function", 102 | args_schema=Schema, 103 | ) 104 | 105 | 106 | @pytest.fixture() 107 | def dummy_structured_tool_with_dict_args_schema() -> StructuredTool: 108 | schema = { 109 | "properties": { 110 | "arg1": {"description": "foo", "type": "integer"}, 111 | "arg2": { 112 | "description": "one of 'bar', 'baz'", 113 | "enum": ["bar", "baz"], 114 | "type": "string", 115 | }, 116 | }, 117 | "required": ["arg2"], 118 | "title": "dummy_function", 119 | "type": "object", 120 | } 121 | 122 | return StructuredTool( 123 | name="dummy_function", 124 | description="dummy function", 125 | args_schema=schema, # type: ignore[arg-type] 126 | ) 127 | 128 | 129 | @pytest.fixture() 130 | def dummy_pydantic() -> type[BaseModel]: 131 | class dummy_function(BaseModel): # noqa: N801 132 | """dummy function""" 133 | 134 | arg1: Optional[int] = Field(..., description="foo") 135 | arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'") 136 | 137 | return dummy_function 138 | 139 | 140 | @pytest.fixture() 141 | def dummy_pydantic_v2() -> type[BaseModelV2Maybe]: 142 | class dummy_function(BaseModelV2Maybe): # noqa: N801 143 | """dummy function""" 144 | 145 | arg1: Optional[int] = FieldV2Maybe(..., description="foo") 146 | arg2: Literal["bar", "baz"] = FieldV2Maybe( 147 | ..., description="one of 'bar', 'baz'" 148 | ) 149 | 150 | return dummy_function 151 | 152 | 153 | @pytest.fixture() 154 | def dummy_typing_typed_dict() -> type: 155 | class dummy_function(TypingTypedDict): # noqa: N801 156 | """dummy function""" 157 | 158 | arg1: TypingAnnotated[Optional[int], None, "foo"] # noqa: F821 159 | arg2: TypingAnnotated[Literal["bar", "baz"], ..., "one of 'bar', 'baz'"] # noqa: F722 160 | 161 | return dummy_function 162 | 163 | 164 | @pytest.fixture() 165 | def dummy_typing_typed_dict_docstring() -> type: 166 | class dummy_function(TypingTypedDict): # noqa: N801 167 | """dummy function 168 | 169 | Args: 170 | arg1: foo 171 | arg2: one of 'bar', 'baz' 172 | """ 173 | 174 | arg1: Optional[int] 175 | arg2: Literal["bar", "baz"] 176 | 177 | return dummy_function 178 | 179 | 180 | @pytest.fixture() 181 | def dummy_extensions_typed_dict() -> type: 182 | class dummy_function(ExtensionsTypedDict): # noqa: N801 183 | """dummy function""" 184 | 185 | arg1: ExtensionsAnnotated[int, None, "foo"] 186 | arg2: ExtensionsAnnotated[Literal["bar", "baz"], ..., "one of 'bar', 'baz'"] 187 | 188 | return dummy_function 189 | 190 | 191 | @pytest.fixture() 192 | def dummy_extensions_typed_dict_docstring() -> type: 193 | class dummy_function(ExtensionsTypedDict): # noqa: N801 194 | """dummy function 195 | 196 | Args: 197 | arg1: foo 198 | arg2: one of 'bar', 'baz' 199 | """ 200 | 201 | arg1: Optional[int] 202 | arg2: Literal["bar", "baz"] 203 | 204 | return dummy_function 205 | 206 | 207 | @pytest.fixture() 208 | def json_schema() -> dict: 209 | return { 210 | "name": "dummy_function", 211 | "description": "dummy function", 212 | "return_parameters": None, 213 | "few_shot_examples": None, 214 | "parameters": { 215 | "type": "object", 216 | "properties": { 217 | "arg1": {"description": "foo", "type": "integer"}, 218 | "arg2": { 219 | "description": "one of 'bar', 'baz'", 220 | "enum": ["bar", "baz"], 221 | "type": "string", 222 | }, 223 | }, 224 | "required": ["arg2"], 225 | }, 226 | } 227 | 228 | 229 | class Dummy: 230 | def dummy_function(self, arg1: Optional[int], arg2: Literal["bar", "baz"]) -> None: 231 | """dummy function 232 | 233 | Args: 234 | arg1: foo 235 | arg2: one of 'bar', 'baz' 236 | """ 237 | 238 | 239 | class DummyWithClassMethod: 240 | @classmethod 241 | def dummy_function(cls, arg1: Optional[int], arg2: Literal["bar", "baz"]) -> None: 242 | """dummy function 243 | 244 | Args: 245 | arg1: foo 246 | arg2: one of 'bar', 'baz' 247 | """ 248 | 249 | 250 | @pytest.fixture() 251 | def function_with_title_parameters() -> type[BaseModel]: 252 | class Resource(BaseModel): 253 | """ 254 | Represents a resource. Give it a good title and a short description. 255 | """ 256 | 257 | url: str 258 | title: str 259 | description: str 260 | 261 | class ExtractResources(BaseModel): 262 | """ 263 | Extract the 3-5 most relevant resources from a search result. 264 | """ 265 | 266 | resources: TypingAnnotated[List[Resource], Field(description="массив ресурсов")] 267 | 268 | return ExtractResources 269 | 270 | 271 | @pytest.mark.parametrize( 272 | "func", 273 | [ 274 | "pydantic", 275 | "function", 276 | "dummy_structured_tool", 277 | "dummy_structured_tool_with_dict_args_schema", 278 | "dummy_tool", 279 | "dummy_typing_typed_dict", 280 | "dummy_typing_typed_dict_docstring", 281 | "dummy_extensions_typed_dict", 282 | "dummy_extensions_typed_dict_docstring", 283 | "annotated_function", 284 | "dummy_pydantic", 285 | "json_schema", 286 | Dummy.dummy_function, 287 | DummyWithClassMethod.dummy_function, 288 | ], 289 | ) 290 | def test_convert_to_gigachat_function( 291 | func: Any, request: pytest.FixtureRequest 292 | ) -> None: 293 | expected = { 294 | "name": "dummy_function", 295 | "description": "dummy function", 296 | "return_parameters": None, 297 | "few_shot_examples": None, 298 | "parameters": { 299 | "type": "object", 300 | "properties": { 301 | "arg1": {"description": "foo", "type": "integer"}, 302 | "arg2": { 303 | "description": "one of 'bar', 'baz'", 304 | "enum": ["bar", "baz"], 305 | "type": "string", 306 | }, 307 | }, 308 | "required": ["arg2"], 309 | }, 310 | } 311 | if isinstance(func, str): 312 | func = request.getfixturevalue(func) 313 | 314 | actual = convert_to_gigachat_function(func) # type: ignore 315 | assert actual == expected 316 | 317 | 318 | def test_runnable(runnable: Runnable) -> None: 319 | expected = { 320 | "name": "dummy_function", 321 | "description": "dummy function", 322 | "return_parameters": None, 323 | "few_shot_examples": None, 324 | "parameters": { 325 | "type": "object", 326 | "properties": { 327 | "arg1": {"type": "integer", "description": ""}, 328 | "arg2": {"enum": ["bar", "baz"], "type": "string", "description": ""}, 329 | }, 330 | "required": ["arg2"], 331 | }, 332 | } 333 | actual = convert_to_gigachat_function( 334 | runnable.as_tool(description="dummy function") 335 | ) 336 | assert actual == expected 337 | 338 | 339 | def test_simple_tool() -> None: 340 | def my_function(input_string: str) -> str: # type: ignore 341 | pass 342 | 343 | tool = Tool(name="dummy_function", func=my_function, description="test description") 344 | actual = convert_to_gigachat_function(tool) 345 | expected = { 346 | "name": "dummy_function", 347 | "description": "test description", 348 | "return_parameters": None, 349 | "few_shot_examples": None, 350 | "parameters": {"properties": {}, "type": "object"}, 351 | } 352 | assert actual == expected 353 | 354 | 355 | @pytest.mark.xfail(reason="Direct pydantic v2 models not yet supported") 356 | def test_convert_to_openai_function_nested_v2() -> None: 357 | class NestedV2(BaseModelV2Maybe): 358 | nested_v2_arg1: int = FieldV2Maybe(..., description="foo") 359 | nested_v2_arg2: Literal["bar", "baz"] = FieldV2Maybe( 360 | ..., description="one of 'bar', 'baz'" 361 | ) 362 | 363 | def my_function(arg1: NestedV2) -> None: 364 | """dummy function""" 365 | 366 | convert_to_gigachat_function(my_function) 367 | 368 | 369 | def test_convert_to_gigachat_function_nested() -> None: 370 | class Nested(BaseModel): 371 | nested_arg1: int = Field(..., description="foo") 372 | nested_arg2: Literal["bar", "baz"] = Field( 373 | ..., description="one of 'bar', 'baz'" 374 | ) 375 | 376 | def my_function(arg1: Nested) -> None: 377 | """dummy function""" 378 | 379 | expected = { 380 | "name": "my_function", 381 | "description": "dummy function", 382 | "parameters": { 383 | "type": "object", 384 | "properties": { 385 | "arg1": { 386 | "type": "object", 387 | "description": "", 388 | "properties": { 389 | "nested_arg1": {"type": "integer", "description": "foo"}, 390 | "nested_arg2": { 391 | "type": "string", 392 | "enum": ["bar", "baz"], 393 | "description": "one of 'bar', 'baz'", 394 | }, 395 | }, 396 | "required": ["nested_arg1", "nested_arg2"], 397 | } 398 | }, 399 | "required": ["arg1"], 400 | }, 401 | "return_parameters": None, 402 | "few_shot_examples": None, 403 | } 404 | 405 | actual = convert_to_gigachat_function(my_function) 406 | assert actual == expected 407 | 408 | 409 | @pytest.mark.xfail( 410 | reason="Pydantic converts Optional[str] to str in .model_json_schema()" 411 | ) 412 | def test_function_optional_param() -> None: 413 | @tool 414 | def func5(a: Optional[str], b: str, c: Optional[list[Optional[str]]]) -> None: 415 | """A test function""" 416 | 417 | func = convert_to_gigachat_function(func5) 418 | req = func["parameters"]["required"] 419 | assert set(req) == {"b"} 420 | 421 | 422 | def test_function_no_params() -> None: 423 | def nullary_function() -> None: 424 | """nullary function""" 425 | 426 | func = convert_to_gigachat_function(nullary_function) 427 | req = func["parameters"].get("required") 428 | assert not req 429 | 430 | 431 | def test_convert_union_fail() -> None: 432 | @tool 433 | def magic_function(input: Union[int, float]) -> str: # type: ignore 434 | """Compute a magic function.""" 435 | 436 | with pytest.raises(IncorrectSchemaException): 437 | convert_to_gigachat_function(magic_function) 438 | 439 | 440 | def test_function_with_title_parameters( 441 | function_with_title_parameters: type[BaseModel], 442 | ) -> None: 443 | expected = { 444 | "name": "ExtractResources", 445 | "description": "Extract the 3-5 most relevant resources from a search result.", 446 | "parameters": { 447 | # noqa 448 | "properties": { 449 | "resources": { 450 | "description": "массив ресурсов", 451 | "items": { 452 | "description": "Represents a resource. Give it a good title and a short description.", # noqa 453 | # noqa 454 | "properties": { 455 | "url": {"type": "string"}, 456 | "title": {"type": "string"}, 457 | "description": {"type": "string"}, 458 | }, 459 | "required": ["url", "title", "description"], 460 | "type": "object", 461 | }, 462 | "type": "array", 463 | } 464 | }, 465 | "required": ["resources"], 466 | "type": "object", 467 | }, 468 | "return_parameters": None, 469 | "few_shot_examples": None, 470 | } 471 | actual = convert_to_gigachat_function(function_with_title_parameters) 472 | assert actual == expected 473 | 474 | 475 | def test_convert_to_function_no_args() -> None: 476 | @tool 477 | def empty_tool() -> str: 478 | """No args""" 479 | return "foo" 480 | 481 | actual = convert_to_gigachat_function(empty_tool) 482 | assert actual == { 483 | "name": "empty_tool", 484 | "description": "No args", 485 | "few_shot_examples": None, 486 | "return_parameters": None, 487 | "parameters": {"properties": {}, "type": "object"}, 488 | } 489 | 490 | 491 | # Test for return parameters and few shot examples 492 | 493 | 494 | class ReturnParameters(BaseModel): 495 | """dummy function""" 496 | 497 | arg1: Optional[int] = Field(..., description="foo") 498 | arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'") 499 | 500 | 501 | @pytest.fixture() 502 | def annotated_function_return_parameters() -> Callable: 503 | def dummy_function( # type: ignore 504 | arg1: ExtensionsAnnotated[int, "foo"], 505 | arg2: ExtensionsAnnotated[Literal["bar", "baz"], "one of 'bar', 'baz'"], 506 | ) -> ReturnParameters: 507 | """dummy function""" 508 | 509 | return dummy_function 510 | 511 | 512 | @pytest.fixture() 513 | def function_return_parameters() -> Callable: 514 | def dummy_function( # type: ignore 515 | arg1: int, arg2: Literal["bar", "baz"] 516 | ) -> ReturnParameters: 517 | """dummy function 518 | 519 | Args: 520 | arg1: foo 521 | arg2: one of 'bar', 'baz' 522 | """ 523 | 524 | return dummy_function 525 | 526 | 527 | @pytest.fixture() 528 | def dummy_return_parameters_with_fews_tool() -> GigaBaseTool: 529 | class Schema(BaseModel): 530 | arg1: int = Field(..., description="foo") 531 | arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'") 532 | 533 | class DummyFunction(GigaBaseTool): # type: ignore[override] 534 | args_schema: type[BaseModel] = Schema 535 | name: str = "dummy_function" 536 | description: str = "dummy function" 537 | return_schema: type[BaseModel] = ReturnParameters 538 | few_shot_examples: FewShotExamples = [ 539 | {"arg1": 1, "arg2": "bar"}, 540 | {"arg1": 2, "arg2": "baz"}, 541 | ] 542 | 543 | def _run(self, *args: Any, **kwargs: Any) -> Any: 544 | pass 545 | 546 | return DummyFunction() 547 | 548 | 549 | @pytest.fixture() 550 | def dummy_return_parameters_with_fews_decorator() -> Callable: 551 | @giga_tool( 552 | few_shot_examples=[{"arg1": 1, "arg2": "bar"}, {"arg1": 2, "arg2": "baz"}] 553 | ) 554 | def dummy_function( # type: ignore 555 | arg1: Optional[int] = Field(..., description="foo"), 556 | arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'"), 557 | ) -> ReturnParameters: 558 | """dummy function""" 559 | pass 560 | 561 | return dummy_function 562 | 563 | 564 | @pytest.fixture() 565 | def dummy_return_parameters_through_arg_with_fews_decorator() -> Callable: 566 | @giga_tool( 567 | few_shot_examples=[{"arg1": 1, "arg2": "bar"}, {"arg1": 2, "arg2": "baz"}], 568 | return_schema=ReturnParameters, 569 | ) 570 | def dummy_function( 571 | arg1: Optional[int] = Field(..., description="foo"), 572 | arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'"), 573 | ) -> None: 574 | """dummy function""" 575 | pass 576 | 577 | return dummy_function 578 | 579 | 580 | @pytest.fixture() 581 | def json_schema_return_parameters_with_fews() -> dict: 582 | return { 583 | "name": "dummy_function", 584 | "description": "dummy function", 585 | "return_parameters": { 586 | "type": "object", 587 | "description": "dummy function", 588 | "properties": { 589 | "arg1": {"description": "foo", "type": "integer"}, 590 | "arg2": { 591 | "description": "one of 'bar', 'baz'", 592 | "enum": ["bar", "baz"], 593 | "type": "string", 594 | }, 595 | }, 596 | "required": ["arg2"], 597 | }, 598 | "few_shot_examples": [{"arg1": 1, "arg2": "bar"}, {"arg1": 2, "arg2": "baz"}], 599 | "parameters": { 600 | "type": "object", 601 | "description": "dummy function", 602 | "properties": { 603 | "arg1": {"description": "foo", "type": "integer"}, 604 | "arg2": { 605 | "description": "one of 'bar', 'baz'", 606 | "enum": ["bar", "baz"], 607 | "type": "string", 608 | }, 609 | }, 610 | "required": ["arg2"], 611 | }, 612 | } 613 | 614 | 615 | class DummyReturnParameters: 616 | def dummy_function( # type: ignore 617 | self, arg1: Optional[int], arg2: Literal["bar", "baz"] 618 | ) -> ReturnParameters: 619 | """dummy function 620 | 621 | Args: 622 | arg1: foo 623 | arg2: one of 'bar', 'baz' 624 | """ 625 | 626 | 627 | class DummyReturnParametersWithClassMethod: 628 | @classmethod 629 | def dummy_function( # type: ignore 630 | cls, arg1: Optional[int], arg2: Literal["bar", "baz"] 631 | ) -> ReturnParameters: 632 | """dummy function 633 | 634 | Args: 635 | arg1: foo 636 | arg2: one of 'bar', 'baz' 637 | """ 638 | 639 | 640 | @pytest.mark.parametrize( 641 | "func", 642 | [ 643 | "annotated_function_return_parameters", 644 | "function_return_parameters", 645 | "dummy_return_parameters_with_fews_tool", 646 | "dummy_return_parameters_with_fews_decorator", 647 | "dummy_return_parameters_through_arg_with_fews_decorator", 648 | "json_schema_return_parameters_with_fews", 649 | DummyReturnParameters.dummy_function, 650 | DummyReturnParametersWithClassMethod.dummy_function, 651 | ], 652 | ) 653 | def test_function_with_return_parameters( 654 | func: Any, request: pytest.FixtureRequest 655 | ) -> None: 656 | return_params_expected = { 657 | "type": "object", 658 | "description": "dummy function", 659 | "properties": { 660 | "arg1": {"description": "foo", "type": "integer"}, 661 | "arg2": { 662 | "description": "one of 'bar', 'baz'", 663 | "enum": ["bar", "baz"], 664 | "type": "string", 665 | }, 666 | }, 667 | "required": ["arg2"], 668 | } 669 | if isinstance(func, str): 670 | func = request.getfixturevalue(func) 671 | 672 | actual_func = convert_to_gigachat_function(func) 673 | assert actual_func["return_parameters"] == return_params_expected 674 | 675 | 676 | @pytest.mark.parametrize( 677 | "func", 678 | [ 679 | "dummy_return_parameters_with_fews_tool", 680 | "dummy_return_parameters_with_fews_decorator", 681 | "dummy_return_parameters_through_arg_with_fews_decorator", 682 | "json_schema_return_parameters_with_fews", 683 | ], 684 | ) 685 | def test_function_with_few_shots(func: Any, request: pytest.FixtureRequest) -> None: 686 | few_shots_expected = [{"arg1": 1, "arg2": "bar"}, {"arg1": 2, "arg2": "baz"}] 687 | if isinstance(func, str): 688 | func = request.getfixturevalue(func) 689 | 690 | actual_func = convert_to_gigachat_function(func) 691 | assert actual_func["few_shot_examples"] == few_shots_expected 692 | -------------------------------------------------------------------------------- /libs/gigachat/langchain_gigachat/chat_models/gigachat.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import base64 4 | import copy 5 | import hashlib 6 | import json 7 | import logging 8 | import re 9 | from mimetypes import guess_extension 10 | from operator import itemgetter 11 | from typing import ( 12 | TYPE_CHECKING, 13 | Any, 14 | AsyncIterator, 15 | Callable, 16 | Dict, 17 | Iterator, 18 | List, 19 | Literal, 20 | Mapping, 21 | Optional, 22 | Sequence, 23 | Tuple, 24 | Type, 25 | TypedDict, 26 | TypeVar, 27 | Union, 28 | overload, 29 | ) 30 | from uuid import uuid4 31 | 32 | from langchain_core.callbacks import ( 33 | AsyncCallbackManagerForLLMRun, 34 | CallbackManagerForLLMRun, 35 | ) 36 | from langchain_core.language_models import LanguageModelInput 37 | from langchain_core.language_models.chat_models import ( 38 | BaseChatModel, 39 | agenerate_from_stream, 40 | generate_from_stream, 41 | ) 42 | from langchain_core.messages import ( 43 | AIMessage, 44 | AIMessageChunk, 45 | BaseMessage, 46 | BaseMessageChunk, 47 | ChatMessage, 48 | ChatMessageChunk, 49 | FunctionMessage, 50 | FunctionMessageChunk, 51 | HumanMessage, 52 | HumanMessageChunk, 53 | SystemMessage, 54 | SystemMessageChunk, 55 | ToolCall, 56 | ToolCallChunk, 57 | ToolMessage, 58 | ) 59 | from langchain_core.messages.ai import UsageMetadata 60 | from langchain_core.output_parsers import ( 61 | JsonOutputKeyToolsParser, 62 | JsonOutputParser, 63 | PydanticOutputParser, 64 | PydanticToolsParser, 65 | ) 66 | from langchain_core.output_parsers.base import OutputParserLike 67 | from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult 68 | from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough 69 | from langchain_core.tools import BaseTool 70 | from langchain_core.utils.pydantic import is_basemodel_subclass, pre_init 71 | from pydantic import BaseModel 72 | 73 | from langchain_gigachat.chat_models.base_gigachat import _BaseGigaChat 74 | from langchain_gigachat.utils.function_calling import ( 75 | convert_to_gigachat_function, 76 | convert_to_gigachat_tool, 77 | ) 78 | 79 | if TYPE_CHECKING: 80 | import gigachat.models as gm 81 | 82 | logger = logging.getLogger(__name__) 83 | 84 | IMAGE_SEARCH_REGEX = re.compile( 85 | r'.+?)"\sfuse=".+?"/>(?P.+)?' 86 | ) 87 | VIDEO_SEARCH_REGEX = re.compile( 88 | r'.+?)"\ssrc="(?P.+?)"\sfuse="true"/>(?P.+)?' # noqa 89 | ) 90 | 91 | 92 | def _validate_content(content: Any) -> Any: 93 | """If content is string, but not JSON - convert string to json-string""" 94 | if isinstance(content, str): 95 | try: 96 | json.loads(content) 97 | except ValueError: 98 | content = json.dumps(content, ensure_ascii=False) 99 | return content 100 | 101 | 102 | def _convert_dict_to_message(message: gm.Messages) -> BaseMessage: 103 | from gigachat.models import FunctionCall, MessagesRole 104 | 105 | additional_kwargs: Dict = {} 106 | tool_calls = [] 107 | if function_call := message.function_call: 108 | if isinstance(function_call, FunctionCall): 109 | additional_kwargs["function_call"] = dict(function_call) 110 | elif isinstance(function_call, dict): 111 | additional_kwargs["function_call"] = function_call 112 | if additional_kwargs.get("function_call") is not None: 113 | tool_calls = [ 114 | ToolCall( 115 | name=additional_kwargs["function_call"]["name"], 116 | args=additional_kwargs["function_call"]["arguments"], 117 | id=str(uuid4()), 118 | ) 119 | ] 120 | if message.functions_state_id: 121 | additional_kwargs["functions_state_id"] = message.functions_state_id 122 | match = IMAGE_SEARCH_REGEX.search(message.content) 123 | if match: 124 | additional_kwargs["image_uuid"] = match.group("UUID") 125 | additional_kwargs["postfix_message"] = match.group("postfix") 126 | match = VIDEO_SEARCH_REGEX.search(message.content) 127 | if match: 128 | additional_kwargs["cover_uuid"] = match.group("cover_UUID") 129 | additional_kwargs["video_uuid"] = match.group("UUID") 130 | additional_kwargs["postfix_message"] = match.group("postfix") 131 | if message.role == MessagesRole.SYSTEM: 132 | return SystemMessage(content=message.content) 133 | elif message.role == MessagesRole.USER: 134 | return HumanMessage(content=message.content) 135 | elif message.role == MessagesRole.ASSISTANT: 136 | return AIMessage( 137 | content=message.content, 138 | additional_kwargs=additional_kwargs, 139 | tool_calls=tool_calls, 140 | ) 141 | elif message.role == MessagesRole.FUNCTION: 142 | return FunctionMessage( 143 | name=message.name or "", content=_validate_content(message.content) 144 | ) 145 | else: 146 | raise TypeError(f"Got unknown role {message.role} {message}") 147 | 148 | 149 | def get_text_and_images_from_content( 150 | content: list[Union[str, dict]], cached_images: Dict[str, str] 151 | ) -> Tuple[str, List[str]]: 152 | text_parts = [] 153 | attachments = [] 154 | for content_part in content: 155 | if isinstance(content_part, str): 156 | text_parts.append(content_part) 157 | elif isinstance(content_part, dict): 158 | if content_part.get("type") == "text": 159 | text_parts.append(content_part["text"]) 160 | elif content_part.get("type") == "image_url": 161 | image_data = content_part["image_url"] 162 | if not isinstance(image_data, dict): 163 | continue 164 | if "giga_id" in content_part["image_url"]: 165 | attachments.append(content_part["image_url"].get("giga_id")) 166 | image_url = content_part["image_url"].get("url") 167 | hashed = hashlib.sha256(image_url.encode()).hexdigest() 168 | if hashed in cached_images: 169 | attachments.append(cached_images[hashed]) 170 | return " ".join(text_parts), attachments 171 | 172 | 173 | def _convert_message_to_dict( 174 | message: BaseMessage, cached_images: Optional[Dict[str, str]] = None 175 | ) -> gm.Messages: 176 | from gigachat.models import Messages, MessagesRole 177 | 178 | kwargs = {} 179 | if cached_images is None: 180 | cached_images = {} 181 | 182 | if isinstance(message.content, list): 183 | content, attachments = get_text_and_images_from_content( 184 | message.content, cached_images 185 | ) 186 | else: 187 | content, attachments = message.content, [] 188 | 189 | attachments += message.additional_kwargs.get("attachments", []) 190 | functions_state_id = message.additional_kwargs.get("functions_state_id", None) 191 | if functions_state_id: 192 | kwargs["functions_state_id"] = functions_state_id 193 | 194 | if isinstance(message, SystemMessage): 195 | kwargs["role"] = MessagesRole.SYSTEM 196 | kwargs["content"] = content 197 | elif isinstance(message, HumanMessage): 198 | kwargs["role"] = MessagesRole.USER 199 | if attachments: 200 | kwargs["attachments"] = attachments 201 | kwargs["content"] = content 202 | elif isinstance(message, AIMessage): 203 | if tool_calls := getattr(message, "tool_calls", None): 204 | function_call = copy.deepcopy(tool_calls[0]) 205 | 206 | if "args" in function_call: 207 | function_call["arguments"] = function_call.pop("args") 208 | else: 209 | function_call = message.additional_kwargs.get("function_call", None) 210 | kwargs["role"] = MessagesRole.ASSISTANT 211 | kwargs["content"] = content 212 | kwargs["function_call"] = function_call 213 | elif isinstance(message, ChatMessage): 214 | kwargs["role"] = message.role 215 | kwargs["content"] = content 216 | elif isinstance(message, FunctionMessage): 217 | kwargs["role"] = MessagesRole.FUNCTION 218 | # TODO Switch to using 'result' field in future GigaChat models 219 | kwargs["content"] = _validate_content(content) 220 | elif isinstance(message, ToolMessage): 221 | kwargs["role"] = MessagesRole.FUNCTION 222 | kwargs["content"] = _validate_content(content) 223 | else: 224 | raise TypeError(f"Got unknown type {message}") 225 | return Messages(**kwargs) 226 | 227 | 228 | def _convert_delta_to_message_chunk( 229 | _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk] 230 | ) -> BaseMessageChunk: 231 | role = _dict.get("role") 232 | content = _dict.get("content") or "" 233 | additional_kwargs: Dict = {} 234 | tool_call_chunks = [] 235 | if _dict.get("function_call"): 236 | function_call = dict(_dict["function_call"]) 237 | if "name" in function_call and function_call["name"] is None: 238 | function_call["name"] = "" 239 | additional_kwargs["function_call"] = function_call 240 | if additional_kwargs.get("function_call") is not None: 241 | tool_call_chunks = [ 242 | ToolCallChunk( 243 | name=additional_kwargs["function_call"]["name"], 244 | args=json.dumps(additional_kwargs["function_call"]["arguments"]), 245 | id=str(uuid4()), 246 | index=0, 247 | ) 248 | ] 249 | if _dict.get("functions_state_id"): 250 | additional_kwargs["functions_state_id"] = _dict["functions_state_id"] 251 | match = IMAGE_SEARCH_REGEX.search(content) 252 | if match: 253 | additional_kwargs["image_uuid"] = match.group("UUID") 254 | additional_kwargs["postfix_message"] = match.group("postfix") 255 | match = VIDEO_SEARCH_REGEX.search(content) 256 | if match: 257 | additional_kwargs["cover_uuid"] = match.group("cover_UUID") 258 | additional_kwargs["video_uuid"] = match.group("UUID") 259 | additional_kwargs["postfix_message"] = match.group("postfix") 260 | 261 | # if ( 262 | # role == "function_in_progress" 263 | # or default_class == FunctionInProgressMessageChunk 264 | # ): 265 | # return FunctionInProgressMessageChunk(content=content) 266 | 267 | if role == "user" or default_class == HumanMessageChunk: 268 | return HumanMessageChunk(content=content) 269 | elif ( 270 | role == "assistant" 271 | or default_class == AIMessageChunk 272 | or "functions_state_id" in _dict 273 | ): 274 | return AIMessageChunk( 275 | content=content, 276 | additional_kwargs=additional_kwargs, 277 | tool_call_chunks=tool_call_chunks, 278 | ) 279 | elif role == "system" or default_class == SystemMessageChunk: 280 | return SystemMessageChunk(content=content) 281 | elif role == "function" or default_class == FunctionMessageChunk: 282 | return FunctionMessageChunk( 283 | content=_validate_content(content), name=_dict["name"] 284 | ) 285 | elif role or default_class == ChatMessageChunk: 286 | return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type] 287 | else: 288 | return default_class(content=content) # type: ignore[call-arg] 289 | 290 | 291 | def _convert_function_to_dict(function: Dict) -> Any: 292 | from gigachat.models import Function, FunctionParameters 293 | 294 | res = Function(name=function["name"], description=function["description"]) 295 | 296 | if "parameters" in function: 297 | if isinstance(function["parameters"], dict): 298 | if "properties" in function["parameters"]: 299 | props = function["parameters"]["properties"] 300 | properties = {} 301 | 302 | for k, v in props.items(): 303 | properties[k] = {"type": v["type"], "description": v["description"]} 304 | 305 | res.parameters = FunctionParameters( 306 | type="object", 307 | properties=properties, # type: ignore[arg-type] 308 | required=props.get("required", []), 309 | ) 310 | else: 311 | raise TypeError( 312 | f"No properties in parameters in {function['parameters']}" 313 | ) 314 | else: 315 | raise TypeError(f"Got unknown type {function['parameters']}") 316 | 317 | return res 318 | 319 | 320 | class _FunctionCall(TypedDict): 321 | name: str 322 | 323 | 324 | _BM = TypeVar("_BM", bound=BaseModel) 325 | _DictOrPydanticClass = Union[Dict[str, Any], Type[_BM], Type] 326 | _DictOrPydantic = Union[Dict, _BM] 327 | 328 | 329 | class _AllReturnType(TypedDict): 330 | raw: BaseMessage 331 | parsed: Optional[_DictOrPydantic] 332 | parsing_error: Optional[BaseException] 333 | 334 | 335 | def trim_content_to_stop_sequence( 336 | content: str, stop_sequence: Optional[List[str]] 337 | ) -> Union[str, bool]: 338 | """ 339 | Обрезаем строку к стоп слову. 340 | Если стоп слово нашлось в строке возвращаем обрезанную строку. 341 | Если нет, то возвращаем False 342 | """ 343 | if stop_sequence is None: 344 | return False 345 | for stop_w in stop_sequence: 346 | try: 347 | index = content.index(stop_w) 348 | return content[:index] 349 | except ValueError: 350 | pass 351 | return False 352 | 353 | 354 | class GigaChat(_BaseGigaChat, BaseChatModel): 355 | """`GigaChat` large language models API. 356 | 357 | To use, provide credentials via token, login and password, 358 | or mTLS for secure access to the GigaChat API. 359 | 360 | Example Usage: 361 | .. code-block:: python 362 | 363 | from langchain_community.chat_models import GigaChat 364 | 365 | # Authorization with Token 366 | # (obtainable in the personal cabinet under Authorization Data): 367 | giga = GigaChat(credentials="YOUR_TOKEN") 368 | 369 | # Personal Space: 370 | giga = GigaChat(credentials="YOUR_TOKEN", scope="GIGACHAT_API_PERS") 371 | 372 | # Corporate Space: 373 | giga = GigaChat(credentials="YOUR_TOKEN", scope="GIGACHAT_API_CORP") 374 | 375 | # Authorization with Login and Password: 376 | giga = GigaChat( 377 | base_url="https://gigachat.devices.sberbank.ru/api/v1", 378 | user="YOUR_USERNAME", 379 | password="YOUR_PASSWORD", 380 | ) 381 | 382 | # Mutual Authentication via TLS (mTLS): 383 | giga = GigaChat( 384 | base_url="https://gigachat.devices.sberbank.ru/api/v1", 385 | ca_bundle_file="certs/ca.pem", # chain_pem.txt 386 | cert_file="certs/tls.pem", # published_pem.txt 387 | key_file="certs/tls.key", 388 | key_file_password="YOUR_KEY_PASSWORD", 389 | ) 390 | 391 | # Authorization with Temporary Token: 392 | giga = GigaChat(access_token="YOUR_TEMPORARY_TOKEN") 393 | 394 | """ 395 | 396 | """ Auto-upload Base-64 images. Not for production usage! """ 397 | auto_upload_images: bool = False 398 | """ 399 | Dict with cached images, with key as hashed 400 | base-64 image to File ID on GigaChat API 401 | """ 402 | _cached_images: Dict[str, str] = {} 403 | 404 | @pre_init 405 | def validate_environment(cls, values: Dict) -> Dict: 406 | values = super(GigaChat, cls).validate_environment(values) 407 | if values["auto_upload_images"]: 408 | logger.warning( 409 | "`auto_upload_images` is experiment option. " 410 | "Please, don't use it on production. " 411 | "Use instead GigaChat.upload_file method for uploading images" 412 | ) 413 | return values 414 | 415 | async def _aupload_images(self, messages: List[BaseMessage]) -> None: 416 | for message in messages: 417 | if isinstance(message.content, list): 418 | for content_part in message.content: 419 | if not isinstance(content_part, dict): 420 | continue 421 | if content_part.get("type") == "image_url": 422 | image_url = content_part["image_url"]["url"] 423 | matches = re.search(r"data:(.+);(.+),(.+)", image_url) 424 | if matches and not self.auto_upload_images: 425 | logger.warning( 426 | "You trying to send base-64 images, " 427 | "but parameter `auto_upload_images` is not True. " 428 | "Set it to True. " 429 | ) 430 | if not matches or not self.auto_upload_images: 431 | continue 432 | hashed = hashlib.sha256(image_url.encode()).hexdigest() 433 | if hashed not in self._cached_images: 434 | extension, type_, image_str = matches.groups() 435 | if type_ != "base64": 436 | continue 437 | file = await self.aupload_file( 438 | ( 439 | f"{uuid4()}{guess_extension(extension)}", 440 | base64.b64decode(image_str), 441 | ) 442 | ) 443 | self._cached_images[hashed] = file.id_ 444 | 445 | def _upload_images(self, messages: List[BaseMessage]) -> None: 446 | for message in messages: 447 | if isinstance(message.content, list): 448 | for content_part in message.content: 449 | if not isinstance(content_part, dict): 450 | continue 451 | if content_part.get("type") == "image_url": 452 | image_url = content_part["image_url"]["url"] 453 | matches = re.search(r"data:(.+);(.+),(.+)", image_url) 454 | if matches and not self.auto_upload_images: 455 | logger.warning( 456 | "You trying to send base-64 images, " 457 | "but parameter `auto_upload_images` is not True. " 458 | "Set it to True. " 459 | ) 460 | if not matches or not self.auto_upload_images: 461 | continue 462 | hashed = hashlib.sha256(image_url.encode()).hexdigest() 463 | if hashed not in self._cached_images: 464 | extension, type_, image_str = matches.groups() 465 | if type_ != "base64": 466 | continue 467 | file = self.upload_file( 468 | ( 469 | f"{uuid4()}{guess_extension(extension)}", 470 | base64.b64decode(image_str), 471 | ) 472 | ) 473 | 474 | self._cached_images[hashed] = file.id_ 475 | 476 | def _build_payload(self, messages: List[BaseMessage], **kwargs: Any) -> gm.Chat: 477 | from gigachat.models import Chat 478 | 479 | messages_dicts = [ 480 | _convert_message_to_dict(m, self._cached_images) for m in messages 481 | ] 482 | kwargs.pop("messages", None) 483 | 484 | functions = kwargs.pop("functions", []) 485 | for tool in kwargs.pop("tools", []): 486 | if tool.get("type", None) == "function" and isinstance(functions, List): 487 | functions.append(tool["function"]) 488 | 489 | function_call = kwargs.pop("function_call", None) 490 | 491 | payload_dict = { 492 | "messages": messages_dicts, 493 | "functions": functions, 494 | "function_call": function_call, 495 | "profanity_check": self.profanity_check, 496 | "temperature": self.temperature, 497 | "top_p": self.top_p, 498 | "max_tokens": self.max_tokens, 499 | "repetition_penalty": self.repetition_penalty, 500 | "update_interval": self.update_interval, 501 | **kwargs, 502 | } 503 | 504 | payload = Chat.parse_obj(payload_dict) 505 | 506 | if self.verbose: 507 | logger.warning( 508 | "Giga request: %s", 509 | json.dumps( 510 | payload.dict(exclude_none=True, by_alias=True), ensure_ascii=False 511 | ), 512 | ) 513 | 514 | return payload 515 | 516 | def _check_finish_reason(self, finish_reason: str | None) -> None: 517 | if finish_reason and finish_reason not in {"stop", "function_call"}: 518 | logger.warning("Giga generation stopped with reason: %s", finish_reason) 519 | 520 | def _create_chat_result(self, response: gm.ChatCompletion) -> ChatResult: 521 | generations = [] 522 | x_headers = None 523 | for res in response.choices: 524 | message = _convert_dict_to_message(res.message) 525 | x_headers = response.x_headers if response.x_headers else {} 526 | if x_headers.get("x-request-id") is not None: 527 | message.id = x_headers["x-request-id"] 528 | if isinstance(message, AIMessage): 529 | message.usage_metadata = UsageMetadata( 530 | output_tokens=response.usage.completion_tokens, 531 | input_tokens=response.usage.prompt_tokens, 532 | total_tokens=response.usage.total_tokens, 533 | input_token_details={ 534 | "cache_read": response.usage.precached_prompt_tokens or 0 535 | }, 536 | ) 537 | finish_reason = res.finish_reason 538 | self._check_finish_reason(finish_reason) 539 | gen = ChatGeneration( 540 | message=message, 541 | generation_info={ 542 | "finish_reason": finish_reason, 543 | "model_name": response.model, 544 | }, 545 | ) 546 | generations.append(gen) 547 | if self.verbose: 548 | logger.warning("Giga response: %s", message.content) 549 | llm_output = { 550 | "token_usage": response.usage.dict(), 551 | "model_name": response.model, 552 | "x_headers": x_headers, 553 | } 554 | return ChatResult(generations=generations, llm_output=llm_output) 555 | 556 | def _generate( 557 | self, 558 | messages: List[BaseMessage], 559 | stop: Optional[List[str]] = None, 560 | run_manager: Optional[CallbackManagerForLLMRun] = None, 561 | stream: Optional[bool] = None, 562 | **kwargs: Any, 563 | ) -> ChatResult: 564 | should_stream = stream if stream is not None else self.streaming 565 | if should_stream: 566 | stream_iter = self._stream( 567 | messages, stop=stop, run_manager=run_manager, **kwargs 568 | ) 569 | return generate_from_stream(stream_iter) 570 | 571 | self._upload_images(messages) 572 | payload = self._build_payload(messages, **kwargs) 573 | response = self._client.chat(payload) 574 | for choice in response.choices: 575 | trimmed_content = trim_content_to_stop_sequence( 576 | choice.message.content, stop 577 | ) 578 | if isinstance(trimmed_content, str): 579 | choice.message.content = trimmed_content 580 | break 581 | 582 | return self._create_chat_result(response) 583 | 584 | async def _agenerate( 585 | self, 586 | messages: List[BaseMessage], 587 | stop: Optional[List[str]] = None, 588 | run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, 589 | stream: Optional[bool] = None, 590 | **kwargs: Any, 591 | ) -> ChatResult: 592 | should_stream = stream if stream is not None else self.streaming 593 | if should_stream: 594 | stream_iter = self._astream( 595 | messages, stop=stop, run_manager=run_manager, **kwargs 596 | ) 597 | return await agenerate_from_stream(stream_iter) 598 | 599 | await self._aupload_images(messages) 600 | payload = self._build_payload(messages, **kwargs) 601 | response = await self._client.achat(payload) 602 | for choice in response.choices: 603 | trimmed_content = trim_content_to_stop_sequence( 604 | choice.message.content, stop 605 | ) 606 | if isinstance(trimmed_content, str): 607 | choice.message.content = trimmed_content 608 | break 609 | 610 | return self._create_chat_result(response) 611 | 612 | def _stream( 613 | self, 614 | messages: List[BaseMessage], 615 | stop: Optional[List[str]] = None, 616 | run_manager: Optional[CallbackManagerForLLMRun] = None, 617 | **kwargs: Any, 618 | ) -> Iterator[ChatGenerationChunk]: 619 | self._upload_images(messages) 620 | payload = self._build_payload(messages, **kwargs) 621 | message_content = "" 622 | 623 | first_chunk = True 624 | for chunk_d in self._client.stream(payload): 625 | chunk = {} 626 | if not isinstance(chunk_d, dict): 627 | chunk = chunk_d.dict() 628 | else: 629 | chunk = chunk_d 630 | if len(chunk["choices"]) == 0: 631 | continue 632 | 633 | choice = chunk["choices"][0] 634 | content = choice.get("delta", {}).get("content", {}) 635 | message_content += content 636 | if trim_content_to_stop_sequence(message_content, stop): 637 | return 638 | chunk_m = _convert_delta_to_message_chunk(choice["delta"], AIMessageChunk) 639 | usage_metadata = None 640 | if chunk.get("usage"): 641 | usage_metadata = UsageMetadata( 642 | output_tokens=chunk["usage"]["completion_tokens"], 643 | input_tokens=chunk["usage"]["prompt_tokens"], 644 | total_tokens=chunk["usage"]["total_tokens"], 645 | input_token_details={ 646 | "cache_read": chunk["usage"].get("precached_prompt_tokens", 0) 647 | }, 648 | ) 649 | if isinstance(chunk_m, AIMessageChunk): 650 | chunk_m.usage_metadata = usage_metadata 651 | x_headers = chunk.get("x_headers") 652 | x_headers = x_headers if isinstance(x_headers, dict) else {} 653 | if "x-request-id" in x_headers: 654 | chunk_m.id = x_headers["x-request-id"] 655 | 656 | generation_info = {} 657 | if finish_reason := choice.get("finish_reason"): 658 | self._check_finish_reason(finish_reason) 659 | generation_info["model_name"] = chunk.get("model") 660 | generation_info["finish_reason"] = finish_reason 661 | if first_chunk: 662 | generation_info["x_headers"] = x_headers 663 | first_chunk = False 664 | if run_manager: 665 | run_manager.on_llm_new_token(content) 666 | 667 | yield ChatGenerationChunk(message=chunk_m, generation_info=generation_info) 668 | 669 | async def _astream( 670 | self, 671 | messages: List[BaseMessage], 672 | stop: Optional[List[str]] = None, 673 | run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, 674 | **kwargs: Any, 675 | ) -> AsyncIterator[ChatGenerationChunk]: 676 | await self._aupload_images(messages) 677 | payload = self._build_payload(messages, **kwargs) 678 | message_content = "" 679 | first_chunk = True 680 | 681 | async for chunk_d in self._client.astream(payload): 682 | chunk = {} 683 | if not isinstance(chunk_d, dict): 684 | chunk = chunk_d.dict() 685 | else: 686 | chunk = chunk_d 687 | if len(chunk["choices"]) == 0: 688 | continue 689 | 690 | choice = chunk["choices"][0] 691 | content = choice.get("delta", {}).get("content", {}) 692 | message_content += content 693 | if trim_content_to_stop_sequence(message_content, stop): 694 | return 695 | chunk_m = _convert_delta_to_message_chunk(choice["delta"], AIMessageChunk) 696 | usage_metadata = None 697 | if chunk.get("usage"): 698 | usage_metadata = UsageMetadata( 699 | output_tokens=chunk["usage"]["completion_tokens"], 700 | input_tokens=chunk["usage"]["prompt_tokens"], 701 | total_tokens=chunk["usage"]["total_tokens"], 702 | input_token_details={ 703 | "cache_read": chunk["usage"].get("precached_prompt_tokens", 0) 704 | }, 705 | ) 706 | if isinstance(chunk_m, AIMessageChunk): 707 | chunk_m.usage_metadata = usage_metadata 708 | x_headers = chunk.get("x_headers") 709 | x_headers = x_headers if isinstance(x_headers, dict) else {} 710 | if isinstance(x_headers, dict) and "x-request-id" in x_headers: 711 | chunk_m.id = x_headers["x-request-id"] 712 | 713 | generation_info = {} 714 | if finish_reason := choice.get("finish_reason"): 715 | self._check_finish_reason(finish_reason) 716 | generation_info["model_name"] = chunk.get("model") 717 | generation_info["finish_reason"] = finish_reason 718 | if first_chunk: 719 | generation_info["x_headers"] = x_headers 720 | first_chunk = False 721 | if run_manager: 722 | await run_manager.on_llm_new_token(content) 723 | 724 | yield ChatGenerationChunk(message=chunk_m, generation_info=generation_info) 725 | 726 | def bind_functions( 727 | self, 728 | functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, type]], 729 | function_call: Optional[str] = None, 730 | **kwargs: Any, 731 | ) -> Runnable[LanguageModelInput, BaseMessage]: 732 | """Bind functions (and other objects) to this chat model. 733 | 734 | Args: 735 | functions: A list of function definitions to bind to this chat model. 736 | Can be a dictionary, pydantic model, or callable. Pydantic 737 | models and callables will be automatically converted to 738 | their schema dictionary representation. 739 | function_call: Which function to require the model to call. 740 | Must be the name of the single provided function or 741 | "auto" to automatically determine which function to call 742 | (if any). 743 | kwargs: Any additional parameters to pass to the 744 | :class:`~langchain.runnable.Runnable` constructor. 745 | """ 746 | formatted_functions = [convert_to_gigachat_function(fn) for fn in functions] 747 | if function_call is not None: 748 | if len(formatted_functions) != 1: 749 | raise ValueError( 750 | "When specifying `function_call`, you must provide exactly one " 751 | "function." 752 | ) 753 | if formatted_functions[0]["name"] != function_call: 754 | raise ValueError( 755 | f"Function call {function_call} was specified, but the only " 756 | f"provided function was {formatted_functions[0]['name']}." 757 | ) 758 | function_call_ = {"name": function_call} 759 | kwargs = {**kwargs, "function_call": function_call_} 760 | return super().bind(functions=formatted_functions, **kwargs) 761 | 762 | # TODO: Fix typing. 763 | @overload # type: ignore[override] 764 | def with_structured_output( 765 | self, 766 | schema: Optional[_DictOrPydanticClass] = None, 767 | *, 768 | method: Literal[ 769 | "function_calling", "json_mode", "format_instructions" 770 | ] = "function_calling", 771 | include_raw: Literal[True] = True, 772 | **kwargs: Any, 773 | ) -> Runnable[LanguageModelInput, _AllReturnType]: ... 774 | 775 | @overload 776 | def with_structured_output( 777 | self, 778 | schema: Optional[_DictOrPydanticClass] = None, 779 | *, 780 | method: Literal[ 781 | "function_calling", "json_mode", "format_instructions" 782 | ] = "function_calling", 783 | include_raw: Literal[False] = False, 784 | **kwargs: Any, 785 | ) -> Runnable[LanguageModelInput, _DictOrPydantic]: ... 786 | 787 | def with_structured_output( 788 | self, 789 | schema: Optional[_DictOrPydanticClass] = None, 790 | *, 791 | method: Literal[ 792 | "function_calling", "json_mode", "format_instructions" 793 | ] = "function_calling", 794 | include_raw: bool = False, 795 | **kwargs: Any, 796 | ) -> Runnable[LanguageModelInput, _DictOrPydantic]: 797 | """Model wrapper that returns outputs formatted to match the given schema.""" 798 | if kwargs: 799 | raise ValueError(f"Received unsupported arguments {kwargs}") 800 | is_pydantic_schema = _is_pydantic_class(schema) 801 | if method == "function_calling": 802 | if schema is None: 803 | raise ValueError( 804 | "schema must be specified when method is 'function_calling'. " 805 | "Received None." 806 | ) 807 | func = convert_to_gigachat_tool(schema)["function"] 808 | key_name = func.get( 809 | "name", func.get("title") 810 | ) # In case of pydantic from JSON (For openai capability) 811 | if is_pydantic_schema: 812 | output_parser: OutputParserLike = PydanticToolsParser( 813 | tools=[schema], # type: ignore 814 | first_tool_only=True, 815 | ) 816 | else: 817 | output_parser = JsonOutputKeyToolsParser( 818 | key_name=key_name, first_tool_only=True 819 | ) 820 | llm = self.bind_tools([schema], tool_choice=key_name) 821 | else: 822 | llm = self 823 | output_parser = ( 824 | PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type] 825 | if is_pydantic_schema 826 | else JsonOutputParser() 827 | ) 828 | if method == "format_instructions": 829 | from langchain_core.prompt_values import ChatPromptValue 830 | from langchain_core.runnables import RunnableLambda 831 | 832 | def add_format_instructions( 833 | _input: LanguageModelInput, format_instructions: str 834 | ) -> LanguageModelInput: 835 | if isinstance(_input, ChatPromptValue): 836 | messages = _input.messages 837 | return type(messages)( 838 | list(messages) + [HumanMessage(format_instructions)] # type: ignore[call-arg] 839 | ) 840 | elif isinstance(_input, str): 841 | return _input + f"\n\n{format_instructions}" 842 | elif isinstance(_input, Sequence): 843 | return type(_input)( 844 | list(_input) + [HumanMessage(format_instructions)] # type: ignore[call-arg] 845 | ) 846 | else: 847 | msg = ( 848 | f"Invalid input type {type(_input)}. " 849 | "Must be a PromptValue, str, or list of BaseMessages." 850 | ) 851 | raise ValueError(msg) # noqa: TRY004 852 | 853 | add_format_instructions_chain = RunnableLambda( 854 | lambda _input: add_format_instructions( 855 | _input, output_parser.get_format_instructions() 856 | ) 857 | ) 858 | llm = add_format_instructions_chain | llm 859 | 860 | if include_raw: 861 | parser_assign = RunnablePassthrough.assign( 862 | parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None 863 | ) 864 | parser_none = RunnablePassthrough.assign(parsed=lambda _: None) 865 | parser_with_fallback = parser_assign.with_fallbacks( 866 | [parser_none], exception_key="parsing_error" 867 | ) 868 | return RunnableMap(raw=llm) | parser_with_fallback 869 | else: 870 | return llm | output_parser 871 | 872 | def bind_tools( 873 | self, 874 | tools: Sequence[ 875 | Union[Dict[str, Any], Type, Type[BaseModel], Callable, BaseTool] 876 | ], # noqa 877 | *, 878 | tool_choice: Optional[ 879 | Union[dict, str, Literal["auto", "any", "none"], bool] 880 | ] = None, 881 | **kwargs: Any, 882 | ) -> Runnable[LanguageModelInput, BaseMessage]: 883 | """Bind tool-like objects to this chat model. 884 | Assumes model is compatible with GigaChat tool-calling API.""" 885 | formatted_tools = [convert_to_gigachat_tool(tool) for tool in tools] 886 | if tool_choice is not None and tool_choice: 887 | if isinstance(tool_choice, str): 888 | if tool_choice not in ("auto", "none"): 889 | tool_choice = {"name": tool_choice} 890 | elif isinstance(tool_choice, bool) and tool_choice: 891 | if not formatted_tools: 892 | raise ValueError("tool_choice can not be bool if tools are empty") 893 | tool_choice = {"name": formatted_tools[0]["name"]} 894 | elif isinstance(tool_choice, dict): 895 | pass 896 | else: 897 | raise ValueError( 898 | f"Unrecognized tool_choice type. Expected str, bool or dict. " 899 | f"Received: {tool_choice}" 900 | ) 901 | kwargs["function_call"] = tool_choice 902 | return super().bind(tools=formatted_tools, **kwargs) 903 | 904 | 905 | def _is_pydantic_class(obj: Any) -> bool: 906 | return isinstance(obj, type) and is_basemodel_subclass(obj) 907 | --------------------------------------------------------------------------------