├── aidial_sdk ├── utils │ ├── __init__.py │ ├── json.py │ ├── errors.py │ ├── env.py │ ├── logging.py │ ├── _reflection.py │ ├── log_config.py │ ├── _content_stream.py │ ├── _attachment.py │ ├── _indexed_list.py │ ├── _cancel_scope.py │ ├── pydantic.py │ └── streaming.py ├── deployment │ ├── __init__.py │ ├── rate.py │ ├── configuration.py │ ├── truncate_prompt.py │ └── tokenize.py ├── telemetry │ ├── __init__.py │ ├── types.py │ └── init.py ├── __init__.py ├── embeddings │ ├── __init__.py │ ├── base.py │ ├── response.py │ └── request.py ├── chat_completion │ ├── enums.py │ ├── _types.py │ ├── choice_base.py │ ├── base.py │ ├── __init__.py │ ├── function_call.py │ ├── function_tool_call.py │ └── stage.py ├── pydantic_v1 │ └── __init__.py ├── _errors.py ├── _pydantic │ ├── __init__.py │ ├── _compat.py │ └── _model_config.py ├── pydantic │ └── v2.py └── header_propagator.py ├── examples ├── image_size │ ├── __init__.py │ ├── requirements.txt │ ├── Dockerfile │ ├── app │ │ ├── image.py │ │ └── main.py │ └── README.md ├── render_text │ ├── __init__.py │ ├── requirements.txt │ ├── Dockerfile │ ├── README.md │ └── app │ │ ├── image.py │ │ └── main.py ├── tic_tac_toe │ ├── __init__.py │ ├── requirements.txt │ ├── README.md │ ├── Dockerfile │ └── app │ │ ├── request.py │ │ └── game.py ├── echo │ ├── requirements.txt │ ├── Dockerfile │ ├── README.md │ └── app.py └── langchain_rag │ ├── requirements.txt │ ├── .env.example │ ├── Dockerfile │ ├── utils.py │ └── README.md ├── poetry.toml ├── .github ├── CODEOWNERS ├── ISSUE_TEMPLATE │ ├── config.yml │ ├── 02_feature_request.yml │ └── 01_bug_report.yml ├── workflows │ ├── pr-title-check.yml │ ├── pr.yml │ ├── release.yml │ ├── slash-command-dispatch.yml │ └── dependabot-automation.yml ├── pull_request_template.md └── dependabot.yml ├── tests ├── __init__.py ├── utils │ ├── __init__.py │ ├── text.py │ ├── constants.py │ ├── chat_completion_validation.py │ ├── client.py │ ├── uvicorn.py │ ├── json.py │ ├── sharing.py │ ├── errors.py │ ├── endpoint_test.py │ ├── pydantic.py │ ├── chunks.py │ └── tokenization.py ├── applications │ ├── noop.py │ ├── single_choice.py │ ├── simple_embeddings.py │ ├── validator.py │ ├── idle.py │ ├── custom_endpoints.py │ ├── echo.py │ └── broken.py ├── conftest.py ├── test_max_prompt_tokens.py ├── test_rate_response.py ├── examples │ ├── test_image_size.py │ ├── test_echo.py │ └── test_render_text.py ├── test_exception.py ├── test_pydantic.py ├── test_custom_endpoints.py ├── test_is_implemented.py ├── test_embeddings.py ├── test_single_choice.py ├── header_propagation │ └── client.py ├── test_function_calling.py ├── test_serialization.py ├── test_chat_completion_validation.py ├── test_tokenize.py ├── test_discarded_messages.py ├── test_request_indices.py ├── test_response_headers.py ├── test_tool_calling.py ├── test_bearer_token.py ├── test_disconnect.py ├── test_cancellation.py ├── test_truncate_prompt.py ├── test_extra_fields.py ├── test_request_tools_parsing.py ├── test_header_propagation.py └── benchmark │ └── benchmark_merge_chunks.py ├── .vscode ├── extensions.json └── settings.json ├── .gitignore ├── CONTRIBUTING.md ├── .flake8 ├── trivy.yaml ├── SECURITY.md ├── .ort.yml ├── Makefile ├── noxfile.py └── pyproject.toml /aidial_sdk/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aidial_sdk/deployment/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aidial_sdk/telemetry/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/image_size/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/render_text/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/tic_tac_toe/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /poetry.toml: -------------------------------------------------------------------------------- 1 | [virtualenvs] 2 | in-project = true -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @adubovik 2 | /.github/ @nepalevov @alexey-ban -------------------------------------------------------------------------------- /examples/echo/requirements.txt: -------------------------------------------------------------------------------- 1 | aidial-sdk>=0.10 2 | uvicorn==0.30.1 -------------------------------------------------------------------------------- /examples/tic_tac_toe/requirements.txt: -------------------------------------------------------------------------------- 1 | aidial-sdk>=0.19 2 | uvicorn==0.30.1 3 | pydantic>1,<3 -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | pytest.register_assert_rewrite("tests.utils.chunks") 4 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | pytest.register_assert_rewrite("tests.utils.json") 4 | -------------------------------------------------------------------------------- /examples/image_size/requirements.txt: -------------------------------------------------------------------------------- 1 | aidial-sdk>=0.10 2 | pillow==10.3.0 3 | aiohttp==3.12.14 4 | uvicorn==0.30.1 -------------------------------------------------------------------------------- /examples/render_text/requirements.txt: -------------------------------------------------------------------------------- 1 | aidial-sdk>=0.10 2 | pillow==10.3.0 3 | aiohttp==3.12.14 4 | uvicorn==0.30.1 -------------------------------------------------------------------------------- /aidial_sdk/utils/json.py: -------------------------------------------------------------------------------- 1 | def remove_nones(d: dict) -> dict: 2 | return {k: v for k, v in d.items() if v is not None} 3 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "ms-python.python", 4 | "ms-python.black-formatter", 5 | "ms-python.isort" 6 | ] 7 | } -------------------------------------------------------------------------------- /aidial_sdk/__init__.py: -------------------------------------------------------------------------------- 1 | from aidial_sdk.application import DIALApp 2 | from aidial_sdk.exceptions import HTTPException 3 | from aidial_sdk.utils.logging import logger 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .venv 3 | __pycache__ 4 | build/ 5 | *.egg-info/ 6 | dist/ 7 | .pytest_cache 8 | .nox 9 | .idea/ 10 | .env 11 | ~* 12 | .vscode/launch.json 13 | .python-version -------------------------------------------------------------------------------- /examples/tic_tac_toe/README.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | 3 | An example of an application configurable by user. 4 | 5 | Upon start the Docker image exposes `openai/deployments/app/(chat/completions|configure)` endpoints at port `5000`. -------------------------------------------------------------------------------- /tests/utils/text.py: -------------------------------------------------------------------------------- 1 | # str.removeprefix method was introduced in Python 3.9 2 | def removeprefix(text: str, prefix: str) -> str: 3 | if text.startswith(prefix): 4 | return text[len(prefix) :] 5 | return text 6 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: 📝 Talk in Discord 4 | url: https://discord.com/channels/831132799535939614/1172686875299950683 5 | about: Discuss your question in Discord 6 | -------------------------------------------------------------------------------- /aidial_sdk/deployment/rate.py: -------------------------------------------------------------------------------- 1 | from aidial_sdk._pydantic import Field, StrictStr 2 | from aidial_sdk.deployment.from_request_mixin import FromRequestDeploymentMixin 3 | 4 | 5 | class RateRequest(FromRequestDeploymentMixin): 6 | response_id: StrictStr = Field(None, alias="responseId") 7 | rate: bool = False 8 | -------------------------------------------------------------------------------- /aidial_sdk/embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | from aidial_sdk.embeddings.base import Embeddings 2 | from aidial_sdk.embeddings.request import ( 3 | Attachment, 4 | EmbeddingsMultiModalInput, 5 | EmbeddingsRequestCustomFields, 6 | Request, 7 | ) 8 | from aidial_sdk.embeddings.response import Embedding, Response, Usage 9 | -------------------------------------------------------------------------------- /examples/langchain_rag/requirements.txt: -------------------------------------------------------------------------------- 1 | aidial-sdk>=0.10 2 | langchain==0.3.26 3 | langchain-community==0.3.27 4 | langchain-openai==0.2.6 5 | langchain-text-splitters==0.3.9 6 | tiktoken==0.7.0 7 | openai==1.54.0 8 | httpx==0.27.2 9 | beautifulsoup4==4.12.3 10 | chromadb==0.5.4 11 | uvicorn==0.30.1 12 | pypdf==6.4.0 -------------------------------------------------------------------------------- /examples/langchain_rag/.env.example: -------------------------------------------------------------------------------- 1 | DIAL_URL=DIAL_URL 2 | 3 | CHAT_MODEL=gpt-4 4 | EMBEDDINGS_MODEL=text-embedding-ada-002 5 | API_VERSION=2024-02-01 6 | 7 | # Enabling debug logs for DIAL SDK 8 | DIAL_SDK_LOG=debug 9 | 10 | # Enabling debug logs for OpenAI and Langchain 11 | OPENAI_LOG=debug 12 | LANGCHAIN_DEBUG=true 13 | -------------------------------------------------------------------------------- /aidial_sdk/utils/errors.py: -------------------------------------------------------------------------------- 1 | from aidial_sdk.exceptions import RuntimeServerError 2 | from aidial_sdk.utils.logging import log_error 3 | 4 | RUNTIME_ERROR_MESSAGE = "Error during processing the request" 5 | 6 | 7 | def runtime_error(reason: str): 8 | log_error(reason) 9 | return RuntimeServerError(RUNTIME_ERROR_MESSAGE) 10 | -------------------------------------------------------------------------------- /tests/applications/noop.py: -------------------------------------------------------------------------------- 1 | from aidial_sdk.chat_completion import ChatCompletion, Request, Response 2 | 3 | 4 | class NoopApplication(ChatCompletion): 5 | async def chat_completion( 6 | self, request: Request, response: Response 7 | ) -> None: 8 | with response.create_single_choice(): 9 | pass 10 | -------------------------------------------------------------------------------- /aidial_sdk/embeddings/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from aidial_sdk.embeddings.request import Request 4 | from aidial_sdk.embeddings.response import Response 5 | 6 | 7 | class Embeddings(ABC): 8 | @abstractmethod 9 | async def embeddings(self, request: Request) -> Response: 10 | """Implement embeddings logic""" 11 | -------------------------------------------------------------------------------- /aidial_sdk/chat_completion/enums.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class FinishReason(str, Enum): 5 | STOP = "stop" 6 | LENGTH = "length" 7 | FUNCTION_CALL = "function_call" 8 | TOOL_CALLS = "tool_calls" 9 | CONTENT_FILTER = "content_filter" 10 | 11 | 12 | class Status(str, Enum): 13 | COMPLETED = "completed" 14 | FAILED = "failed" 15 | -------------------------------------------------------------------------------- /aidial_sdk/utils/env.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | 5 | def env_bool(name: str, default: bool) -> bool: 6 | return os.getenv(name, str(default)).lower() in ["true", "1"] 7 | 8 | 9 | def env_var_list(name: str) -> List[str]: 10 | value = os.getenv(name) 11 | if value is None: 12 | return [] 13 | return value.split(",") 14 | -------------------------------------------------------------------------------- /.github/workflows/pr-title-check.yml: -------------------------------------------------------------------------------- 1 | name: "Validate PR title" 2 | 3 | on: 4 | pull_request_target: 5 | types: 6 | - opened 7 | - edited 8 | - synchronize 9 | 10 | jobs: 11 | pr-title-check: 12 | uses: epam/ai-dial-ci/.github/workflows/pr-title-check.yml@2.8.1 13 | secrets: 14 | ACTIONS_BOT_TOKEN: ${{ secrets.ACTIONS_BOT_TOKEN }} 15 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to contribute 2 | 3 | As an open-source project in a rapidly developing field, we are open to contributions, whether it be in the form of a new feature, improved infrastructure, or better documentation. 4 | 5 | For detailed information on how to contribute, see the full [contributing documentation](https://github.com/epam/ai-dial/blob/main/CONTRIBUTING.md). 6 | -------------------------------------------------------------------------------- /aidial_sdk/chat_completion/_types.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import TYPE_CHECKING, Union 3 | 4 | from aidial_sdk.chat_completion.chunks import ( 5 | BaseChunk, 6 | EndChunk, 7 | ExceptionChunk, 8 | ) 9 | 10 | if TYPE_CHECKING: 11 | ChunkQueue = asyncio.Queue[Union[BaseChunk, ExceptionChunk, EndChunk]] 12 | else: 13 | ChunkQueue = asyncio.Queue 14 | -------------------------------------------------------------------------------- /.github/workflows/pr.yml: -------------------------------------------------------------------------------- 1 | name: PR Workflow 2 | 3 | on: 4 | pull_request: 5 | branches: [development, release-*] 6 | 7 | jobs: 8 | run_tests: 9 | uses: epam/ai-dial-ci/.github/workflows/python_package_pr.yml@2.8.1 10 | secrets: inherit 11 | with: 12 | python-version: 3.9 13 | code-checks-python-versions: '["3.9", "3.10", "3.11", "3.12", "3.13"]' 14 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release Workflow 2 | 3 | on: 4 | push: 5 | branches: [development, release-*] 6 | 7 | jobs: 8 | release: 9 | uses: epam/ai-dial-ci/.github/workflows/python_package_release.yml@2.8.1 10 | secrets: inherit 11 | with: 12 | python-version: 3.9 13 | code-checks-python-versions: '["3.9", "3.10", "3.11", "3.12", "3.13"]' 14 | -------------------------------------------------------------------------------- /tests/utils/constants.py: -------------------------------------------------------------------------------- 1 | import fastapi 2 | 3 | from aidial_sdk._pydantic import SecretStr 4 | from aidial_sdk.chat_completion import Request 5 | 6 | _DUMMY_FASTAPI_REQUEST = fastapi.Request({"type": "http"}) 7 | 8 | DUMMY_DIAL_REQUEST = Request( 9 | headers={}, 10 | original_request=_DUMMY_FASTAPI_REQUEST, 11 | api_key_secret=SecretStr("dummy_key"), 12 | deployment_id="", 13 | messages=[], 14 | ) 15 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | # E501 string literal is too long 3 | # W503 line break before binary operator 4 | # E203 whitespace before ':' (triggered on list slices like xs[i : i + 5]) 5 | # E704 multiple statements on one line (def) 6 | ignore = E501, W503, E203, E704 7 | exclude = 8 | .git, 9 | .tmp, 10 | .venv, 11 | .conda, 12 | .nox, 13 | .pytest_cache 14 | __pycache__, 15 | _pydantic.py, 16 | __init__.py 17 | -------------------------------------------------------------------------------- /examples/echo/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11-alpine 2 | 3 | ENV PYTHONDONTWRITEBYTECODE=1 4 | ENV PYTHONUNBUFFERED=1 5 | 6 | COPY requirements.txt . 7 | RUN python -m pip install -r requirements.txt 8 | 9 | WORKDIR /app 10 | COPY . /app 11 | 12 | RUN adduser -u 1001 --disabled-password --gecos "" appuser && chown -R appuser /app 13 | USER appuser 14 | 15 | EXPOSE 5000 16 | CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "5000"] 17 | -------------------------------------------------------------------------------- /examples/image_size/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11-alpine 2 | 3 | ENV PYTHONDONTWRITEBYTECODE=1 4 | ENV PYTHONUNBUFFERED=1 5 | 6 | COPY requirements.txt . 7 | RUN python -m pip install -r requirements.txt 8 | 9 | WORKDIR /app 10 | COPY . /app 11 | 12 | RUN adduser -u 1001 --disabled-password --gecos "" appuser && chown -R appuser /app 13 | USER appuser 14 | 15 | EXPOSE 5000 16 | CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "5000"] 17 | -------------------------------------------------------------------------------- /examples/langchain_rag/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11-slim 2 | 3 | ENV PYTHONDONTWRITEBYTECODE=1 4 | ENV PYTHONUNBUFFERED=1 5 | 6 | COPY requirements.txt . 7 | RUN python -m pip install -r requirements.txt 8 | 9 | WORKDIR /app 10 | COPY . /app 11 | 12 | RUN adduser -u 1001 --disabled-password --gecos "" appuser && chown -R appuser /app 13 | USER appuser 14 | 15 | EXPOSE 5000 16 | CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "5000"] 17 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "files.trimTrailingWhitespace": true, 3 | "[python]": { 4 | "editor.defaultFormatter": "ms-python.black-formatter", 5 | "editor.formatOnSave": true, 6 | "editor.codeActionsOnSave": { 7 | "source.organizeImports": "explicit" 8 | }, 9 | "editor.tabSize": 4 10 | }, 11 | "python.testing.pytestArgs": ["."], 12 | "python.testing.unittestEnabled": false, 13 | "python.testing.pytestEnabled": true, 14 | } -------------------------------------------------------------------------------- /examples/render_text/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11-alpine 2 | 3 | ENV PYTHONDONTWRITEBYTECODE=1 4 | ENV PYTHONUNBUFFERED=1 5 | 6 | COPY requirements.txt . 7 | RUN python -m pip install -r requirements.txt 8 | 9 | WORKDIR /app 10 | COPY . /app 11 | 12 | RUN adduser -u 1001 --disabled-password --gecos "" appuser && chown -R appuser /app 13 | USER appuser 14 | 15 | EXPOSE 5000 16 | CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "5000"] 17 | -------------------------------------------------------------------------------- /aidial_sdk/deployment/configuration.py: -------------------------------------------------------------------------------- 1 | import fastapi 2 | 3 | from aidial_sdk.deployment.from_request_mixin import FromRequestDeploymentMixin 4 | from aidial_sdk.utils.pydantic import ExtraAllowModel 5 | 6 | 7 | class ConfigurationRequest(FromRequestDeploymentMixin): 8 | @staticmethod 9 | async def get_request_body(request: fastapi.Request) -> dict: 10 | return {} 11 | 12 | 13 | class ConfigurationResponse(ExtraAllowModel): 14 | pass 15 | -------------------------------------------------------------------------------- /tests/utils/chat_completion_validation.py: -------------------------------------------------------------------------------- 1 | from tests.applications.validator import RequestValidator, ValidatorApplication 2 | from tests.utils.client import create_app_client 3 | 4 | 5 | def validate_chat_completion( 6 | request: dict, request_validator: RequestValidator 7 | ) -> None: 8 | client = create_app_client( 9 | ValidatorApplication(request_validator=request_validator) 10 | ) 11 | 12 | client.post("chat/completions", json=request) 13 | -------------------------------------------------------------------------------- /examples/tic_tac_toe/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11-alpine 2 | 3 | RUN apk add --no-cache git 4 | 5 | ENV PYTHONDONTWRITEBYTECODE=1 6 | ENV PYTHONUNBUFFERED=1 7 | 8 | COPY requirements.txt . 9 | RUN python -m pip install -r requirements.txt 10 | 11 | WORKDIR /app 12 | COPY . /app 13 | 14 | RUN adduser -u 1001 --disabled-password --gecos "" appuser && chown -R appuser /app 15 | USER appuser 16 | 17 | EXPOSE 5000 18 | CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "5000"] 19 | -------------------------------------------------------------------------------- /tests/applications/single_choice.py: -------------------------------------------------------------------------------- 1 | from aidial_sdk.chat_completion import ChatCompletion, Request, Response 2 | 3 | 4 | class SingleChoiceApplication(ChatCompletion): 5 | async def chat_completion( 6 | self, request: Request, response: Response 7 | ) -> None: 8 | response.set_response_id("test_id") 9 | response.set_created(0) 10 | 11 | with response.create_single_choice() as choice: 12 | choice.append_content("Test response content") 13 | -------------------------------------------------------------------------------- /examples/echo/README.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | 3 | An example of a simple text-to-text DIAL application. 4 | 5 | It returns back the content and attachments from the last user message. 6 | 7 | Upon start the Docker image exposes `openai/deployments/echo/chat/completions` endpoint at port `5000`. 8 | 9 | ## Usage 10 | 11 | Find how to integrate the application into the DIAL Core and call it using DIAL API in the [cookbook](https://github.com/epam/ai-dial/blob/main/dial-cookbook/examples/how_to_call_text_to_text_applications.ipynb). -------------------------------------------------------------------------------- /trivy.yaml: -------------------------------------------------------------------------------- 1 | # Trivy configuration file 2 | # https://aquasecurity.github.io/trivy/latest/docs/references/configuration/config-file/ 3 | # Can be deleted after public ecr mirror will be added by default 4 | db: 5 | no-progress: true 6 | repository: 7 | - ghcr.io/aquasecurity/trivy-db:2 8 | - public.ecr.aws/aquasecurity/trivy-db:2 9 | java-repository: 10 | - ghcr.io/aquasecurity/trivy-java-db:1 11 | - public.ecr.aws/aquasecurity/trivy-java-db:1 12 | misconfiguration: 13 | checks-bundle-repository: public.ecr.aws/aquasecurity/trivy-checks -------------------------------------------------------------------------------- /aidial_sdk/embeddings/response.py: -------------------------------------------------------------------------------- 1 | from typing import List, Literal, Union 2 | 3 | from aidial_sdk.utils.pydantic import ExtraAllowModel 4 | 5 | 6 | class Embedding(ExtraAllowModel): 7 | embedding: Union[str, List[float]] 8 | index: int 9 | object: Literal["embedding"] = "embedding" 10 | 11 | 12 | class Usage(ExtraAllowModel): 13 | prompt_tokens: int 14 | total_tokens: int 15 | 16 | 17 | class EmbeddingResponse(ExtraAllowModel): 18 | data: List[Embedding] 19 | model: str 20 | object: Literal["list"] = "list" 21 | usage: Usage 22 | 23 | 24 | Response = EmbeddingResponse 25 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import httpx 2 | import pytest 3 | 4 | from aidial_sdk import DIALApp 5 | from aidial_sdk.chat_completion.base import ChatCompletion 6 | from tests.utils.uvicorn import run_uvicorn_in_thread 7 | 8 | 9 | @pytest.fixture 10 | async def test_http_client(chat_completion: ChatCompletion): 11 | app = DIALApp().add_chat_completion("app", chat_completion) 12 | 13 | async with run_uvicorn_in_thread(app) as base_url: 14 | async with httpx.AsyncClient( 15 | base_url=f"{base_url}/openai/deployments/app", 16 | headers={"api-key": "test-api-key"}, 17 | ) as client: 18 | yield client 19 | -------------------------------------------------------------------------------- /tests/applications/simple_embeddings.py: -------------------------------------------------------------------------------- 1 | from aidial_sdk.embeddings.base import Embeddings 2 | from aidial_sdk.embeddings.request import Request 3 | from aidial_sdk.embeddings.response import Embedding, Response, Usage 4 | 5 | 6 | class SimpleEmbeddings(Embeddings): 7 | async def embeddings(self, request: Request) -> Response: 8 | n = 1 9 | if isinstance(request.input, list): 10 | n = len(request.input) 11 | return Response( 12 | data=[Embedding(embedding=[float(i)], index=i) for i in range(n)], 13 | model="dummy", 14 | usage=Usage(prompt_tokens=n, total_tokens=n), 15 | ) 16 | -------------------------------------------------------------------------------- /aidial_sdk/chat_completion/choice_base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from aidial_sdk.chat_completion.chunks import BaseChunk 4 | 5 | 6 | class ChoiceBase(ABC): 7 | @property 8 | @abstractmethod 9 | def index(self) -> int: 10 | pass 11 | 12 | @property 13 | @abstractmethod 14 | def opened(self) -> bool: 15 | pass 16 | 17 | @property 18 | @abstractmethod 19 | def closed(self) -> bool: 20 | pass 21 | 22 | @property 23 | @abstractmethod 24 | def has_function_call(self) -> bool: 25 | pass 26 | 27 | @abstractmethod 28 | def send_chunk(self, chunk: BaseChunk) -> None: 29 | pass 30 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ### Applicable issues 2 | 3 | 4 | - fixes # 5 | 6 | ### Description of changes 7 | 8 | 9 | 10 | ### Checklist 11 | 12 | 13 | 14 | - [ ] Title of the pull request follows [Conventional Commits specification](https://www.conventionalcommits.org/en/v1.0.0/) 15 | 16 | By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license. 17 | -------------------------------------------------------------------------------- /examples/tic_tac_toe/app/request.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from aidial_sdk.chat_completion import Message, Request 4 | 5 | 6 | def get_message_form_value(message: Message) -> Optional[dict]: 7 | cc = message.custom_content 8 | if cc is None: 9 | return None 10 | return cc.form_value 11 | 12 | 13 | def get_message_state(message: Message) -> Optional[dict]: 14 | cc = message.custom_content 15 | if cc is None: 16 | return None 17 | return cc.state 18 | 19 | 20 | def get_configuration(request: Request) -> dict: 21 | cf = request.custom_fields 22 | assert cf is not None and cf.configuration is not None 23 | return cf.configuration 24 | -------------------------------------------------------------------------------- /tests/test_max_prompt_tokens.py: -------------------------------------------------------------------------------- 1 | from tests.utils.chat_completion_validation import validate_chat_completion 2 | 3 | 4 | def test_max_prompt_tokens_is_set(): 5 | validate_chat_completion( 6 | request={ 7 | "messages": [{"role": "user", "content": "Test content"}], 8 | "max_prompt_tokens": 15, 9 | }, 10 | request_validator=lambda r: r.max_prompt_tokens == 15, 11 | ) 12 | 13 | 14 | def test_max_prompt_tokens_is_unset(): 15 | validate_chat_completion( 16 | request={ 17 | "messages": [{"role": "user", "content": "Test content"}], 18 | }, 19 | request_validator=lambda r: not r.max_prompt_tokens, 20 | ) 21 | -------------------------------------------------------------------------------- /tests/applications/validator.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional 2 | 3 | from aidial_sdk.chat_completion import ChatCompletion, Request, Response 4 | 5 | # It can be either function that raises AssertionError, or lambda that returns 6 | # boolean, that will be used for assertion 7 | RequestValidator = Callable[[Request], Optional[bool]] 8 | 9 | 10 | class ValidatorApplication(ChatCompletion): 11 | def __init__(self, request_validator: RequestValidator): 12 | self.request_validator = request_validator 13 | 14 | async def chat_completion( 15 | self, request: Request, response: Response 16 | ) -> None: 17 | result = self.request_validator(request) 18 | if result is not None: 19 | assert result 20 | with response.create_single_choice(): 21 | pass 22 | -------------------------------------------------------------------------------- /aidial_sdk/deployment/truncate_prompt.py: -------------------------------------------------------------------------------- 1 | from typing import List, Literal, Union 2 | 3 | from aidial_sdk._pydantic._compat import BaseModel 4 | from aidial_sdk.chat_completion.request import ChatCompletionRequest 5 | from aidial_sdk.deployment.from_request_mixin import FromRequestDeploymentMixin 6 | 7 | 8 | class TruncatePromptRequest(FromRequestDeploymentMixin): 9 | inputs: List[ChatCompletionRequest] 10 | 11 | 12 | class TruncatePromptSuccess(BaseModel): 13 | status: Literal["success"] = "success" 14 | discarded_messages: List[int] 15 | 16 | 17 | class TruncatePromptError(BaseModel): 18 | status: Literal["error"] = "error" 19 | error: str 20 | 21 | 22 | TruncatePromptResult = Union[TruncatePromptSuccess, TruncatePromptError] 23 | 24 | 25 | class TruncatePromptResponse(BaseModel): 26 | outputs: List[TruncatePromptResult] 27 | -------------------------------------------------------------------------------- /examples/image_size/app/image.py: -------------------------------------------------------------------------------- 1 | import base64 2 | from io import BytesIO 3 | from typing import Tuple 4 | 5 | import aiohttp 6 | from PIL import Image 7 | 8 | 9 | def get_image_base64_size(image_base64: str) -> Tuple[int, int]: 10 | image_binary = base64.b64decode(image_base64) 11 | img = Image.open(BytesIO(image_binary)) 12 | return img.size 13 | 14 | 15 | def bytes_to_base64(data: bytes) -> str: 16 | return base64.b64encode(data).decode() 17 | 18 | 19 | async def download_image_as_bytes(url: str) -> bytes: 20 | async with aiohttp.ClientSession() as session: 21 | async with session.get(url) as response: 22 | response.raise_for_status() 23 | return await response.content.read() 24 | 25 | 26 | async def download_image_as_base64(url: str) -> str: 27 | return bytes_to_base64(await download_image_as_bytes(url)) 28 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Reporting Security Issues 2 | 3 | We take all security reports seriously. We appreciate your efforts to responsibly disclose your findings and will make every effort to acknowledge your contributions. 4 | 5 | ⚠️ Please do *not* file GitHub issues for security vulnerabilities as they are public! ⚠️ 6 | 7 | To report a security issue, please use the GitHub Security Advisory ["Report a Vulnerability"](https://github.com/epam/ai-dial-sdk/security/advisories/new) tab. Tip: In this form, only the title and description are mandatory. 8 | 9 | We will send a response indicating the next steps in handling your report. After the initial reply to your report, we will keep you informed of the progress toward a fix and full announcement and may ask for additional information or guidance. 10 | 11 | When we receive such reports, we will investigate and subsequently address any potential vulnerabilities as quickly as possible. 12 | -------------------------------------------------------------------------------- /.github/workflows/slash-command-dispatch.yml: -------------------------------------------------------------------------------- 1 | name: Slash Command Dispatch 2 | on: 3 | issue_comment: 4 | types: [created] 5 | jobs: 6 | slashCommandDispatch: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - name: Slash Command Dispatch 10 | id: scd 11 | uses: peter-evans/slash-command-dispatch@5c11dc7efead556e3bdabf664302212f79eb26fa # v5.0.1 12 | with: 13 | token: ${{ secrets.ACTIONS_BOT_TOKEN }} 14 | reaction-token: ${{ secrets.ACTIONS_BOT_TOKEN }} 15 | config: > 16 | [ 17 | { 18 | "command": "deploy-review", 19 | "permission": "write", 20 | "issue_type": "pull-request", 21 | "repository": "epam/ai-dial-ci", 22 | "static_args": [ 23 | "application=${{ github.event.repository.name }}" 24 | ] 25 | } 26 | ] 27 | -------------------------------------------------------------------------------- /tests/utils/client.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import httpx 4 | from fastapi import FastAPI 5 | from starlette.testclient import TestClient 6 | 7 | from aidial_sdk import DIALApp 8 | from aidial_sdk.chat_completion.base import ChatCompletion 9 | 10 | 11 | def create_app_client( 12 | chat_completion: ChatCompletion, 13 | *, 14 | name: str = "test-deployment-name", 15 | headers: Dict[str, str] = {"api-key": "TEST_API_KEY"}, 16 | ) -> httpx.Client: 17 | app = DIALApp().add_chat_completion(name, chat_completion) 18 | return create_test_client(app, name=name, headers=headers) 19 | 20 | 21 | def create_test_client( 22 | app: FastAPI, 23 | *, 24 | name: str = "test-deployment-name", 25 | headers: Dict[str, str] = {"api-key": "TEST_API_KEY"}, 26 | ) -> httpx.Client: 27 | return TestClient( 28 | app=app, 29 | headers=headers, 30 | base_url=f"http://testserver/openai/deployments/{name}", 31 | ) 32 | -------------------------------------------------------------------------------- /tests/utils/uvicorn.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import socket 3 | import threading 4 | from contextlib import asynccontextmanager 5 | 6 | import uvicorn 7 | 8 | 9 | def get_free_port(): 10 | with socket.socket() as s: 11 | s.bind(("127.0.0.1", 0)) 12 | return s.getsockname()[1] 13 | 14 | 15 | @asynccontextmanager 16 | async def run_uvicorn_in_thread(app, host="127.0.0.1", port=None): 17 | port = port or get_free_port() 18 | config = uvicorn.Config( 19 | app, host=host, port=port, log_level="warning", loop="asyncio" 20 | ) 21 | server = uvicorn.Server(config) 22 | 23 | thread = threading.Thread(target=server.run, daemon=True) 24 | thread.start() 25 | 26 | # Wait for server to actually start 27 | while not server.started: 28 | await asyncio.sleep(0.1) 29 | 30 | try: 31 | yield f"http://{host}:{port}" 32 | finally: 33 | server.should_exit = True 34 | thread.join(timeout=3) 35 | -------------------------------------------------------------------------------- /tests/test_rate_response.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pytest 4 | 5 | from aidial_sdk import DIALApp 6 | from tests.applications.noop import NoopApplication 7 | from tests.utils.endpoint_test import TestCase, run_endpoint_test 8 | from tests.utils.errors import extra_fields_error 9 | 10 | RATE_REQUEST_OK1 = {} 11 | RATE_REQUEST_OK2 = {"responseId": "123", "rate": True} 12 | RATE_REQUEST_FAIL = {"foo": "bar"} 13 | 14 | 15 | deployment = "test-app" 16 | app = DIALApp().add_chat_completion(deployment, NoopApplication()) 17 | 18 | 19 | testcases: List[TestCase] = [ 20 | TestCase(app, deployment, "rate", RATE_REQUEST_OK2, None), 21 | TestCase(app, deployment, "rate", RATE_REQUEST_OK1, None), 22 | TestCase( 23 | app, deployment, "rate", RATE_REQUEST_FAIL, extra_fields_error("foo") 24 | ), 25 | ] 26 | 27 | 28 | @pytest.mark.parametrize("testcase", testcases) 29 | def test_rate_endpoint(testcase: TestCase): 30 | run_endpoint_test(testcase) 31 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "pip" 4 | directory: "/" 5 | schedule: 6 | interval: "weekly" 7 | day: "wednesday" 8 | time: "09:00" 9 | # Disable version updates, keep security updates only 10 | open-pull-requests-limit: 0 11 | commit-message: 12 | # Prefix all commit messages with "chore: " 13 | prefix: "chore" 14 | - package-ecosystem: "github-actions" 15 | directory: "/" 16 | schedule: 17 | interval: "weekly" 18 | day: "wednesday" 19 | time: "09:00" 20 | commit-message: 21 | # Prefix all commit messages with "chore: " 22 | prefix: "chore" 23 | groups: 24 | ai-dial-ci: 25 | applies-to: version-updates 26 | patterns: 27 | - "epam/ai-dial-ci/*" 28 | github-actions: 29 | applies-to: version-updates 30 | patterns: 31 | - "*" 32 | exclude-patterns: 33 | - "epam/ai-dial-ci/*" 34 | open-pull-requests-limit: 10 35 | -------------------------------------------------------------------------------- /tests/examples/test_image_size.py: -------------------------------------------------------------------------------- 1 | from examples.image_size.app.main import app 2 | from tests.utils.client import create_test_client 3 | 4 | 5 | def test_app(): 6 | client = create_test_client(app, name="image-size") 7 | 8 | attachment = { 9 | "type": "image/png", 10 | "data": "iVBORw0KGgoAAAANSUhEUgAAAAUAAAAFCAYAAACNbyblAAAAHElEQVQI12P4//8/w38GIAXDIBKE0DHxgljNBAAO9TXL0Y4OHwAAAABJRU5ErkJggg==", 11 | "title": "Image", 12 | } 13 | 14 | response = client.post( 15 | "chat/completions", 16 | json={ 17 | "messages": [ 18 | { 19 | "role": "user", 20 | "content": "", 21 | "custom_content": {"attachments": [attachment]}, 22 | } 23 | ] 24 | }, 25 | ) 26 | 27 | body = response.json() 28 | response_message = body["choices"][0]["message"] 29 | response_content = response_message["content"] 30 | 31 | assert response_content == "Size: 5x5px" 32 | -------------------------------------------------------------------------------- /tests/test_exception.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import pytest 4 | 5 | from aidial_sdk.exceptions import HTTPException, TruncatePromptSystemError 6 | 7 | test_cases: List[Tuple[HTTPException, str]] = [ 8 | ( 9 | HTTPException( 10 | message="message", 11 | status_code=400, 12 | type="type", 13 | param="param", 14 | code="code", 15 | display_message="display_message", 16 | headers={"header": "value"}, 17 | ), 18 | "message", 19 | ), 20 | ( 21 | TruncatePromptSystemError(1, 20), 22 | "The requested maximum prompt tokens is 1. " 23 | "However, the system messages resulted in 20 tokens. " 24 | "Please reduce the length of the system messages or increase the maximum prompt tokens.", 25 | ), 26 | ] 27 | 28 | 29 | @pytest.mark.parametrize("exc, expected", test_cases) 30 | def test_str_exception(exc: HTTPException, expected: str): 31 | assert str(exc) == expected 32 | -------------------------------------------------------------------------------- /tests/utils/json.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Any 3 | 4 | 5 | def match_objects(expected: Any, actual: Any) -> bool: 6 | if isinstance(expected, dict): 7 | assert list(sorted(expected.keys())) == list(sorted(actual.keys())) 8 | for k, v in expected.items(): 9 | match_objects(v, actual[k]) 10 | elif isinstance(expected, tuple): 11 | assert len(expected) == len(actual) 12 | for i in range(len(expected)): 13 | match_objects(expected[i], actual[i]) 14 | elif isinstance(expected, list): 15 | assert len(expected) == len(actual) 16 | for i in range(len(expected)): 17 | match_objects(expected[i], actual[i]) 18 | elif callable(expected): 19 | assert expected(actual) 20 | elif isinstance(expected, re.Pattern): 21 | assert expected.match( 22 | actual 23 | ), f"{actual!r} doesn't match the regex {expected.pattern!r}" 24 | else: 25 | assert expected == actual 26 | 27 | return True 28 | -------------------------------------------------------------------------------- /aidial_sdk/utils/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from contextvars import ContextVar 3 | from typing import Optional 4 | 5 | logger = logging.getLogger("aidial_sdk") 6 | 7 | deployment_id: ContextVar[Optional[str]] = ContextVar( 8 | "deployment_id", default=None 9 | ) 10 | 11 | 12 | def set_log_deployment(new_deployment_id: str): 13 | deployment_id.set(new_deployment_id) 14 | 15 | 16 | def log_info(message: str, *args, **kwargs): 17 | logger.info(f"[{deployment_id.get()}] {message}", *args, **kwargs) 18 | 19 | 20 | def log_debug(message: str, *args, **kwargs): 21 | logger.debug(f"[{deployment_id.get()}] {message}", *args, **kwargs) 22 | 23 | 24 | def log_warning(message: str, *args, **kwargs): 25 | logger.warning(f"[{deployment_id.get()}] {message}", *args, **kwargs) 26 | 27 | 28 | def log_error(message: str, *args, **kwargs): 29 | logger.error(f"[{deployment_id.get()}] {message}", *args, **kwargs) 30 | 31 | 32 | def log_exception(message: str, *args, **kwargs): 33 | logger.exception(message, *args, **kwargs) 34 | -------------------------------------------------------------------------------- /tests/examples/test_echo.py: -------------------------------------------------------------------------------- 1 | from examples.echo.app import app 2 | from tests.utils.client import create_test_client 3 | 4 | 5 | def test_app(): 6 | client = create_test_client(app, name="echo") 7 | 8 | content = "Hello world!" 9 | attachment = { 10 | "type": "image/png", 11 | "url": "image-url", 12 | "title": "Image", 13 | } 14 | 15 | response = client.post( 16 | "chat/completions", 17 | json={ 18 | "messages": [ 19 | { 20 | "role": "user", 21 | "content": content, 22 | "custom_content": {"attachments": [attachment]}, 23 | } 24 | ] 25 | }, 26 | ) 27 | 28 | body = response.json() 29 | response_message = body["choices"][0]["message"] 30 | 31 | response_content = response_message["content"] 32 | assert response_content == content 33 | 34 | response_attachment = response_message["custom_content"]["attachments"][0] 35 | assert response_attachment == attachment 36 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/02_feature_request.yml: -------------------------------------------------------------------------------- 1 | name: "🚀 Feature request" 2 | description: Suggest an idea for this project 3 | labels: ["enhancement"] 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: | 8 | Thank you for suggesting an idea to improve DIAL. 9 | Please fill in as much of the following form as you're able. 10 | - type: input 11 | attributes: 12 | label: Name and Version 13 | description: Application name and version 14 | placeholder: dial 1.2.3 15 | validations: 16 | required: true 17 | - type: textarea 18 | attributes: 19 | label: What is the problem this feature will solve? 20 | validations: 21 | required: true 22 | - type: textarea 23 | attributes: 24 | label: What is the feature you are proposing to solve the problem? 25 | description: Describe the requests. If you already have something in mind... PRs are welcome! 26 | validations: 27 | required: true 28 | - type: textarea 29 | attributes: 30 | label: What alternatives have you considered? 31 | -------------------------------------------------------------------------------- /examples/image_size/README.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | 3 | An example of a simple image-to-text DIAL application. 4 | 5 | It takes an image from the last user message attachments and returns back the image dimensions. 6 | 7 | Upon start the Docker image exposes `openai/deployments/image-size/chat/completions` endpoint at port `5000`. 8 | 9 | ## Configuration 10 | 11 | The application supports image attachments provided in one of the following format: 12 | 13 | 1. Base64 encoded image 14 | 2. URL to the image, which might be either 15 | * public URL or 16 | * URL pointing to a file in the DIAL file storage. `DIAL_URL` environment variable should be set to support image stored in the storage. 17 | 18 | |Variable|Default|Description| 19 | |---|---|---| 20 | |DIAL_URL||URL of the core DIAL server. Optional. Used to access images stored in the DIAL file storage| 21 | 22 | ## Usage 23 | 24 | Find how to integrate the application into the DIAL Core and call it using DIAL API in the [cookbook](https://github.com/epam/ai-dial/blob/main/dial-cookbook/examples/how_to_call_image_to_text_applications.ipynb). -------------------------------------------------------------------------------- /aidial_sdk/deployment/tokenize.py: -------------------------------------------------------------------------------- 1 | from typing import List, Literal, Union 2 | 3 | from aidial_sdk._pydantic._compat import BaseModel 4 | from aidial_sdk.chat_completion.request import ChatCompletionRequest 5 | from aidial_sdk.deployment.from_request_mixin import FromRequestDeploymentMixin 6 | 7 | 8 | class TokenizeInputRequest(BaseModel): 9 | type: Literal["request"] = "request" 10 | value: ChatCompletionRequest 11 | 12 | 13 | class TokenizeInputString(BaseModel): 14 | type: Literal["string"] = "string" 15 | value: str 16 | 17 | 18 | TokenizeInput = Union[TokenizeInputRequest, TokenizeInputString] 19 | 20 | 21 | class TokenizeRequest(FromRequestDeploymentMixin): 22 | inputs: List[TokenizeInput] 23 | 24 | 25 | class TokenizeSuccess(BaseModel): 26 | status: Literal["success"] = "success" 27 | token_count: int 28 | 29 | 30 | class TokenizeError(BaseModel): 31 | status: Literal["error"] = "error" 32 | error: str 33 | 34 | 35 | TokenizeOutput = Union[TokenizeSuccess, TokenizeError] 36 | 37 | 38 | class TokenizeResponse(BaseModel): 39 | outputs: List[TokenizeOutput] 40 | -------------------------------------------------------------------------------- /aidial_sdk/utils/_reflection.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | 4 | def has_method_implemented(obj: Any, method_name: str) -> bool: 5 | """ 6 | Determine if a method is overridden in an object instance or 7 | if it is inherited from its class. 8 | """ 9 | 10 | base_method = None 11 | for cls in type(obj).__mro__[1:]: 12 | base_method = getattr(cls, method_name, None) 13 | if base_method is not None: 14 | break 15 | 16 | this_method = getattr(obj, method_name, None) 17 | 18 | if base_method is None or this_method is None: 19 | return False 20 | 21 | if hasattr(base_method, "__code__") and hasattr(this_method, "__code__"): 22 | return base_method.__code__ != this_method.__code__ 23 | 24 | return base_method != this_method 25 | 26 | 27 | def get_method_implementation(obj: Any, method_name: str) -> Optional[Any]: 28 | """ 29 | Get the method implementation of an object instance. 30 | """ 31 | 32 | if has_method_implemented(obj, method_name): 33 | return getattr(obj, method_name) 34 | return None 35 | -------------------------------------------------------------------------------- /aidial_sdk/utils/log_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from aidial_sdk._pydantic._compat import BaseModel 4 | 5 | DIAL_SDK_LOG = os.environ.get("DIAL_SDK_LOG", "WARNING").upper() 6 | 7 | 8 | class LogConfig(BaseModel): 9 | """Logging configuration to be set for the server""" 10 | 11 | version: int = 1 12 | disable_existing_loggers: bool = False 13 | formatters: dict = { 14 | "default": { 15 | "()": "uvicorn.logging.DefaultFormatter", 16 | "fmt": "%(levelprefix)s | %(asctime)s | %(name)s | %(process)d | %(message)s", 17 | "datefmt": "%Y-%m-%d %H:%M:%S", 18 | "use_colors": True, 19 | }, 20 | } 21 | handlers: dict = { 22 | "default": { 23 | "formatter": "default", 24 | "class": "logging.StreamHandler", 25 | "stream": "ext://sys.stderr", 26 | }, 27 | } 28 | loggers: dict = { 29 | "aidial_sdk": {"handlers": ["default"], "level": DIAL_SDK_LOG}, 30 | "uvicorn": { 31 | "handlers": ["default"], 32 | "propagate": False, 33 | }, 34 | } 35 | -------------------------------------------------------------------------------- /.github/workflows/dependabot-automation.yml: -------------------------------------------------------------------------------- 1 | name: Dependabot Automation 2 | 3 | on: pull_request_target 4 | 5 | permissions: {} 6 | 7 | jobs: 8 | dependabot: 9 | runs-on: ubuntu-latest 10 | if: | 11 | github.event.pull_request.user.login == 'dependabot[bot]' && 12 | github.repository_owner == 'epam' 13 | steps: 14 | - name: Dependabot metadata 15 | id: metadata 16 | uses: dependabot/fetch-metadata@08eff52bf64351f401fb50d4972fa95b9f2c2d1b # v2.4.0 17 | - name: Approve PR 18 | run: gh pr review --approve "$PR_URL" 19 | env: 20 | PR_URL: ${{ github.event.pull_request.html_url }} 21 | GH_TOKEN: ${{ secrets.ACTIONS_BOT_TOKEN }} 22 | - name: Merge PR 23 | if: | 24 | steps.metadata.outputs.dependency-group == 'ai-dial-ci' && 25 | steps.metadata.outputs.update-type != 'version-update:semver-major' 26 | run: gh pr merge --auto --squash "$PR_URL" 27 | env: 28 | PR_URL: ${{ github.event.pull_request.html_url }} 29 | GH_TOKEN: ${{ secrets.ACTIONS_BOT_TOKEN }} 30 | -------------------------------------------------------------------------------- /tests/utils/sharing.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Any, Set 3 | 4 | 5 | class IdHashable: 6 | obj: Any 7 | 8 | def __init__(self, obj: Any) -> None: 9 | self.obj = obj 10 | 11 | def __hash__(self): 12 | return id(self.obj) 13 | 14 | def __eq__(self, other: object) -> bool: 15 | return (type(self), hash(self)) == (type(other), hash(other)) 16 | 17 | def __repr__(self): 18 | return json.dumps(self.obj) 19 | 20 | 21 | def collect_mutable_objects(a: Any) -> Set[IdHashable]: 22 | ret: set[IdHashable] = set() 23 | 24 | def _register(obj: Any): 25 | if isinstance(obj, (dict, list)): 26 | ret.add(IdHashable(obj)) 27 | 28 | def _rec(obj: Any): 29 | _register(obj) 30 | if isinstance(obj, dict): 31 | list(map(_rec, obj.values())) 32 | elif isinstance(obj, (list, tuple)): 33 | list(map(_rec, obj)) 34 | 35 | _rec(a) 36 | 37 | return ret 38 | 39 | 40 | def collect_shared_mutable_objects(a: Any, b: Any) -> Set[IdHashable]: 41 | return collect_mutable_objects(a) & collect_mutable_objects(b) 42 | -------------------------------------------------------------------------------- /aidial_sdk/utils/_content_stream.py: -------------------------------------------------------------------------------- 1 | from typing import Protocol 2 | 3 | 4 | class ContentReceiver(Protocol): 5 | def append_content(self, content: str) -> None: ... 6 | 7 | 8 | class ContentStream: 9 | """ 10 | The ContentStream class allows using the receiver in contexts where typing.SupportsWrite[str] is expected. 11 | For example: 12 | 13 | 1. Redirecting print statements: 14 | 15 | print("Hello, world", file=content_stream) 16 | 17 | 2. Using with tqdm for progress bars: 18 | 19 | import tqdm 20 | for item in tqdm(items, file=content_stream): 21 | process(item) 22 | 23 | 3. Redirecting logs to the content stream: 24 | 25 | import logging 26 | logging_handler = logging.StreamHandler(stream=content_stream) 27 | 28 | 4. Writing CSV data: 29 | 30 | import csv 31 | csv.writer(content_stream).writerows(data) 32 | """ 33 | 34 | _receiver: ContentReceiver 35 | 36 | def __init__(self, receiver: ContentReceiver) -> None: 37 | self._receiver = receiver 38 | 39 | def write(self, s: str) -> None: 40 | self._receiver.append_content(s) 41 | -------------------------------------------------------------------------------- /aidial_sdk/pydantic_v1/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from aidial_sdk._pydantic import PYDANTIC_V2 4 | 5 | _WARN_MESSAGE_V1 = """" 6 | The usage of `aidial_sdk.pydantic_v1` module is deprecated. 7 | 8 | To migrate your code simply replace all `aidial_sdk.pydantic_v1` imports with `pydantic` imports. 9 | """.strip() 10 | 11 | _WARN_MESSAGE_V2 = """" 12 | The usage of `aidial_sdk.pydantic_v1` module is deprecated. 13 | This module provides Pydantic v1 model even when Pydantic v2 is installed. We recommend using Pydantic v2 models. 14 | 15 | To migrate your code to Pydantic v2 models, you can follow these steps: 16 | 17 | 1. Replace all `aidial_sdk.pydantic_v1` imports with `pydantic` imports. 18 | 2. Migrate usages of Pydantic v1 models originating in DIAL SDK to Pydantic v2 API. 19 | 3. Set environment variable `PYDANTIC_V2=True` to make the DIAL SDK use Pydantic v2 models instead of Pydantic v1. 20 | """.strip() 21 | 22 | warnings.warn( 23 | _WARN_MESSAGE_V2 if PYDANTIC_V2 else _WARN_MESSAGE_V1, DeprecationWarning 24 | ) 25 | 26 | try: 27 | from pydantic.v1 import * # type: ignore 28 | except ImportError: 29 | from pydantic import * # type: ignore 30 | -------------------------------------------------------------------------------- /.ort.yml: -------------------------------------------------------------------------------- 1 | --- 2 | excludes: 3 | paths: 4 | - pattern: "examples/**" 5 | reason: "EXAMPLE_OF" 6 | scopes: 7 | - pattern: "lint" 8 | reason: "DEV_DEPENDENCY_OF" 9 | comment: "Packages for static code analysis only." 10 | - pattern: "test" 11 | reason: "TEST_DEPENDENCY_OF" 12 | comment: "Packages for testing only." 13 | resolutions: 14 | rule_violations: 15 | - message: ".*PyPI::httpcore:1\\.0\\.9.*" 16 | reason: "CANT_FIX_EXCEPTION" 17 | comment: "BSD 3-Clause New or Revised License: https://github.com/encode/httpcore/blob/1.0.9/LICENSE.md" 18 | - message: ".*PyPI::httpx:0\\.25\\.2.*" 19 | reason: "CANT_FIX_EXCEPTION" 20 | comment: "BSD 3-Clause New or Revised License: https://github.com/encode/httpx/blob/0.25.2/LICENSE.md" 21 | - message: ".*PyPI::fastapi:0\\.120\\.2.*" 22 | reason: "CANT_FIX_EXCEPTION" 23 | comment: "MIT License: https://github.com/fastapi/fastapi/blob/0.120.2/LICENSE" 24 | - message: ".*PyPI::starlette:0\\.49\\.1.*" 25 | reason: "CANT_FIX_EXCEPTION" 26 | comment: "BSD 3-Clause New or Revised License: https://github.com/Kludex/starlette/blob/0.49.1/LICENSE.md" -------------------------------------------------------------------------------- /aidial_sdk/_errors.py: -------------------------------------------------------------------------------- 1 | from fastapi import HTTPException, Request 2 | from fastapi.responses import JSONResponse 3 | 4 | from aidial_sdk._pydantic import ValidationError 5 | from aidial_sdk.exceptions import HTTPException as DIALException 6 | from aidial_sdk.exceptions import InvalidRequestError 7 | 8 | 9 | def pydantic_validation_exception_handler( 10 | request: Request, exc: Exception 11 | ) -> JSONResponse: 12 | assert isinstance(exc, ValidationError) 13 | 14 | error = exc.errors()[0] 15 | path = ".".join(map(str, error["loc"])) 16 | message = f"Your request contained invalid structure on path {path}. {error['msg']}" 17 | 18 | return InvalidRequestError(message).to_fastapi_response() 19 | 20 | 21 | def fastapi_exception_handler(request: Request, exc: Exception) -> JSONResponse: 22 | assert isinstance(exc, HTTPException) 23 | return JSONResponse( 24 | status_code=exc.status_code, 25 | content=exc.detail, 26 | headers=exc.headers, 27 | ) 28 | 29 | 30 | def dial_exception_handler(request: Request, exc: Exception) -> JSONResponse: 31 | assert isinstance(exc, DIALException) 32 | return exc.to_fastapi_response() 33 | -------------------------------------------------------------------------------- /tests/applications/idle.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import List 3 | 4 | from aidial_sdk.chat_completion import ChatCompletion, Request, Response 5 | 6 | 7 | class IdleApplication(ChatCompletion): 8 | """ 9 | Application that waits the given intervals before producing chunks. 10 | """ 11 | 12 | intervals: List[float] 13 | throw_exception: bool 14 | 15 | def __init__(self, intervals: List[float], throw_exception: bool): 16 | self.intervals = intervals 17 | self.throw_exception = throw_exception 18 | 19 | async def chat_completion( 20 | self, request: Request, response: Response 21 | ) -> None: 22 | # sleep before the first chunk is generated 23 | await asyncio.sleep(self.intervals[0]) 24 | 25 | response.set_response_id("test_id") 26 | response.set_created(0) 27 | 28 | with response.create_single_choice() as choice: 29 | choice.append_content("1") 30 | for idx, interval in enumerate(self.intervals[1:], 2): 31 | await asyncio.sleep(interval) 32 | choice.append_content(str(idx)) 33 | 34 | if self.throw_exception: 35 | raise RuntimeError("Something went wrong") 36 | -------------------------------------------------------------------------------- /tests/applications/custom_endpoints.py: -------------------------------------------------------------------------------- 1 | import fastapi 2 | 3 | from aidial_sdk import DIALApp 4 | from aidial_sdk.chat_completion import ChatCompletion, Request, Response 5 | from aidial_sdk.deployment.tokenize import TokenizeRequest, TokenizeResponse 6 | 7 | 8 | class NoopApplication(ChatCompletion): 9 | async def chat_completion( 10 | self, request: Request, response: Response 11 | ) -> None: 12 | with response.create_single_choice(): 13 | pass 14 | 15 | async def tokenize(self, request: TokenizeRequest) -> TokenizeResponse: 16 | return TokenizeResponse(outputs=[]) 17 | 18 | 19 | app = DIALApp().add_chat_completion("test-app1", NoopApplication()) 20 | 21 | 22 | @app.post("/openai/deployments/test-app1/tokenize") 23 | async def tokenize(request: fastapi.Request): 24 | return {"result": "custom_tokenize_result"} 25 | 26 | 27 | @app.post("/openai/deployments/test-app1/truncate_prompt") 28 | async def truncate_prompt(request: fastapi.Request): 29 | return {"result": "custom_truncate_prompt_result"} 30 | 31 | 32 | @app.post("/openai/deployments/test-app2/chat/completions") 33 | async def chat_completion(request: fastapi.Request): 34 | return {"result": "custom_chat_completion_result"} 35 | -------------------------------------------------------------------------------- /aidial_sdk/embeddings/request.py: -------------------------------------------------------------------------------- 1 | from typing import List, Literal, Optional, Union 2 | 3 | from aidial_sdk._pydantic import StrictInt, StrictStr 4 | from aidial_sdk.chat_completion.request import Attachment 5 | from aidial_sdk.deployment.from_request_mixin import FromRequestDeploymentMixin 6 | from aidial_sdk.utils.pydantic import ExtraAllowModel 7 | 8 | 9 | class AzureEmbeddingsRequest(ExtraAllowModel): 10 | model: Optional[StrictStr] = None 11 | input: Union[ 12 | StrictStr, List[StrictStr], List[StrictInt], List[List[StrictInt]] 13 | ] 14 | encoding_format: Literal["float", "base64"] = "float" 15 | dimensions: Optional[StrictInt] = None 16 | user: Optional[StrictStr] = None 17 | 18 | 19 | class EmbeddingsRequestCustomFields(ExtraAllowModel): 20 | type: Optional[StrictStr] = None 21 | instruction: Optional[StrictStr] = None 22 | 23 | 24 | EmbeddingsMultiModalInput = Union[ 25 | StrictStr, Attachment, List[Union[StrictStr, Attachment]] 26 | ] 27 | 28 | 29 | class EmbeddingsRequest(AzureEmbeddingsRequest): 30 | custom_input: Optional[List[EmbeddingsMultiModalInput]] = None 31 | custom_fields: Optional[EmbeddingsRequestCustomFields] = None 32 | 33 | 34 | class Request(EmbeddingsRequest, FromRequestDeploymentMixin): 35 | pass 36 | -------------------------------------------------------------------------------- /examples/render_text/README.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | 3 | An example of a simple text-to-image DIAL application. 4 | 5 | It takes a text from the last user message attachments and returns back the image with the rasterized text. 6 | 7 | The generated image is added as an image attachment to the response message and also as a Markdown image in the response text. 8 | 9 | Upon start the Docker image exposes `openai/deployments/render-text/chat/completions` endpoint at port `5000`. 10 | 11 | ## Configuration 12 | 13 | The application returns the image in one of the following formats: 14 | 15 | 1. Base64 encoded image 16 | 2. URL to the image stored in the DIAL file storage. `DIAL_URL` environment variable should be set to support image uploading to the storage. 17 | 18 | The format of the image attachment is controlled by the user message, which is expected to have the following format: `(base64|url),`. 19 | 20 | |Variable|Default|Description| 21 | |---|---|---| 22 | |DIAL_URL||URL of the core DIAL server. Optional. Used to upload generated images the DIAL file storage| 23 | 24 | ## Usage 25 | 26 | Find how to integrate the application into the DIAL Core and call it using DIAL API in the [cookbook](https://github.com/epam/ai-dial/blob/main/dial-cookbook/examples/how_to_call_text_to_image_applications.ipynb). -------------------------------------------------------------------------------- /tests/examples/test_render_text.py: -------------------------------------------------------------------------------- 1 | import base64 2 | from io import BytesIO 3 | from typing import Tuple 4 | 5 | from PIL import Image 6 | 7 | from examples.render_text.app.main import app 8 | from tests.utils.client import create_test_client 9 | 10 | 11 | def test_app(): 12 | client = create_test_client(app, name="render-text") 13 | 14 | response = client.post( 15 | "chat/completions", 16 | json={ 17 | "messages": [ 18 | { 19 | "role": "user", 20 | "content": "base64,Hello world!", 21 | } 22 | ] 23 | }, 24 | ) 25 | 26 | body = response.json() 27 | 28 | response_message = body["choices"][0]["message"] 29 | response_content = response_message["content"] 30 | assert response_content.startswith("![Image](data:image/png;base64,") 31 | 32 | attachment = response_message["custom_content"]["attachments"][0] 33 | assert attachment["type"] == "image/png" 34 | assert attachment["title"] == "Image" 35 | data = attachment["data"] 36 | assert data is not None and get_image_base64_size(data) == (200, 100) 37 | 38 | 39 | def get_image_base64_size(image_base64) -> Tuple[int, int]: 40 | image_binary = base64.b64decode(image_base64) 41 | img = Image.open(BytesIO(image_binary)) 42 | return img.size 43 | -------------------------------------------------------------------------------- /tests/test_pydantic.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pydantic import BaseModel as NativeBaseModel 3 | 4 | from aidial_sdk._pydantic import INSTALLED_PYDANTIC_V2, USE_PYDANTIC_V2 5 | from aidial_sdk._pydantic._compat import BaseModel as DialBaseModel 6 | from tests.utils.pydantic import model_dump 7 | 8 | pytestmark = pytest.mark.skipif( 9 | not USE_PYDANTIC_V2 and INSTALLED_PYDANTIC_V2, 10 | reason="Pydantic v1 and v2 models aren't compatible", 11 | ) 12 | 13 | 14 | class DialSubStruct(DialBaseModel): 15 | x: int 16 | y: str 17 | 18 | 19 | class NativeSubStruct(NativeBaseModel): 20 | x: int 21 | y: str 22 | 23 | 24 | def test_native_pydantic(): 25 | assert model_dump(NativeSubStruct(x=1, y="2")) == {"x": 1, "y": "2"} 26 | 27 | 28 | def test_dial_pydantic(): 29 | assert DialSubStruct(x=1, y="2").model_dump() == {"x": 1, "y": "2"} 30 | 31 | 32 | def test_pydantic_dial_in_native_compatibility(): 33 | class NativeStruct(NativeBaseModel): 34 | z: DialSubStruct 35 | 36 | assert model_dump(NativeStruct(z=DialSubStruct(x=1, y="2"))) == { 37 | "z": {"x": 1, "y": "2"} 38 | } 39 | 40 | 41 | def test_pydantic_native_in_dial_compatibility(): 42 | class DialStruct(DialBaseModel): 43 | z: NativeSubStruct 44 | 45 | assert DialStruct(z=NativeSubStruct(x=1, y="2")).model_dump() == { 46 | "z": {"x": 1, "y": "2"} 47 | } 48 | -------------------------------------------------------------------------------- /tests/test_custom_endpoints.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pytest 4 | 5 | from tests.applications.custom_endpoints import app 6 | from tests.utils.endpoint_test import TestCase, run_endpoint_test 7 | from tests.utils.errors import route_not_found_error 8 | 9 | CHAT_COMPLETION_REQUEST = {"messages": [{"role": "user", "content": "ping"}]} 10 | 11 | deployment = "test-app" 12 | 13 | testcases: List[TestCase] = [ 14 | TestCase( 15 | app, 16 | "test-app1", 17 | "tokenize", 18 | {"inputs": []}, 19 | {"outputs": []}, 20 | ), 21 | TestCase( 22 | app, 23 | "test-app1", 24 | "tokenizer", 25 | {}, 26 | route_not_found_error, 27 | ), 28 | TestCase( 29 | app, 30 | "test-app1", 31 | "truncate_prompt", 32 | {}, 33 | {"result": "custom_truncate_prompt_result"}, 34 | ), 35 | TestCase( 36 | app, 37 | "test-app2", 38 | "chat/completions", 39 | {}, 40 | {"result": "custom_chat_completion_result"}, 41 | ), 42 | TestCase( 43 | app, 44 | "test-app2", 45 | "tokenize", 46 | {}, 47 | route_not_found_error, 48 | ), 49 | ] 50 | 51 | 52 | @pytest.mark.parametrize("testcase", testcases) 53 | def test_custom_endpoints(testcase: TestCase): 54 | run_endpoint_test(testcase) 55 | -------------------------------------------------------------------------------- /tests/test_is_implemented.py: -------------------------------------------------------------------------------- 1 | from aidial_sdk.chat_completion import ( 2 | ChatCompletion, 3 | Request, 4 | Response, 5 | TokenizeRequest, 6 | TokenizeResponse, 7 | TruncatePromptRequest, 8 | TruncatePromptResponse, 9 | ) 10 | from aidial_sdk.utils._reflection import has_method_implemented 11 | 12 | 13 | class WithTokenize(ChatCompletion): 14 | async def chat_completion( 15 | self, request: Request, response: Response 16 | ) -> None: 17 | pass 18 | 19 | async def tokenize(self, request: TokenizeRequest) -> TokenizeResponse: 20 | return TokenizeResponse(outputs=[]) 21 | 22 | 23 | class WithTruncatePrompt(ChatCompletion): 24 | async def chat_completion( 25 | self, request: Request, response: Response 26 | ) -> None: 27 | pass 28 | 29 | async def truncate_prompt( 30 | self, request: TruncatePromptRequest 31 | ) -> TruncatePromptResponse: 32 | return TruncatePromptResponse(outputs=[]) 33 | 34 | 35 | def test_has_tokenize_implemented(): 36 | assert has_method_implemented(WithTokenize(), "tokenize") 37 | assert not has_method_implemented(WithTruncatePrompt(), "tokenize") 38 | 39 | 40 | def test_has_truncate_prompt_implemented(): 41 | assert not has_method_implemented(WithTokenize(), "truncate_prompt") 42 | assert has_method_implemented(WithTruncatePrompt(), "truncate_prompt") 43 | -------------------------------------------------------------------------------- /aidial_sdk/utils/_attachment.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, cast, overload 2 | 3 | from aidial_sdk.chat_completion.request import Attachment 4 | 5 | 6 | @overload 7 | def create_attachment(attachment: Attachment) -> Attachment: ... 8 | 9 | 10 | @overload 11 | def create_attachment( 12 | type: Optional[str] = None, 13 | title: Optional[str] = None, 14 | data: Optional[str] = None, 15 | url: Optional[str] = None, 16 | reference_url: Optional[str] = None, 17 | reference_type: Optional[str] = None, 18 | ) -> Attachment: ... 19 | 20 | 21 | def create_attachment(*args, **kwargs) -> Attachment: 22 | if args and isinstance(args[0], Attachment): 23 | return cast(Attachment, args[0]) 24 | elif isinstance(kwargs.get("attachment"), Attachment): 25 | return cast(Attachment, kwargs.get("attachment")) 26 | else: 27 | return _attachment_from_fields(*args, **kwargs) 28 | 29 | 30 | def _attachment_from_fields( 31 | type: Optional[str] = None, 32 | title: Optional[str] = None, 33 | data: Optional[str] = None, 34 | url: Optional[str] = None, 35 | reference_url: Optional[str] = None, 36 | reference_type: Optional[str] = None, 37 | ) -> Attachment: 38 | return Attachment( 39 | type=type, 40 | title=title, 41 | data=data, 42 | url=url, 43 | reference_url=reference_url, 44 | reference_type=reference_type, 45 | ) 46 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/01_bug_report.yml: -------------------------------------------------------------------------------- 1 | name: 🐞 Bug report 2 | description: Create a report to help us improve 3 | labels: ["bug"] 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: | 8 | Thank you for reporting an issue to improve DIAL. 9 | Please fill in as much of the following form as you're able. 10 | - type: input 11 | attributes: 12 | label: Name and Version 13 | description: Application name and version 14 | placeholder: dial 1.2.3 15 | validations: 16 | required: true 17 | - type: textarea 18 | attributes: 19 | label: What steps will reproduce the bug? 20 | description: Enter details about your bug. 21 | placeholder: | 22 | 1. In this environment... 23 | 2. With this config... 24 | 3. Run '...' 25 | 4. See error... 26 | validations: 27 | required: true 28 | - type: textarea 29 | attributes: 30 | label: What is the expected behavior? 31 | description: If possible please provide textual output instead of screenshots. 32 | - type: textarea 33 | attributes: 34 | label: What do you see instead? 35 | description: If possible please provide textual output instead of screenshots. 36 | validations: 37 | required: true 38 | - type: textarea 39 | attributes: 40 | label: Additional information 41 | description: Tell us anything else you think we should know. 42 | -------------------------------------------------------------------------------- /examples/echo/app.py: -------------------------------------------------------------------------------- 1 | """ 2 | A DIAL application which returns back the content and attachments 3 | from the last user message. 4 | """ 5 | 6 | import uvicorn 7 | 8 | from aidial_sdk import DIALApp 9 | from aidial_sdk.chat_completion import ChatCompletion, Request, Response 10 | 11 | 12 | # ChatCompletion is an abstract class for applications and model adapters 13 | class EchoApplication(ChatCompletion): 14 | async def chat_completion( 15 | self, request: Request, response: Response 16 | ) -> None: 17 | # Get last message (the newest) from the history 18 | last_message = request.messages[-1] 19 | 20 | # Generate response with a single choice 21 | with response.create_single_choice() as choice: 22 | # Fill the content of the response with the last user's content 23 | choice.append_content(last_message.text()) 24 | 25 | if last_message.custom_content is not None: 26 | for attachment in last_message.custom_content.attachments or []: 27 | # Add the same attachment to the response 28 | choice.add_attachment(**attachment.model_dump()) 29 | 30 | 31 | # DIALApp extends FastAPI to provide a user-friendly interface for routing requests to your applications 32 | app = DIALApp() 33 | app.add_chat_completion("echo", EchoApplication()) 34 | 35 | # Run built app 36 | if __name__ == "__main__": 37 | uvicorn.run(app, port=5000) 38 | -------------------------------------------------------------------------------- /aidial_sdk/chat_completion/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Union 3 | 4 | from aidial_sdk.chat_completion.request import Request 5 | from aidial_sdk.chat_completion.response import Response 6 | from aidial_sdk.deployment.configuration import ( 7 | ConfigurationRequest, 8 | ConfigurationResponse, 9 | ) 10 | from aidial_sdk.deployment.rate import RateRequest 11 | from aidial_sdk.deployment.tokenize import TokenizeRequest, TokenizeResponse 12 | from aidial_sdk.deployment.truncate_prompt import ( 13 | TruncatePromptRequest, 14 | TruncatePromptResponse, 15 | ) 16 | 17 | 18 | class ChatCompletion(ABC): 19 | @abstractmethod 20 | async def chat_completion( 21 | self, request: Request, response: Response 22 | ) -> None: 23 | """Implement chat completion logic""" 24 | 25 | async def rate_response(self, request: RateRequest) -> None: 26 | """Implement rate response logic""" 27 | 28 | async def tokenize(self, request: TokenizeRequest) -> TokenizeResponse: 29 | """Implement tokenize logic""" 30 | raise NotImplementedError() 31 | 32 | async def truncate_prompt( 33 | self, request: TruncatePromptRequest 34 | ) -> TruncatePromptResponse: 35 | """Implement truncate prompt logic""" 36 | raise NotImplementedError() 37 | 38 | async def configuration( 39 | self, request: ConfigurationRequest 40 | ) -> Union[ConfigurationResponse, dict]: 41 | """Implement configuration logic""" 42 | raise NotImplementedError() 43 | -------------------------------------------------------------------------------- /examples/langchain_rag/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from aidial_sdk import HTTPException as DIALException 4 | from aidial_sdk.chat_completion import Message 5 | 6 | 7 | def sanitize_namespace(namespace: str) -> str: 8 | return "".join(c if c.isalnum() or c in "._-/" else "-" for c in namespace) 9 | 10 | 11 | def get_last_attachment_url(messages: List[Message]) -> str: 12 | for message in reversed(messages): 13 | if ( 14 | message.custom_content is not None 15 | and message.custom_content.attachments is not None 16 | ): 17 | attachments = message.custom_content.attachments 18 | 19 | if attachments == []: 20 | continue 21 | 22 | if len(attachments) != 1: 23 | msg = "Only one attachment per message is supported" 24 | raise DIALException( 25 | status_code=422, 26 | message=msg, 27 | display_message=msg, 28 | ) 29 | 30 | attachment = attachments[0] 31 | 32 | url = attachment.url 33 | if url is None: 34 | msg = "Attachment is expected to be provided via a URL" 35 | raise DIALException( 36 | status_code=422, 37 | message=msg, 38 | display_message=msg, 39 | ) 40 | 41 | return url 42 | 43 | msg = "No attachment was found" 44 | raise DIALException( 45 | status_code=422, 46 | message=msg, 47 | display_message=msg, 48 | ) 49 | -------------------------------------------------------------------------------- /tests/utils/errors.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from pydantic import BaseModel 4 | 5 | from aidial_sdk._pydantic import PYDANTIC_V2 6 | 7 | 8 | class Error(BaseModel): 9 | code: int 10 | error: dict 11 | 12 | 13 | def internal_server_error(message: str) -> Error: 14 | return Error( 15 | code=500, 16 | error={ 17 | "error": { 18 | "code": "500", 19 | "message": message, 20 | "type": "runtime_error", 21 | } 22 | }, 23 | ) 24 | 25 | 26 | def bad_request_error(message: Any) -> Error: 27 | return Error( 28 | code=400, 29 | error={ 30 | "error": { 31 | "message": message, 32 | "type": "invalid_request_error", 33 | "code": "400", 34 | } 35 | }, 36 | ) 37 | 38 | 39 | def invalid_request_error(path: str, message: str) -> Error: 40 | return bad_request_error( 41 | f"Your request contained invalid structure on path {path}. {message}" 42 | ) 43 | 44 | 45 | def missing_fields_error(path: str) -> Error: 46 | if PYDANTIC_V2: 47 | return invalid_request_error(path, "Field required") 48 | else: 49 | return invalid_request_error(path, "field required") 50 | 51 | 52 | def extra_fields_error(path: str) -> Error: 53 | if PYDANTIC_V2: 54 | return invalid_request_error(path, "Extra inputs are not permitted") 55 | else: 56 | return invalid_request_error(path, "extra fields not permitted") 57 | 58 | 59 | route_not_found_error: Error = Error(code=404, error={"detail": "Not Found"}) 60 | -------------------------------------------------------------------------------- /aidial_sdk/chat_completion/__init__.py: -------------------------------------------------------------------------------- 1 | from aidial_sdk.chat_completion.base import ChatCompletion 2 | from aidial_sdk.chat_completion.choice import Choice 3 | from aidial_sdk.chat_completion.enums import FinishReason, Status 4 | from aidial_sdk.chat_completion.form import Button, FormMetaclass 5 | from aidial_sdk.chat_completion.request import ( 6 | Addon, 7 | Attachment, 8 | CacheBreakpoint, 9 | CustomContent, 10 | Function, 11 | FunctionCall, 12 | FunctionChoice, 13 | Message, 14 | MessageContentImagePart, 15 | MessageContentPart, 16 | MessageContentRefusalPart, 17 | MessageContentTextPart, 18 | MessageCustomFields, 19 | Request, 20 | ResponseFormat, 21 | ResponseFormatJsonObject, 22 | ResponseFormatJsonSchema, 23 | ResponseFormatJsonSchemaObject, 24 | ResponseFormatText, 25 | Role, 26 | ) 27 | from aidial_sdk.chat_completion.request import Stage as RequestStage 28 | from aidial_sdk.chat_completion.request import ( 29 | Tool, 30 | ToolCall, 31 | ToolChoice, 32 | ToolCustomFields, 33 | ) 34 | from aidial_sdk.chat_completion.response import Response 35 | from aidial_sdk.chat_completion.stage import Stage 36 | from aidial_sdk.deployment.configuration import ( 37 | ConfigurationRequest, 38 | ConfigurationResponse, 39 | ) 40 | from aidial_sdk.deployment.tokenize import ( 41 | TokenizeError, 42 | TokenizeRequest, 43 | TokenizeResponse, 44 | TokenizeSuccess, 45 | ) 46 | from aidial_sdk.deployment.truncate_prompt import ( 47 | TruncatePromptError, 48 | TruncatePromptRequest, 49 | TruncatePromptResponse, 50 | TruncatePromptSuccess, 51 | ) 52 | -------------------------------------------------------------------------------- /tests/applications/echo.py: -------------------------------------------------------------------------------- 1 | from typing_extensions import override 2 | 3 | from aidial_sdk.chat_completion import ChatCompletion, Request, Response 4 | from aidial_sdk.deployment.tokenize import TokenizeRequest, TokenizeResponse 5 | from aidial_sdk.deployment.truncate_prompt import ( 6 | TruncatePromptRequest, 7 | TruncatePromptResponse, 8 | ) 9 | from tests.utils.tokenization import ( 10 | default_truncate_prompt, 11 | make_batched_tokenize, 12 | make_batched_truncate_prompt, 13 | word_count_request, 14 | word_count_tokenize, 15 | ) 16 | 17 | 18 | class EchoApplication(ChatCompletion): 19 | model_max_prompt_tokens: int 20 | 21 | def __init__(self, model_max_prompt_tokens: int): 22 | self.model_max_prompt_tokens = model_max_prompt_tokens 23 | 24 | async def chat_completion( 25 | self, request: Request, response: Response 26 | ) -> None: 27 | response.set_response_id("test_id") 28 | response.set_created(0) 29 | 30 | content = request.messages[-1].text() 31 | 32 | with response.create_single_choice() as choice: 33 | choice.append_content(content) 34 | 35 | @override 36 | async def tokenize(self, request: TokenizeRequest) -> TokenizeResponse: 37 | return make_batched_tokenize(word_count_tokenize)(request) 38 | 39 | @override 40 | async def truncate_prompt( 41 | self, request: TruncatePromptRequest 42 | ) -> TruncatePromptResponse: 43 | return make_batched_truncate_prompt( 44 | lambda req: default_truncate_prompt( 45 | req, word_count_request, self.model_max_prompt_tokens 46 | ) 47 | )(request) 48 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | ARGS ?= 2 | VENV_DIR ?= .venv 3 | POETRY ?= $(VENV_DIR)/bin/poetry 4 | POETRY_VERSION ?= 2.1.1 5 | 6 | all: build 7 | 8 | init_env: 9 | python -m venv $(VENV_DIR) 10 | $(VENV_DIR)/bin/pip install poetry==$(POETRY_VERSION) --quiet 11 | 12 | install: init_env 13 | $(POETRY) install --all-extras 14 | 15 | build: install 16 | $(POETRY) build 17 | 18 | clean: 19 | rm -rf $$($(POETRY) env info --path) 20 | rm -rf .nox 21 | rm -rf .pytest_cache 22 | rm -rf dist 23 | find . -type d -name __pycache__ | xargs rm -r 24 | 25 | publish: build 26 | $(POETRY) publish -u __token__ -p $(PYPI_TOKEN) --skip-existing 27 | 28 | lint: install 29 | $(POETRY) run nox -s lint 30 | 31 | format: install 32 | $(POETRY) run nox -s format 33 | 34 | test: install 35 | $(POETRY) run -- nox -s test $(if $(PYTHON),--python=$(PYTHON),) -- $(ARGS) 36 | 37 | test_fast: install 38 | $(POETRY) run -- nox -s test $(if $(PYTHON),--python=$(PYTHON),) -- -m 'not slow' $(ARGS) 39 | 40 | benchmark: install 41 | python -m benchmark.benchmark_merge_chunks 42 | 43 | help: 44 | @echo '====================' 45 | @echo 'build - build the library' 46 | @echo 'clean - clean virtual env and build artifacts' 47 | @echo 'publish - publish the library to Pypi' 48 | @echo '-- LINTING --' 49 | @echo 'format - run code formatters' 50 | @echo 'lint - run linters' 51 | @echo '-- TESTS --' 52 | @echo 'test - run unit tests' 53 | @echo 'test_fast - run unit tests without slow tests' 54 | @echo 'test PYTHON= - run unit tests with the specific python version' 55 | @echo 'benchmark - run benchmarks' 56 | -------------------------------------------------------------------------------- /tests/test_embeddings.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pytest 4 | 5 | from aidial_sdk import DIALApp 6 | from tests.applications.simple_embeddings import SimpleEmbeddings 7 | from tests.utils.endpoint_test import TestCase, run_endpoint_test 8 | 9 | deployment = "test-app" 10 | app = DIALApp().add_embeddings(deployment, SimpleEmbeddings()) 11 | 12 | expected_response_1 = { 13 | "data": [ 14 | {"embedding": [0.0], "index": 0, "object": "embedding"}, 15 | ], 16 | "model": "dummy", 17 | "object": "list", 18 | "usage": {"prompt_tokens": 1, "total_tokens": 1}, 19 | } 20 | 21 | expected_response_2 = { 22 | "data": [ 23 | {"embedding": [0.0], "index": 0, "object": "embedding"}, 24 | {"embedding": [1.0], "index": 1, "object": "embedding"}, 25 | ], 26 | "model": "dummy", 27 | "object": "list", 28 | "usage": {"prompt_tokens": 2, "total_tokens": 2}, 29 | } 30 | 31 | testcases: List[TestCase] = [ 32 | TestCase( 33 | app, 34 | deployment, 35 | "embeddings", 36 | { 37 | "input": "a", 38 | "custom_fields": { 39 | "type": "query", 40 | "instruction": "instruction", 41 | }, 42 | }, 43 | expected_response_1, 44 | ), 45 | TestCase( 46 | app, 47 | deployment, 48 | "embeddings", 49 | {"input": [15339]}, 50 | expected_response_1, 51 | ), 52 | TestCase( 53 | app, 54 | deployment, 55 | "embeddings", 56 | {"input": ["a", "b"]}, 57 | expected_response_2, 58 | ), 59 | ] 60 | 61 | 62 | @pytest.mark.parametrize("testcase", testcases) 63 | def test_embeddings(testcase: TestCase): 64 | run_endpoint_test(testcase) 65 | -------------------------------------------------------------------------------- /aidial_sdk/chat_completion/function_call.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from aidial_sdk.chat_completion.choice_base import ChoiceBase 4 | from aidial_sdk.chat_completion.chunks import FunctionCallChunk 5 | from aidial_sdk.utils.errors import runtime_error 6 | 7 | 8 | class FunctionCall: 9 | _choice: ChoiceBase 10 | 11 | def __init__(self, choice: ChoiceBase): 12 | self._choice = choice 13 | 14 | @classmethod 15 | def create_and_send( 16 | cls, choice: ChoiceBase, name: str, arguments: Optional[str] 17 | ) -> "FunctionCall": 18 | return cls(choice)._send_function_call( 19 | create=True, name=name, arguments=arguments 20 | ) 21 | 22 | def append_arguments(self, arguments: str) -> "FunctionCall": 23 | return self._send_function_call( 24 | create=False, name=None, arguments=arguments 25 | ) 26 | 27 | def _send_function_call( 28 | self, *, create: bool, name: Optional[str], arguments: Optional[str] 29 | ) -> "FunctionCall": 30 | if not self._choice.opened: 31 | raise runtime_error( 32 | "Trying to add function call to an unopened choice" 33 | ) 34 | if self._choice.closed: 35 | raise runtime_error( 36 | "Trying to add function call to a closed choice" 37 | ) 38 | if create and self._choice.has_function_call: 39 | raise runtime_error( 40 | "Trying to add function call to a choice which already has a function call" 41 | ) 42 | 43 | self._choice.send_chunk( 44 | FunctionCallChunk( 45 | self._choice.index, name=name, arguments=arguments 46 | ) 47 | ) 48 | 49 | return self 50 | -------------------------------------------------------------------------------- /aidial_sdk/telemetry/types.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import Optional 4 | 5 | from aidial_sdk._pydantic._compat import BaseModel 6 | from aidial_sdk.utils.env import env_var_list 7 | 8 | # OpenTelemetry SDK configuration env vars: 9 | # https://opentelemetry.io/docs/specs/otel/configuration/sdk-environment-variables/ 10 | 11 | OTEL_LOGS_EXPORTER = env_var_list("OTEL_LOGS_EXPORTER") 12 | OTEL_TRACES_EXPORTER = env_var_list("OTEL_TRACES_EXPORTER") 13 | OTEL_METRICS_EXPORTER = env_var_list("OTEL_METRICS_EXPORTER") 14 | OTEL_EXPORTER_PROMETHEUS_PORT = int( 15 | os.getenv("OTEL_EXPORTER_PROMETHEUS_PORT", 9464) 16 | ) 17 | OTEL_PYTHON_LOG_CORRELATION = ( 18 | os.getenv("OTEL_PYTHON_LOG_CORRELATION", "false").lower() == "true" 19 | ) 20 | 21 | 22 | class LogsConfig(BaseModel): 23 | otlp_export: bool = "otlp" in OTEL_LOGS_EXPORTER 24 | level: int = logging.INFO 25 | 26 | 27 | class TracingConfig(BaseModel): 28 | otlp_export: bool = "otlp" in OTEL_TRACES_EXPORTER 29 | 30 | """Configure logging to include tracing context 31 | into console log messages""" 32 | logging: bool = OTEL_PYTHON_LOG_CORRELATION 33 | 34 | 35 | class MetricsConfig(BaseModel): 36 | otlp_export: bool = "otlp" in OTEL_METRICS_EXPORTER 37 | prometheus_export: bool = "prometheus" in OTEL_METRICS_EXPORTER 38 | port: int = OTEL_EXPORTER_PROMETHEUS_PORT 39 | 40 | 41 | class TelemetryConfig(BaseModel): 42 | service_name: Optional[str] = None 43 | 44 | logs: Optional[LogsConfig] = LogsConfig() if OTEL_LOGS_EXPORTER else None 45 | tracing: Optional[TracingConfig] = ( 46 | TracingConfig() if OTEL_TRACES_EXPORTER else None 47 | ) 48 | metrics: Optional[MetricsConfig] = ( 49 | MetricsConfig() if OTEL_METRICS_EXPORTER else None 50 | ) 51 | -------------------------------------------------------------------------------- /tests/test_single_choice.py: -------------------------------------------------------------------------------- 1 | from tests.applications.single_choice import SingleChoiceApplication 2 | from tests.utils.chunks import check_sse_stream, create_single_choice_chunk 3 | from tests.utils.client import create_app_client 4 | 5 | 6 | def test_single_choice(): 7 | client = create_app_client(SingleChoiceApplication()) 8 | 9 | response = client.post( 10 | "chat/completions", 11 | json={ 12 | "messages": [{"role": "user", "content": "Test content"}], 13 | "stream": False, 14 | }, 15 | ) 16 | 17 | assert response.status_code == 200 and response.json() == { 18 | "choices": [ 19 | { 20 | "index": 0, 21 | "finish_reason": "stop", 22 | "message": { 23 | "role": "assistant", 24 | "content": "Test response content", 25 | }, 26 | } 27 | ], 28 | "usage": None, 29 | "id": "test_id", 30 | "created": 0, 31 | "object": "chat.completion", 32 | } 33 | 34 | 35 | def test_single_choice_streaming(): 36 | client = create_app_client(SingleChoiceApplication()) 37 | 38 | response = client.post( 39 | "chat/completions", 40 | json={ 41 | "messages": [{"role": "user", "content": "Test content"}], 42 | "stream": True, 43 | }, 44 | ) 45 | 46 | check_sse_stream( 47 | response.iter_lines(), 48 | [ 49 | create_single_choice_chunk(delta={"role": "assistant"}), 50 | create_single_choice_chunk( 51 | delta={"content": "Test response content"} 52 | ), 53 | create_single_choice_chunk(delta={}, finish_reason="stop"), 54 | ], 55 | ) 56 | -------------------------------------------------------------------------------- /examples/render_text/app/image.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import os 3 | import textwrap 4 | from io import BytesIO 5 | 6 | import aiohttp 7 | from PIL import Image, ImageDraw, ImageFont 8 | 9 | 10 | def text_to_image_base64(text: str, img_size=(200, 100), font_size=20) -> str: 11 | img = Image.new("RGB", img_size, color="yellow") # type: ignore 12 | d = ImageDraw.Draw(img) 13 | 14 | try: 15 | font = ImageFont.truetype("Monaco.ttf", font_size) 16 | except IOError: 17 | font = ImageFont.load_default(font_size) # type: ignore 18 | 19 | wrapped_text = textwrap.fill(text, width=15) 20 | 21 | d.text((10, 10), wrapped_text, fill=(0, 0, 0), font=font) 22 | 23 | img_buffer = BytesIO() 24 | img.save(img_buffer, format="PNG") 25 | img_buffer.seek(0) 26 | 27 | img_base64 = base64.b64encode(img_buffer.getvalue()).decode() 28 | 29 | return img_base64 30 | 31 | 32 | async def upload_png_image( 33 | dial_url: str, filepath: str, image_base64: str 34 | ) -> str: 35 | async with aiohttp.ClientSession() as session: 36 | async with session.get(f"{dial_url}/v1/bucket") as response: 37 | response.raise_for_status() 38 | appdata = (await response.json())["appdata"] 39 | 40 | image_bytes = base64.b64decode(image_base64) 41 | 42 | data = aiohttp.FormData() 43 | data.add_field( 44 | name="file", 45 | content_type="image/png", 46 | value=BytesIO(image_bytes), 47 | filename=os.path.basename(filepath), 48 | ) 49 | 50 | async with session.put( 51 | f"{dial_url}/v1/files/{appdata}/{filepath}", data=data 52 | ) as response: 53 | response.raise_for_status() 54 | metadata = await response.json() 55 | 56 | return metadata["url"] 57 | -------------------------------------------------------------------------------- /tests/header_propagation/client.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | import aiohttp 4 | import httpx 5 | import requests 6 | from fastapi import FastAPI 7 | from fastapi.responses import JSONResponse 8 | from pydantic import BaseModel 9 | from typing_extensions import assert_never 10 | 11 | app = FastAPI() 12 | 13 | 14 | class Library(str, Enum): 15 | requests = "requests" 16 | httpx_sync = "httpx_sync" 17 | httpx_async = "httpx_async" 18 | aiohttp = "aiohttp" 19 | 20 | 21 | class Request(BaseModel): 22 | url: str 23 | lib: Library 24 | headers: dict 25 | 26 | 27 | @app.post("/") 28 | async def handle(request: Request): 29 | url = request.url 30 | lib = request.lib 31 | headers = request.headers 32 | 33 | if lib == Library.requests: 34 | response = requests.get(url, headers=headers) 35 | status_code = response.status_code 36 | content = response.json() 37 | 38 | elif lib == Library.httpx_async: 39 | async with httpx.AsyncClient() as client: 40 | response = await client.get(url, headers=headers) 41 | status_code = response.status_code 42 | content = response.json() 43 | 44 | elif lib == Library.httpx_sync: 45 | with httpx.Client() as client: 46 | response = client.get(url, headers=headers) 47 | status_code = response.status_code 48 | content = response.json() 49 | 50 | elif lib == Library.aiohttp: 51 | async with aiohttp.ClientSession() as session: 52 | async with session.get(url, headers=headers) as response: 53 | status_code = response.status 54 | content = await response.json() 55 | 56 | else: 57 | assert_never(lib) 58 | 59 | return JSONResponse( 60 | status_code=status_code, 61 | content=content, 62 | ) 63 | -------------------------------------------------------------------------------- /aidial_sdk/chat_completion/function_tool_call.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional 2 | 3 | from aidial_sdk.chat_completion.choice_base import ChoiceBase 4 | from aidial_sdk.chat_completion.chunks import FunctionToolCallChunk 5 | from aidial_sdk.utils.errors import runtime_error 6 | 7 | 8 | class FunctionToolCall: 9 | _choice: ChoiceBase 10 | _index: int 11 | 12 | def __init__(self, choice: ChoiceBase, index: int): 13 | self._choice = choice 14 | self._index = index 15 | 16 | @classmethod 17 | def create_and_send( 18 | cls, 19 | choice: ChoiceBase, 20 | index: int, 21 | id: str, 22 | name: str, 23 | arguments: Optional[str], 24 | ) -> "FunctionToolCall": 25 | return cls(choice, index)._send_tool_call( 26 | id=id, type="function", name=name, arguments=arguments 27 | ) 28 | 29 | def append_arguments(self, arguments: str) -> "FunctionToolCall": 30 | return self._send_tool_call( 31 | id=None, type=None, name=None, arguments=arguments 32 | ) 33 | 34 | def _send_tool_call( 35 | self, 36 | *, 37 | id: Optional[str], 38 | type: Optional[Literal["function"]], 39 | name: Optional[str], 40 | arguments: Optional[str], 41 | ) -> "FunctionToolCall": 42 | if not self._choice.opened: 43 | raise runtime_error("Trying to add tool call to an unopened choice") 44 | if self._choice.closed: 45 | raise runtime_error("Trying to add tool call to a closed choice") 46 | 47 | self._choice.send_chunk( 48 | FunctionToolCallChunk( 49 | self._choice.index, 50 | self._index, 51 | type=type, 52 | id=id, 53 | name=name, 54 | arguments=arguments, 55 | ) 56 | ) 57 | 58 | return self 59 | -------------------------------------------------------------------------------- /noxfile.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Tuple 3 | 4 | import nox 5 | 6 | nox.options.reuse_existing_virtualenvs = True 7 | 8 | SRC = "." 9 | 10 | 11 | def format_with_args(session: nox.Session, *args): 12 | session.run("autoflake", *args) 13 | session.run("isort", *args) 14 | session.run("black", *args) 15 | 16 | 17 | @nox.session 18 | def lint(session: nox.Session): 19 | """Runs linters and fixers""" 20 | try: 21 | session.run("poetry", "install", "--all-extras", external=True) 22 | session.run("poetry", "check", "--lock", external=True) 23 | session.run("pyright", SRC) 24 | session.run("flake8", SRC) 25 | format_with_args(session, SRC, "--check") 26 | except Exception: 27 | session.error( 28 | "linting has failed. Run 'make format' to fix formatting and fix other errors manually" 29 | ) 30 | 31 | 32 | @nox.session 33 | def format(session: nox.Session): 34 | """Runs linters and fixers""" 35 | session.run("poetry", "install", external=True) 36 | format_with_args(session, SRC) 37 | 38 | 39 | class UsePydanticV2(Enum): 40 | YES = "1" 41 | NO = "0" 42 | 43 | 44 | @nox.session(python=["3.9", "3.10", "3.11", "3.12", "3.13"]) 45 | # Testing against earliest and latest supported versions of the dependencies 46 | @nox.parametrize( 47 | "pydantic", 48 | [ 49 | ("1.10.17", UsePydanticV2.NO), 50 | ("2.8.2", UsePydanticV2.NO), 51 | ("2.8.2", UsePydanticV2.YES), 52 | ], 53 | ) 54 | @nox.parametrize("httpx", ["0.25.0", "0.27.0"]) 55 | def test( 56 | session: nox.Session, pydantic: Tuple[str, UsePydanticV2], httpx: str 57 | ) -> None: 58 | """Runs tests""" 59 | session.run("poetry", "install", external=True) 60 | session.install(f"pydantic=={pydantic[0]}", f"httpx=={httpx}") 61 | session.run( 62 | "pytest", *session.posargs, env={"PYDANTIC_V2": str(pydantic[1].value)} 63 | ) 64 | -------------------------------------------------------------------------------- /tests/utils/endpoint_test.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union 2 | 3 | from fastapi import FastAPI 4 | from starlette.testclient import TestClient 5 | 6 | from tests.utils.errors import Error 7 | from tests.utils.json import match_objects 8 | 9 | 10 | class TestCase: 11 | __test__ = False 12 | 13 | app: FastAPI 14 | 15 | deployment: str 16 | endpoint: str 17 | 18 | request_body: dict 19 | request_headers: Dict[str, str] 20 | response: Union[Error, dict, None] 21 | 22 | def __init__( 23 | self, 24 | app: FastAPI, 25 | deployment: str, 26 | endpoint: str, 27 | request_body: dict, 28 | response: Union[Error, dict, None], 29 | request_headers: Dict[str, str] = {}, 30 | ): 31 | self.app = app 32 | self.deployment = deployment 33 | self.endpoint = endpoint 34 | self.request_body = request_body 35 | self.response = response 36 | self.request_headers = request_headers 37 | 38 | 39 | def run_endpoint_test(testcase: TestCase): 40 | 41 | client = TestClient(testcase.app) 42 | 43 | actual_response = client.post( 44 | f"/openai/deployments/{testcase.deployment}/{testcase.endpoint}", 45 | json=testcase.request_body, 46 | headers={"Api-Key": "TEST_API_KEY", **testcase.request_headers}, 47 | ) 48 | 49 | if actual_response.text == "": 50 | actual_response_body = None 51 | else: 52 | actual_response_body = actual_response.json() 53 | 54 | expected_response = testcase.response 55 | if isinstance(expected_response, Error): 56 | expected_response_code = expected_response.code 57 | expected_response_body = expected_response.error 58 | else: 59 | expected_response_code = 200 60 | expected_response_body = expected_response 61 | 62 | assert match_objects(expected_response_body, actual_response_body) 63 | assert actual_response.status_code == expected_response_code 64 | -------------------------------------------------------------------------------- /tests/test_function_calling.py: -------------------------------------------------------------------------------- 1 | from aidial_sdk.chat_completion import ChatCompletion, Request, Response 2 | from tests.utils.chunks import ( 3 | check_sse_stream, 4 | create_function_call_chunk, 5 | create_single_choice_chunk, 6 | ) 7 | from tests.utils.client import create_app_client 8 | 9 | 10 | class FunctionCaller(ChatCompletion): 11 | async def chat_completion( 12 | self, request: Request, response: Response 13 | ) -> None: 14 | response.set_response_id("test_id") 15 | response.set_created(0) 16 | 17 | with response.create_single_choice() as choice: 18 | choice.append_content("Test content") 19 | 20 | function_call = choice.create_function_call("function_name") 21 | function_call.append_arguments('{"key') 22 | function_call.append_arguments('":"') 23 | function_call.append_arguments('val"}') 24 | 25 | 26 | def test_function_call_non_streaming(): 27 | response = create_app_client(FunctionCaller()).post( 28 | "chat/completions", json={"messages": [], "stream": False} 29 | ) 30 | 31 | body = response.json() 32 | assert body["choices"][0]["message"]["function_call"] == { 33 | "name": "function_name", 34 | "arguments": '{"key":"val"}', 35 | } 36 | 37 | 38 | def test_function_call_streaming(): 39 | response = create_app_client(FunctionCaller()).post( 40 | "chat/completions", json={"messages": [], "stream": True} 41 | ) 42 | 43 | check_sse_stream( 44 | response.iter_lines(), 45 | [ 46 | create_single_choice_chunk(delta={"role": "assistant"}), 47 | create_single_choice_chunk(delta={"content": "Test content"}), 48 | create_function_call_chunk(name="function_name"), 49 | create_function_call_chunk(arguments='{"key'), 50 | create_function_call_chunk(arguments='":"'), 51 | create_function_call_chunk(arguments='val"}'), 52 | create_single_choice_chunk(delta={}, finish_reason="function_call"), 53 | ], 54 | ) 55 | -------------------------------------------------------------------------------- /aidial_sdk/utils/_indexed_list.py: -------------------------------------------------------------------------------- 1 | INDEX_ERROR_MESSAGE = "A list element must have 'index' field to identify position of the element in the list" 2 | 3 | INDEX_INTEGER_ERROR_MESSAGE = ( 4 | "A list element must have 'index' field of a integer type, but got {ty}" 5 | ) 6 | 7 | INDEX_NON_NEGATIVE_ERROR_MESSAGE = "A list element must have 'index' field which a non-negative integer, but got {index}" 8 | 9 | INCONSISTENT_INDEXED_LIST_ERROR_MESSAGE = ( 10 | "All elements of a list must be either indexed or not indexed" 11 | ) 12 | 13 | 14 | def try_parse_indexed_list( 15 | xs: list, *, normalize_inplace: bool = False 16 | ) -> bool: 17 | if len(xs) == 0: 18 | return False 19 | 20 | all_indexed = True 21 | max_index = None 22 | normalized = True 23 | 24 | for idx, elem in enumerate(xs): 25 | if isinstance(elem, dict) and (index := elem.get("index")) is not None: 26 | if not isinstance(index, int): 27 | raise AssertionError( 28 | INDEX_INTEGER_ERROR_MESSAGE.format(ty=type(index).__name__) 29 | ) 30 | 31 | if index < 0: 32 | raise AssertionError( 33 | INDEX_NON_NEGATIVE_ERROR_MESSAGE.format(index=index) 34 | ) 35 | 36 | normalized = normalized and idx == index 37 | 38 | max_index = index if max_index is None else max(max_index, index) 39 | else: 40 | all_indexed = False 41 | 42 | if max_index is not None and not all_indexed: 43 | raise AssertionError(INCONSISTENT_INDEXED_LIST_ERROR_MESSAGE) 44 | 45 | if max_index is None: 46 | return False 47 | 48 | if not normalized and normalize_inplace: 49 | _normalize_indexed_list(xs, max_index + 1) 50 | 51 | return True 52 | 53 | 54 | def _normalize_indexed_list(xs: list, new_length: int) -> list: 55 | elems = {elem.get("index"): elem for elem in xs} 56 | 57 | xs.clear() 58 | for index in range(new_length): 59 | elem = elems.pop(index, None) 60 | if elem is not None: 61 | xs.append(elem) 62 | else: 63 | xs.append({"index": index}) 64 | 65 | return xs 66 | -------------------------------------------------------------------------------- /tests/applications/broken.py: -------------------------------------------------------------------------------- 1 | from fastapi import HTTPException as FastAPIException 2 | 3 | from aidial_sdk import HTTPException as DIALException 4 | from aidial_sdk.chat_completion import ChatCompletion, Request, Response 5 | 6 | 7 | def _raise_exception(exception_type: str): 8 | if exception_type == "sdk_exception": 9 | raise DIALException("Test error", 503) 10 | elif exception_type == "fastapi_exception": 11 | raise FastAPIException(504, detail="Test detail") 12 | elif exception_type == "value_error_exception": 13 | raise ValueError("Test value error") 14 | elif exception_type == "zero_division_exception": 15 | return 1 / 0 16 | elif exception_type == "sdk_exception_with_display_message": 17 | raise DIALException("Test error", 503, display_message="I'm broken") 18 | elif exception_type == "sdk_exception_with_extra_fields": 19 | raise DIALException( 20 | "Test error", 21 | 503, 22 | status=503, 23 | details="error details", 24 | ) 25 | elif exception_type == "sdk_exception_with_headers": 26 | raise DIALException( 27 | "Too many requests", 429, headers={"Retry-After": "42"} 28 | ) 29 | else: 30 | raise DIALException("Unexpected error") 31 | 32 | 33 | class ImmediatelyBrokenApplication(ChatCompletion): 34 | """ 35 | Application which breaks immediately after receiving a request. 36 | """ 37 | 38 | async def chat_completion( 39 | self, request: Request, response: Response 40 | ) -> None: 41 | _raise_exception(request.messages[0].text()) 42 | 43 | 44 | class RuntimeBrokenApplication(ChatCompletion): 45 | """ 46 | Application which breaks after producing some output. 47 | """ 48 | 49 | async def chat_completion( 50 | self, request: Request, response: Response 51 | ) -> None: 52 | response.set_response_id("test_id") 53 | response.set_created(0) 54 | 55 | with response.create_single_choice() as choice: 56 | choice.append_content("Test content") 57 | await response.aflush() 58 | 59 | _raise_exception(request.messages[0].text()) 60 | -------------------------------------------------------------------------------- /examples/langchain_rag/README.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | 3 | An example of a simple DIAL RAG application based on Langchain utilizing Chroma vector database and RetrievalQA chain. 4 | 5 | The application processes chat completion request in the following way: 6 | 7 | 1. finds the last attachment in the conversation history and extracts URL from it, 8 | 2. downloads the document from the URL, 9 | 3. parses the document if it's a PDF or treats it as a plain text otherwise, 10 | 4. splits the text of the document into chunks, 11 | 5. computes the embeddings for the chunks, 12 | 6. saves the embeddings in the local cache, 13 | 7. run the RetrievalQA Langchain chain that consults the embeddings store and calls chat completion model to generate final answer. 14 | 15 | Upon start the Docker image exposes `openai/deployments/simple-rag/chat/completions` endpoint at port `5000`. 16 | 17 | ## Configuration 18 | 19 | |Variable|Default|Description| 20 | |---|---|---| 21 | |DIAL_URL||Required. URL of the DIAL server. Used to access embeddings and chat completion models| 22 | |EMBEDDINGS_MODEL|text-embedding-ada-002|Embeddings model| 23 | |CHAT_MODEL|gpt-4|Chat completion model| 24 | |API_VERSION|2024-02-01|Azure OpenAI API version| 25 | |LANGCHAIN_DEBUG|False|Flag to enable debug logs from Langchain| 26 | |OPENAI_LOG||Flag that controls openai library logging. Set to `debug` to enable debug logging| 27 | 28 | ## Usage 29 | 30 | The application could be tested by running it directly on your machine: 31 | 32 | ```sh 33 | python -m venv .venv 34 | source .venv/bin/activate 35 | pip install -r requirements.txt 36 | python -m app 37 | ``` 38 | 39 | Then you may call the application using DIAL API key: 40 | 41 | ```sh 42 | curl "http://localhost:5000/openai/deployments/simple-rag/chat/completions" \ 43 | -X POST \ 44 | -H "Content-Type: application:json" \ 45 | -H "api-key:${DIAL_API_KEY}" \ 46 | -d '{ 47 | "stream": true, 48 | "messages": [ 49 | { 50 | "role": "user", 51 | "content": "Who is Miss Meyers?", 52 | "custom_content": { 53 | "attachments": [ 54 | { 55 | "url": "https://en.wikipedia.org/wiki/Miss_Meyers" 56 | } 57 | ] 58 | } 59 | } 60 | ] 61 | }' 62 | ``` -------------------------------------------------------------------------------- /tests/test_serialization.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from aidial_sdk.chat_completion import Message, ResponseFormatJsonSchema, Role 4 | from aidial_sdk.chat_completion.request import ResponseFormatJsonSchemaObject 5 | from tests.utils.pydantic import model_dump, model_parse, model_parse_json 6 | 7 | 8 | def test_message_ser(): 9 | msg_obj = Message(role=Role.SYSTEM, content="test") 10 | actual_dict = model_dump(msg_obj, exclude_none=True) 11 | expected_dict = {"role": "system", "content": "test"} 12 | 13 | assert json.loads(json.dumps(actual_dict)) == expected_dict 14 | 15 | 16 | def test_message_deser(): 17 | msg_dict = {"role": "system", "content": "test"} 18 | actual_obj = model_parse_json(Message, json.dumps(msg_dict)) 19 | expected_obj = Message(role=Role.SYSTEM, content="test") 20 | 21 | assert actual_obj == expected_obj 22 | 23 | 24 | def test_response_format_serialization(): 25 | format_obj = ResponseFormatJsonSchema( 26 | type="json_schema", 27 | json_schema=ResponseFormatJsonSchemaObject( 28 | description="desc", 29 | name="name", 30 | schema={"key": "value"}, 31 | ), 32 | ) 33 | 34 | actual_dict = model_dump(format_obj) 35 | 36 | expected_dict = { 37 | "type": "json_schema", 38 | "json_schema": { 39 | "description": "desc", 40 | "name": "name", 41 | "schema": {"key": "value"}, 42 | "strict": False, 43 | }, 44 | } 45 | 46 | assert actual_dict == expected_dict 47 | 48 | 49 | def test_response_format_deserialization(): 50 | format_dict = { 51 | "type": "json_schema", 52 | "json_schema": { 53 | "description": "desc", 54 | "name": "name", 55 | "schema": {"key": "value"}, 56 | }, 57 | } 58 | 59 | actual_obj = model_parse(ResponseFormatJsonSchema, format_dict) 60 | 61 | expected_obj = ResponseFormatJsonSchema( 62 | type="json_schema", 63 | json_schema=ResponseFormatJsonSchemaObject( 64 | description="desc", 65 | name="name", 66 | schema={"key": "value"}, 67 | strict=False, 68 | ), 69 | ) 70 | 71 | assert actual_obj == expected_obj 72 | -------------------------------------------------------------------------------- /tests/utils/pydantic.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, Type, TypeVar, Union 2 | 3 | from aidial_sdk._pydantic import PYDANTIC_V2, BaseModel 4 | from aidial_sdk._pydantic import Field as PydField 5 | from aidial_sdk.utils.pydantic import model_validate_extra_fields 6 | 7 | _ModelT = TypeVar("_ModelT", bound=BaseModel) 8 | 9 | 10 | def Field(*args, **kwargs) -> Any: 11 | if PYDANTIC_V2: 12 | from aidial_sdk.pydantic.v2 import Field as SDKField 13 | 14 | return SDKField(*args, **kwargs) 15 | else: 16 | return PydField(*args, **kwargs) 17 | 18 | 19 | def model_parse( 20 | model: Type[_ModelT], data: Any, *, allow_extra_fields=True 21 | ) -> _ModelT: 22 | if PYDANTIC_V2: 23 | obj = model.model_validate(data) 24 | else: 25 | obj = model.parse_obj(data) # pyright: ignore[reportDeprecated] 26 | if not allow_extra_fields: 27 | model_validate_extra_fields(obj) # type: ignore 28 | return obj 29 | 30 | 31 | def model_parse_json( 32 | model: Type[_ModelT], data: Union[str, bytes], *, allow_extra_fields=True 33 | ) -> _ModelT: 34 | if PYDANTIC_V2: 35 | obj = model.model_validate_json(data) 36 | else: 37 | obj = model.parse_raw(data) # pyright: ignore[reportDeprecated] 38 | if not allow_extra_fields: 39 | model_validate_extra_fields(obj) # type: ignore 40 | return obj 41 | 42 | 43 | def model_json_schema(model: Type[_ModelT]) -> Dict[str, Any]: 44 | if PYDANTIC_V2: 45 | return model.model_json_schema() 46 | return model.schema() # pyright: ignore[reportDeprecated] 47 | 48 | 49 | def model_copy( 50 | model: _ModelT, 51 | *, 52 | update: Optional[Dict[str, Any]] = None, 53 | deep: bool = False, 54 | ) -> _ModelT: 55 | if PYDANTIC_V2: 56 | return model.model_copy(update=update, deep=deep) 57 | return model.copy( # pyright: ignore[reportDeprecated] 58 | update=update, deep=deep 59 | ) 60 | 61 | 62 | def model_dump( 63 | model: BaseModel, *, exclude_none: bool = False 64 | ) -> Dict[str, Any]: 65 | if PYDANTIC_V2: 66 | return model.model_dump(exclude_none=exclude_none) 67 | return model.dict( # pyright: ignore[reportDeprecated] 68 | exclude_none=exclude_none 69 | ) 70 | -------------------------------------------------------------------------------- /aidial_sdk/utils/_cancel_scope.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from asyncio import exceptions 3 | from typing import Optional, Set 4 | 5 | 6 | class CancelScope: 7 | """ 8 | Async context manager that enforces cancellation of all tasks created within its scope when either: 9 | 1. the parent task has been cancelled or has thrown an exception or 10 | 2. any of the tasks created within the scope has thrown an exception. 11 | """ 12 | 13 | def __init__(self): 14 | self._tasks: Set[asyncio.Task] = set() 15 | self._on_completed_fut: Optional[asyncio.Future] = None 16 | self._cancelling: bool = False 17 | 18 | async def __aenter__(self): 19 | return self 20 | 21 | async def __aexit__(self, exc_type, exc, tb): 22 | 23 | cancelled_error = ( 24 | exc if isinstance(exc, exceptions.CancelledError) else None 25 | ) 26 | 27 | # If the parent task has thrown an exception, cancel all the tasks 28 | if exc_type is not None: 29 | self._cancel_tasks() 30 | 31 | while self._tasks: 32 | if self._on_completed_fut is None: 33 | self._on_completed_fut = asyncio.Future() 34 | 35 | # If the parent task was cancelled, cancel all the tasks 36 | try: 37 | await self._on_completed_fut 38 | except exceptions.CancelledError as ex: 39 | cancelled_error = ex 40 | self._cancel_tasks() 41 | 42 | self._on_completed_fut = None 43 | 44 | if cancelled_error: 45 | raise cancelled_error 46 | 47 | def create_task(self, coro): 48 | task = asyncio.create_task(coro) 49 | task.add_done_callback(self._on_task_done) 50 | self._tasks.add(task) 51 | return task 52 | 53 | def _cancel_tasks(self): 54 | if not self._cancelling: 55 | self._cancelling = True 56 | for t in self._tasks: 57 | if not t.done(): 58 | t.cancel() 59 | 60 | def _on_task_done(self, task): 61 | self._tasks.discard(task) 62 | 63 | if ( 64 | self._on_completed_fut is not None 65 | and not self._on_completed_fut.done() 66 | and not self._tasks 67 | ): 68 | self._on_completed_fut.set_result(True) 69 | 70 | # If any of the tasks was cancelled, cancel all the tasks 71 | if task.exception() is not None: 72 | self._cancel_tasks() 73 | -------------------------------------------------------------------------------- /tests/test_chat_completion_validation.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import List 3 | 4 | import pytest 5 | 6 | from aidial_sdk import DIALApp 7 | from aidial_sdk.chat_completion import ChatCompletion, Request, Response 8 | from tests.utils.endpoint_test import TestCase, run_endpoint_test 9 | from tests.utils.errors import bad_request_error, internal_server_error 10 | 11 | 12 | class App(ChatCompletion): 13 | async def chat_completion( 14 | self, request: Request, response: Response 15 | ) -> None: 16 | with response.create_single_choice() as choice: 17 | choice.add_attachment(data="xxx", url="yyy") 18 | 19 | 20 | VALID_REQUEST = {"messages": [{"role": "user", "content": "test"}]} 21 | INVALID_ATTACHMENT_BOTH = { 22 | "messages": [ 23 | { 24 | "role": "user", 25 | "content": "test", 26 | "custom_content": {"attachments": [{"data": "xxx", "url": "yyy"}]}, 27 | } 28 | ] 29 | } 30 | INVALID_ATTACHMENT_NEITHER = { 31 | "messages": [ 32 | { 33 | "role": "user", 34 | "content": "test", 35 | "custom_content": {"attachments": [{"title": "title"}]}, 36 | } 37 | ] 38 | } 39 | 40 | 41 | deployment = "test-app" 42 | 43 | noop = DIALApp().add_chat_completion(deployment, App()) 44 | 45 | 46 | testcases: List[TestCase] = [ 47 | TestCase( 48 | noop, 49 | deployment, 50 | "chat/completions", 51 | VALID_REQUEST, 52 | internal_server_error("Error during processing the request"), 53 | ), 54 | TestCase( 55 | noop, 56 | deployment, 57 | "chat/completions", 58 | INVALID_ATTACHMENT_BOTH, 59 | bad_request_error( 60 | re.compile( 61 | r"Your request contained invalid structure on path messages.0.custom_content.attachments.0\..* Attachment must have either 'data' or 'url', but it has both" 62 | ) 63 | ), 64 | ), 65 | TestCase( 66 | noop, 67 | deployment, 68 | "chat/completions", 69 | INVALID_ATTACHMENT_NEITHER, 70 | bad_request_error( 71 | re.compile( 72 | r"Your request contained invalid structure on path messages.0.custom_content.attachments.0\..* Attachment must have either 'data' or 'url', but it's missing both" 73 | ) 74 | ), 75 | ), 76 | ] 77 | 78 | 79 | @pytest.mark.parametrize("testcase", testcases) 80 | def test_chat_completion_validation(testcase: TestCase): 81 | run_endpoint_test(testcase) 82 | -------------------------------------------------------------------------------- /tests/test_tokenize.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pytest 4 | 5 | from aidial_sdk import DIALApp 6 | from aidial_sdk._pydantic import PYDANTIC_V2 7 | from tests.applications.echo import EchoApplication 8 | from tests.applications.noop import NoopApplication 9 | from tests.utils.endpoint_test import TestCase, run_endpoint_test 10 | from tests.utils.errors import missing_fields_error, route_not_found_error 11 | 12 | CHAT_COMPLETION_REQUEST = { 13 | "messages": [ 14 | {"role": "system", "content": "system"}, 15 | {"role": "user", "content": "ping"}, 16 | {"role": "assistant", "content": "pong"}, 17 | {"role": "user", "content": "hello"}, 18 | ], 19 | } 20 | 21 | TOKENIZE_REQUEST_OK1 = { 22 | "inputs": [ 23 | {"type": "request", "value": CHAT_COMPLETION_REQUEST}, 24 | {"type": "string", "value": "test string"}, 25 | ] 26 | } 27 | TOKENIZE_RESPONSE_OK1 = { 28 | "outputs": [ 29 | {"status": "success", "token_count": 4}, 30 | {"status": "success", "token_count": 2}, 31 | ] 32 | } 33 | 34 | TOKENIZE_REQUEST_OK2 = {"inputs": []} 35 | TOKENIZE_RESPONSE_OK2 = {"outputs": []} 36 | 37 | TOKENIZE_REQUEST_FAIL = {"inputs": [{}]} 38 | 39 | 40 | deployment = "test-app" 41 | 42 | noop = DIALApp().add_chat_completion(deployment, NoopApplication()) 43 | 44 | echo = DIALApp().add_chat_completion(deployment, EchoApplication(0)) 45 | 46 | 47 | testcases: List[TestCase] = [ 48 | TestCase( 49 | noop, 50 | deployment, 51 | "tokenize", 52 | TOKENIZE_REQUEST_OK1, 53 | route_not_found_error, 54 | ), 55 | TestCase( 56 | noop, 57 | deployment, 58 | "tokenizer", 59 | TOKENIZE_REQUEST_OK1, 60 | route_not_found_error, 61 | ), 62 | TestCase( 63 | echo, 64 | deployment, 65 | "tokenize", 66 | TOKENIZE_REQUEST_OK1, 67 | TOKENIZE_RESPONSE_OK1, 68 | ), 69 | TestCase( 70 | echo, 71 | deployment, 72 | "tokenize", 73 | TOKENIZE_REQUEST_OK2, 74 | TOKENIZE_RESPONSE_OK2, 75 | ), 76 | TestCase( 77 | echo, 78 | deployment, 79 | "tokenize", 80 | TOKENIZE_REQUEST_FAIL, 81 | ( 82 | # NOTE: https://github.com/pydantic/pydantic/issues/7261 83 | missing_fields_error("inputs.0.TokenizeInputRequest.value") 84 | if PYDANTIC_V2 85 | else missing_fields_error("inputs.0.value") 86 | ), 87 | ), 88 | ] 89 | 90 | 91 | @pytest.mark.parametrize("testcase", testcases) 92 | def test_tokenize(testcase: TestCase): 93 | run_endpoint_test(testcase) 94 | -------------------------------------------------------------------------------- /aidial_sdk/_pydantic/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The entry point for all Pydantic definition that unifies v1 and v2 APIs for SDK internals. 3 | 4 | It's private, since its's expected that the client of the SDK 5 | will either import `aidial_sdk.pydantic_v1` or `pydantic`. 6 | 7 | This is the only place where `pydantic` imports 8 | are allowed in the DIAL SDK package. 9 | """ 10 | 11 | from typing import TYPE_CHECKING 12 | 13 | from pydantic import VERSION 14 | 15 | from aidial_sdk.utils.env import env_bool 16 | 17 | INSTALLED_PYDANTIC_V2 = VERSION.startswith("2.") 18 | USE_PYDANTIC_V2 = env_bool("PYDANTIC_V2", False) 19 | PYDANTIC_V2 = INSTALLED_PYDANTIC_V2 and USE_PYDANTIC_V2 20 | 21 | if TYPE_CHECKING: 22 | from pydantic import BaseModel 23 | from pydantic import ConfigDict as ConfigDict 24 | from pydantic import ( 25 | Field, 26 | PositiveInt, 27 | SecretStr, 28 | StrictBool, 29 | StrictInt, 30 | StrictStr, 31 | ValidationError, 32 | ) 33 | from pydantic import field_validator as validator 34 | from pydantic import model_validator 35 | from pydantic._internal._model_construction import ModelMetaclass 36 | from pydantic.fields import FieldInfo 37 | from pydantic.v1.validators import make_literal_validator 38 | else: 39 | 40 | if PYDANTIC_V2: 41 | from pydantic import ( 42 | BaseModel, 43 | ConfigDict, 44 | Field, 45 | PositiveInt, 46 | SecretStr, 47 | StrictBool, 48 | StrictInt, 49 | StrictStr, 50 | ValidationError, 51 | ) 52 | from pydantic import field_validator as validator 53 | from pydantic import model_validator 54 | from pydantic._internal._model_construction import ModelMetaclass 55 | from pydantic.fields import FieldInfo 56 | from pydantic.v1.validators import make_literal_validator 57 | else: 58 | from pydantic.v1 import ( 59 | BaseModel, 60 | Field, 61 | PositiveInt, 62 | SecretStr, 63 | StrictBool, 64 | StrictInt, 65 | StrictStr, 66 | validator, 67 | ) 68 | 69 | try: 70 | from pydantic.v1.main import ModelMetaclass 71 | except ImportError: 72 | from pydantic.main import ModelMetaclass 73 | from pydantic.v1 import ValidationError, root_validator 74 | from pydantic.v1.fields import FieldInfo 75 | from pydantic.v1.validators import make_literal_validator 76 | 77 | def _fail(*args, **kwargs): 78 | raise ImportError("ConfigDict is only supported in Pydantic v2") 79 | 80 | ConfigDict = _fail 81 | -------------------------------------------------------------------------------- /tests/test_discarded_messages.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from aidial_sdk import HTTPException 4 | from aidial_sdk.chat_completion import ChatCompletion, Request, Response 5 | from tests.utils.chunks import check_sse_stream, create_single_choice_chunk 6 | from tests.utils.client import create_app_client 7 | from tests.utils.constants import DUMMY_DIAL_REQUEST 8 | 9 | DISCARDED_MESSAGES = list(range(0, 12)) 10 | 11 | 12 | def test_discarded_messages_returned(): 13 | class _Impl(ChatCompletion): 14 | async def chat_completion( 15 | self, request: Request, response: Response 16 | ) -> None: 17 | with response.create_single_choice(): 18 | pass 19 | response.set_discarded_messages(DISCARDED_MESSAGES) 20 | 21 | client = create_app_client(_Impl()) 22 | 23 | response = client.post( 24 | "chat/completions", 25 | json={"messages": [{"role": "user", "content": "Test"}]}, 26 | ) 27 | 28 | assert ( 29 | response.json()["statistics"]["discarded_messages"] 30 | == DISCARDED_MESSAGES 31 | ) 32 | 33 | 34 | def test_discarded_messages_returned_as_last_chunk_in_stream(): 35 | class _Impl(ChatCompletion): 36 | async def chat_completion( 37 | self, request: Request, response: Response 38 | ) -> None: 39 | response.set_response_id("test_id") 40 | response.set_created(0) 41 | 42 | with response.create_single_choice(): 43 | pass 44 | 45 | response.set_discarded_messages(DISCARDED_MESSAGES) 46 | 47 | client = create_app_client(_Impl()) 48 | 49 | response = client.post( 50 | "chat/completions", 51 | json={ 52 | "messages": [{"role": "user", "content": "Test content"}], 53 | "stream": True, 54 | }, 55 | ) 56 | 57 | check_sse_stream( 58 | response.iter_lines(), 59 | [ 60 | create_single_choice_chunk(delta={"role": "assistant"}), 61 | create_single_choice_chunk( 62 | delta={}, 63 | finish_reason="stop", 64 | statistics={"discarded_messages": DISCARDED_MESSAGES}, 65 | ), 66 | ], 67 | ) 68 | 69 | 70 | def test_discarded_messages_is_set_twice(): 71 | response = Response(DUMMY_DIAL_REQUEST) 72 | 73 | with response.create_single_choice(): 74 | pass 75 | 76 | response.set_discarded_messages(DISCARDED_MESSAGES) 77 | 78 | with pytest.raises(HTTPException): 79 | response.set_discarded_messages(DISCARDED_MESSAGES) 80 | 81 | 82 | def test_discarded_messages_is_set_before_choice(): 83 | response = Response(DUMMY_DIAL_REQUEST) 84 | 85 | with pytest.raises(HTTPException): 86 | response.set_discarded_messages(DISCARDED_MESSAGES) 87 | -------------------------------------------------------------------------------- /tests/test_request_indices.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import List 3 | 4 | import pytest 5 | 6 | from aidial_sdk._pydantic import BaseModel, ValidationError 7 | from aidial_sdk.chat_completion import ( 8 | Attachment, 9 | FunctionCall, 10 | Status, 11 | ToolCall, 12 | ) 13 | from aidial_sdk.chat_completion.request import Stage 14 | from tests.utils.pydantic import model_dump, model_parse 15 | 16 | 17 | @dataclasses.dataclass 18 | class TestCase: 19 | __test__ = False 20 | obj: BaseModel 21 | dct: dict 22 | 23 | def get_id(self) -> str: 24 | return type(self.obj).__name__ 25 | 26 | 27 | _test_cases: List[TestCase] = [ 28 | TestCase( 29 | ToolCall( 30 | id="tool-call-id", 31 | type="function", 32 | function=FunctionCall(name="func-name", arguments="{}"), 33 | ), 34 | { 35 | "id": "tool-call-id", 36 | "type": "function", 37 | "function": {"name": "func-name", "arguments": "{}"}, 38 | }, 39 | ), 40 | TestCase( 41 | Attachment(type="text/plain", data="test"), 42 | {"type": "text/plain", "data": "test"}, 43 | ), 44 | TestCase( 45 | Stage(name="Testing", status=Status.COMPLETED, content="test"), 46 | {"name": "Testing", "status": "completed", "content": "test"}, 47 | ), 48 | ] 49 | 50 | 51 | @pytest.fixture(params=_test_cases, ids=lambda x: x.get_id()) 52 | def test_case(request) -> TestCase: 53 | return request.param 54 | 55 | 56 | def _check_ser_deser(obj: BaseModel): 57 | dct = model_dump(obj) 58 | obj2 = model_parse(type(obj), dct, allow_extra_fields=False) 59 | assert obj == obj2 60 | 61 | 62 | def test_index_field_ser_deser(test_case: TestCase): 63 | _check_ser_deser(test_case.obj) 64 | 65 | 66 | def test_index_field_ignore_int(test_case: TestCase): 67 | obj = model_parse( 68 | type(test_case.obj), 69 | {**test_case.dct, **{"index": 101}}, 70 | allow_extra_fields=False, 71 | ) 72 | _check_ser_deser(obj) 73 | 74 | 75 | def test_index_field_fail_on_str(test_case: TestCase): 76 | with pytest.raises( 77 | ValidationError, 78 | match=r"(Extra inputs are not permitted|extra fields not permitted)", 79 | ): 80 | model_parse( 81 | type(test_case.obj), 82 | {**test_case.dct, **{"index": "value"}}, 83 | allow_extra_fields=False, 84 | ) 85 | 86 | 87 | def test_index_field_fail_on_extra_fields(test_case: TestCase): 88 | with pytest.raises( 89 | ValidationError, 90 | match=r"(Extra inputs are not permitted|extra fields not permitted)", 91 | ): 92 | model_parse( 93 | type(test_case.obj), 94 | {**test_case.dct, **{"index2": "whatever"}}, 95 | allow_extra_fields=False, 96 | ) 97 | -------------------------------------------------------------------------------- /tests/test_response_headers.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from aidial_sdk.chat_completion import ChatCompletion, Request, Response 4 | from tests.utils.client import create_app_client 5 | 6 | 7 | class _BeforeGeneration(ChatCompletion): 8 | async def chat_completion( 9 | self, request: Request, response: Response 10 | ) -> None: 11 | response.append_header("header1", "value1") 12 | response.append_header("header2", "value2-1") 13 | response.append_header("header2", "value2-2") 14 | 15 | with response.create_single_choice() as choice: 16 | await response.aflush() 17 | choice.append_content("hello world") 18 | 19 | 20 | class _AfterGeneration(ChatCompletion): 21 | async def chat_completion( 22 | self, request: Request, response: Response 23 | ) -> None: 24 | with response.create_single_choice() as choice: 25 | choice.append_content("hello world") 26 | await response.aflush() 27 | 28 | response.append_header("header1", "value1") 29 | response.append_header("header2", "value2-1") 30 | response.append_header("header2", "value2-2") 31 | 32 | 33 | @pytest.mark.parametrize("stream", [False, True]) 34 | def test_append_response_header_before_generation(stream: bool): 35 | client = create_app_client(_BeforeGeneration()) 36 | 37 | response = client.post( 38 | "chat/completions", 39 | json={ 40 | "messages": [{"role": "user", "content": "test"}], 41 | "stream": stream, 42 | }, 43 | ) 44 | 45 | assert response.status_code == 200 46 | assert response.headers.get("missing-header") is None 47 | 48 | assert response.headers.get("header1") == "value1" 49 | assert response.headers.get("header2") == "value2-1, value2-2" 50 | 51 | 52 | @pytest.mark.parametrize("stream", [False, True]) 53 | def test_append_response_header_after_generation(stream: bool, caplog): 54 | client = create_app_client(_AfterGeneration()) 55 | 56 | response = client.post( 57 | "chat/completions", 58 | json={ 59 | "messages": [{"role": "user", "content": "test"}], 60 | "stream": stream, 61 | }, 62 | ) 63 | 64 | assert response.headers.get("missing-header") is None 65 | assert response.status_code == 200 66 | 67 | if stream: 68 | assert response.headers.get("header1") is None 69 | assert response.headers.get("header2") is None 70 | 71 | assert "Trying to set a header after start of generation" in caplog.text 72 | assert ( 73 | 'data: {"error":{"message":"Error during processing the request","type":"runtime_error","code":"500"}}' 74 | in response.text 75 | ) 76 | else: 77 | assert response.headers.get("header1") == "value1" 78 | assert response.headers.get("header2") == "value2-1, value2-2" 79 | -------------------------------------------------------------------------------- /tests/test_tool_calling.py: -------------------------------------------------------------------------------- 1 | from aidial_sdk.chat_completion import ChatCompletion, Request, Response 2 | from tests.utils.chunks import ( 3 | check_sse_stream, 4 | create_single_choice_chunk, 5 | create_tool_call_chunk, 6 | ) 7 | from tests.utils.client import create_app_client 8 | 9 | 10 | class ToolCaller(ChatCompletion): 11 | async def chat_completion( 12 | self, request: Request, response: Response 13 | ) -> None: 14 | response.set_response_id("test_id") 15 | response.set_created(0) 16 | 17 | with response.create_single_choice() as choice: 18 | choice.append_content("Test content") 19 | 20 | tool_call1 = choice.create_function_tool_call( 21 | "tool_call_id1", "tool_name" 22 | ) 23 | tool_call1.append_arguments('{"key') 24 | tool_call1.append_arguments('":"') 25 | tool_call1.append_arguments('val"}') 26 | 27 | choice.create_function_tool_call( 28 | "tool_call_id2", "tool_name", '{"foo":"bar"}' 29 | ) 30 | 31 | 32 | def test_tool_call_non_streaming(): 33 | response = create_app_client(ToolCaller()).post( 34 | "chat/completions", json={"messages": [], "stream": False} 35 | ) 36 | 37 | body = response.json() 38 | assert body["choices"][0]["message"]["tool_calls"] == [ 39 | { 40 | "id": "tool_call_id1", 41 | "type": "function", 42 | "function": { 43 | "name": "tool_name", 44 | "arguments": '{"key":"val"}', 45 | }, 46 | }, 47 | { 48 | "id": "tool_call_id2", 49 | "type": "function", 50 | "function": { 51 | "name": "tool_name", 52 | "arguments": '{"foo":"bar"}', 53 | }, 54 | }, 55 | ] 56 | 57 | 58 | def test_tool_call_streaming(): 59 | response = create_app_client(ToolCaller()).post( 60 | "chat/completions", json={"messages": [], "stream": True} 61 | ) 62 | 63 | check_sse_stream( 64 | response.iter_lines(), 65 | [ 66 | create_single_choice_chunk(delta={"role": "assistant"}), 67 | create_single_choice_chunk(delta={"content": "Test content"}), 68 | create_tool_call_chunk( 69 | 0, type="function", id="tool_call_id1", name="tool_name" 70 | ), 71 | create_tool_call_chunk(0, arguments='{"key'), 72 | create_tool_call_chunk(0, arguments='":"'), 73 | create_tool_call_chunk(0, arguments='val"}'), 74 | create_tool_call_chunk( 75 | 1, 76 | type="function", 77 | id="tool_call_id2", 78 | name="tool_name", 79 | arguments='{"foo":"bar"}', 80 | ), 81 | create_single_choice_chunk(delta={}, finish_reason="tool_calls"), 82 | ], 83 | ) 84 | -------------------------------------------------------------------------------- /tests/test_bearer_token.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import pytest 4 | from starlette.testclient import TestClient 5 | 6 | from aidial_sdk import DIALApp 7 | from aidial_sdk.chat_completion import ChatCompletion, Request, Response 8 | 9 | 10 | class TokenEchoApp(ChatCompletion): 11 | async def chat_completion(self, request: Request, response: Response): 12 | with response.create_choice() as choice: 13 | choice.append_content("ok") 14 | 15 | 16 | def create_client(app_instance: ChatCompletion): 17 | app = DIALApp().add_chat_completion("echo", app_instance) 18 | return TestClient( 19 | app, 20 | base_url="https://testserver/openai/deployments/echo", 21 | headers={"Api-Key": "test-api-key"}, 22 | ) 23 | 24 | 25 | class AssertingApp(TokenEchoApp): 26 | def __init__(self, expected_jwt, expected_bearer_token): 27 | super().__init__() 28 | self.expected_jwt = expected_jwt 29 | self.expected_bearer_token = expected_bearer_token 30 | 31 | # noinspection PyDeprecation 32 | async def chat_completion(self, request: Request, response: Response): 33 | with warnings.catch_warnings(): 34 | warnings.simplefilter("ignore", DeprecationWarning) 35 | assert request.jwt == self.expected_jwt 36 | assert request.bearer_token == self.expected_bearer_token 37 | await super().chat_completion(request, response) 38 | 39 | 40 | @pytest.mark.parametrize( 41 | "authz_header, expected_jwt, expected_bearer", 42 | [ 43 | (None, None, None), 44 | ("Bearer abc123", "Bearer abc123", "abc123"), 45 | ("Bearer spaced", "Bearer spaced", " spaced"), 46 | ("bearer lower", "bearer lower", None), 47 | ("Token abc", "Token abc", None), 48 | ("Bearer", "Bearer", None), 49 | ], 50 | ) 51 | def test_bearer_token_parsing(authz_header, expected_jwt, expected_bearer): 52 | client = create_client(AssertingApp(expected_jwt, expected_bearer)) 53 | 54 | headers = {} 55 | if authz_header is not None: 56 | headers["Authorization"] = authz_header 57 | 58 | response = client.post( 59 | "chat/completions", 60 | json={"messages": [{"role": "user", "content": "hi"}]}, 61 | headers=headers, 62 | ) 63 | 64 | assert response.status_code == 200 65 | 66 | 67 | def test_bearer_token_is_removed_from_headers_forwarding(): 68 | class InspectHeadersApp(TokenEchoApp): 69 | async def chat_completion(self, request: Request, response: Response): 70 | assert "Authorization" not in request.headers 71 | assert "Api-Key" not in request.headers 72 | await super().chat_completion(request, response) 73 | 74 | client = create_client(InspectHeadersApp()) 75 | 76 | response = client.post( 77 | "chat/completions", 78 | json={"messages": [{"role": "user", "content": "hi"}]}, 79 | headers={"Authorization": "Bearer tok", "Api-Key": "ignored"}, 80 | ) 81 | 82 | assert response.status_code == 200 83 | -------------------------------------------------------------------------------- /tests/test_disconnect.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import httpx 4 | import pytest 5 | 6 | from aidial_sdk.chat_completion import ChatCompletion 7 | 8 | 9 | class WaitingApp(ChatCompletion): 10 | non_empty_stream: bool 11 | 12 | def __init__(self, non_empty_stream: bool): 13 | self.non_empty_stream = non_empty_stream 14 | 15 | self.is_cancelled = False 16 | self._started = None 17 | self._cancelled = None 18 | 19 | @property 20 | def started(self) -> asyncio.Event: 21 | # NOTE: lazily init events to ensure that they are 22 | # attached to the event loop inside the uvicorn thread, 23 | # instead of the test's event loop. 24 | if self._started is None: 25 | self._started = asyncio.Event() 26 | return self._started 27 | 28 | @property 29 | def cancelled(self) -> asyncio.Event: 30 | if self._cancelled is None: 31 | self._cancelled = asyncio.Event() 32 | return self._cancelled 33 | 34 | async def _wait(self): 35 | for _ in range(1000): 36 | await asyncio.sleep(1) 37 | 38 | async def chat_completion(self, request, response) -> None: 39 | try: 40 | self.started.set() 41 | 42 | if self.non_empty_stream: 43 | with response.create_single_choice(): 44 | await self._wait() 45 | else: 46 | await self._wait() 47 | 48 | except asyncio.CancelledError: 49 | self.is_cancelled = True 50 | self.cancelled.set() 51 | raise 52 | 53 | 54 | @pytest.fixture( 55 | params=[False, True], 56 | ids=lambda b: "non-empty" if b else "empty", 57 | ) 58 | def non_empty_stream(request) -> bool: 59 | return request.param 60 | 61 | 62 | @pytest.fixture 63 | def chat_completion(non_empty_stream): 64 | return WaitingApp(non_empty_stream) 65 | 66 | 67 | @pytest.mark.slow 68 | @pytest.mark.parametrize( 69 | "stream", [False, True], ids=lambda b: "stream" if b else "block" 70 | ) 71 | async def test_disconnect( 72 | stream: bool, 73 | non_empty_stream: bool, 74 | chat_completion: WaitingApp, 75 | test_http_client: httpx.AsyncClient, 76 | ): 77 | if stream and non_empty_stream: 78 | await run_disconnect_test(stream, chat_completion, test_http_client) 79 | else: 80 | with pytest.raises(httpx.ReadTimeout): 81 | await run_disconnect_test(stream, chat_completion, test_http_client) 82 | 83 | 84 | async def run_disconnect_test( 85 | stream: bool, 86 | chat_completion: WaitingApp, 87 | test_http_client: httpx.AsyncClient, 88 | ): 89 | async with test_http_client.stream( 90 | "POST", 91 | "/chat/completions", 92 | json={ 93 | "messages": [{"role": "user", "content": "hello"}], 94 | "stream": stream, 95 | }, 96 | timeout=1, 97 | ) as response: 98 | await asyncio.wait_for(chat_completion.started.wait(), timeout=5) 99 | 100 | # Emulate client disconnect by closing the socket 101 | await response.aclose() 102 | 103 | await asyncio.wait_for(chat_completion.cancelled.wait(), timeout=5) 104 | assert chat_completion.is_cancelled 105 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "aidial-sdk" 3 | version = "0.31.0rc" 4 | description = "Framework to create applications and model adapters for AI DIAL" 5 | authors = [{ name = "EPAM RAIL", email = "SpecialEPM-DIALDevTeam@epam.com" }] 6 | license = "Apache-2.0" 7 | readme = "README.md" 8 | requires-python = ">=3.9,<4.0" 9 | dependencies = [ 10 | "fastapi (>=0.51,<1.0)", 11 | "uvicorn (>=0.19,<1.0)", 12 | "pydantic (>=1.10.17,<3)", 13 | "wrapt (>=1.10,<2)", 14 | ] 15 | 16 | [project.urls] 17 | Homepage = "https://epam-rail.com" 18 | Documentation = "https://epam-rail.com/dial_api" 19 | Repository = "https://github.com/epam/ai-dial-sdk" 20 | 21 | [project.optional-dependencies] 22 | telemetry = [ 23 | "opentelemetry-sdk (>=1.22.0,<2.0)", 24 | "opentelemetry-api (>=1.22.0,<2.0)", 25 | "opentelemetry-exporter-otlp-proto-grpc (>=1.22.0,<2.0)", 26 | "opentelemetry-instrumentation-aiohttp-client (>=0.43b0)", 27 | "opentelemetry-instrumentation-fastapi (>=0.43b0)", 28 | "opentelemetry-instrumentation-httpx (>=0.43b0)", 29 | "opentelemetry-instrumentation-logging (>=0.43b0)", 30 | "opentelemetry-instrumentation-requests (>=0.43b0)", 31 | "opentelemetry-instrumentation-system-metrics (>=0.43b0)", 32 | "opentelemetry-instrumentation-urllib (>=0.43b0)", 33 | "opentelemetry-exporter-prometheus (>=0.43b0)", 34 | "prometheus-client (>=0.17.1,<=0.21)", 35 | ] 36 | httpx = ["httpx (>=0.25.0,<1)"] 37 | 38 | [tool.poetry.group.test.dependencies] 39 | pytest = "^8.2" 40 | pytest-asyncio = "^0.24.0" 41 | nox = "^2023.4.22" 42 | pillow = [ 43 | {version = "11.1.0", python = ">=3.12"}, 44 | {version = "10.3.0", python = "<3.12"} 45 | ] 46 | httpx = "^0.25.0" 47 | respx = "^0.21.1" 48 | aiohttp = "^3.12.14" 49 | aioresponses = "^0.7.6" 50 | requests = "^2.32" 51 | responses = "^0.25.3" 52 | 53 | [tool.poetry.group.lint.dependencies] 54 | flake8 = "^6.0.0" 55 | black = ">=23.3,<25.0" 56 | isort = "^5.12.0" 57 | pyright = "1.1.385" 58 | autoflake = "^2.2.0" 59 | 60 | [build-system] 61 | requires = ["poetry-core>=1.0.0"] 62 | build-backend = "poetry.core.masonry.api" 63 | 64 | [tool.pytest.ini_options] 65 | asyncio_default_fixture_loop_scope = "function" 66 | addopts = "--asyncio-mode=auto" 67 | testpaths = [ 68 | "tests" 69 | ] 70 | filterwarnings = [ 71 | "error", 72 | # muting deprecation warnings coming from old pydantic 73 | "ignore::DeprecationWarning:pydantic.typing", 74 | ] 75 | markers = [ 76 | "slow: marker for slow running tests", 77 | ] 78 | 79 | [tool.pyright] 80 | typeCheckingMode = "basic" 81 | reportDeprecated = "error" 82 | reportUnusedVariable = "error" 83 | reportIncompatibleMethodOverride = "error" 84 | exclude = [ 85 | ".git", 86 | "**/.venv", 87 | ".nox", 88 | ".pytest_cache", 89 | "**/__pycache__", 90 | "build", 91 | "examples/langchain_rag" 92 | ] 93 | 94 | [tool.black] 95 | line-length = 80 96 | 97 | [tool.isort] 98 | line_length = 80 99 | profile = "black" 100 | 101 | [tool.autoflake] 102 | ignore_init_module_imports = true 103 | remove_all_unused_imports = true 104 | in_place = true 105 | recursive = true 106 | quiet = true 107 | exclude = [ 108 | ".nox", 109 | ".pytest_cache", 110 | "\\.venv", 111 | "_pydantic.py" 112 | ] 113 | -------------------------------------------------------------------------------- /examples/render_text/app/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple text-to-image DIAL application. 3 | Takes the last message, rasterizes the text and 4 | sends the image back to the user in an attachment. 5 | """ 6 | 7 | import os 8 | 9 | import uvicorn 10 | 11 | from aidial_sdk import DIALApp 12 | from aidial_sdk import HTTPException as DIALException 13 | from aidial_sdk.chat_completion import ChatCompletion, Request, Response 14 | 15 | from .image import text_to_image_base64, upload_png_image 16 | 17 | DIAL_URL = os.environ.get("DIAL_URL") 18 | 19 | 20 | # ChatCompletion is an abstract class for applications and model adapters 21 | class RenderTextApplication(ChatCompletion): 22 | async def chat_completion(self, request: Request, response: Response): 23 | # Create a single choice 24 | with response.create_single_choice() as choice: 25 | # Get the last message content 26 | content = request.messages[-1].text() 27 | 28 | # The image may be returned either as base64 string or as URL 29 | # The content specifies the mode of return: 'base64' or 'url' 30 | try: 31 | command, text = content.split(",", 1) 32 | if command not in ["base64", "url"]: 33 | raise DIALException( 34 | message="The command must be either 'base64' or 'url'", 35 | status_code=422, 36 | ) 37 | except ValueError: 38 | raise DIALException( 39 | message="The content must be in the format '(base64|url),'", 40 | status_code=422, 41 | ) 42 | 43 | # Rasterize the user message to an image 44 | image_base64 = text_to_image_base64(text) 45 | image_type = "image/png" 46 | 47 | # Add the image as an attachment 48 | if command == "base64": 49 | # As base64 string 50 | choice.add_attachment( 51 | type=image_type, title="Image", data=image_base64 52 | ) 53 | else: 54 | # As URL to DIAL File storage 55 | if DIAL_URL is None: 56 | # DIAL SDK automatically converts standard Python exceptions to 500 Internal Server Error 57 | raise ValueError("DIAL_URL environment variable is unset") 58 | 59 | # Upload the image to DIAL File storage 60 | image_url = await upload_png_image( 61 | DIAL_URL, "images/picture.png", image_base64 62 | ) 63 | 64 | # And return as an attachment 65 | choice.add_attachment( 66 | type=image_type, title="Image", url=image_url 67 | ) 68 | 69 | # Return the image in Markdown format 70 | choice.append_content( 71 | f"![Image](data:{image_type};base64,{image_base64})" 72 | ) 73 | 74 | 75 | # DIALApp extends FastAPI to provide a user-friendly interface for routing requests to your applications 76 | app = DIALApp( 77 | dial_url=DIAL_URL, 78 | propagate_auth_headers=DIAL_URL is not None, 79 | add_healthcheck=True, 80 | ) 81 | 82 | app.add_chat_completion("render-text", RenderTextApplication()) 83 | 84 | if __name__ == "__main__": 85 | uvicorn.run(app, port=5000) 86 | -------------------------------------------------------------------------------- /tests/utils/chunks.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import json 3 | from typing import Iterable, List, Literal, Optional, Union 4 | 5 | from aidial_sdk.utils.json import remove_nones 6 | 7 | 8 | def create_chunk( 9 | *, 10 | id: str = "test_id", 11 | model: Optional[str] = None, 12 | created: int = 0, 13 | choices: List[dict], 14 | usage: Optional[dict] = None, 15 | **kwargs, 16 | ): 17 | return { 18 | "id": id, 19 | **({} if model is None else {"model": model}), 20 | "created": created, 21 | "object": "chat.completion.chunk", 22 | "choices": choices, 23 | "usage": usage, 24 | **kwargs, 25 | } 26 | 27 | 28 | def create_single_choice_chunk( 29 | *, 30 | choice_idx: int = 0, 31 | delta: dict = {}, 32 | finish_reason: Optional[str] = None, 33 | **kwargs, 34 | ): 35 | choice = { 36 | "index": choice_idx, 37 | "delta": delta, 38 | "finish_reason": finish_reason, 39 | } 40 | 41 | return create_chunk(choices=[choice], **kwargs) 42 | 43 | 44 | def create_tool_call_chunk( 45 | idx: int, 46 | *, 47 | type: Optional[Literal["function"]] = None, 48 | id: Optional[str] = None, 49 | name: Optional[str] = None, 50 | arguments: Optional[str] = None, 51 | ): 52 | return create_single_choice_chunk( 53 | delta={ 54 | "content": None, 55 | "tool_calls": [ 56 | remove_nones( 57 | { 58 | "index": idx, 59 | "id": id, 60 | "type": type, 61 | "function": remove_nones( 62 | {"name": name, "arguments": arguments} 63 | ), 64 | } 65 | ) 66 | ], 67 | } 68 | ) 69 | 70 | 71 | def create_function_call_chunk( 72 | *, 73 | name: Optional[str] = None, 74 | arguments: Optional[str] = None, 75 | ): 76 | return create_single_choice_chunk( 77 | delta={ 78 | "content": None, 79 | "function_call": remove_nones( 80 | {"name": name, "arguments": arguments} 81 | ), 82 | } 83 | ) 84 | 85 | 86 | def _check_sse_line(actual: str, expected: Union[str, dict]): 87 | if isinstance(expected, str): 88 | assert actual == expected 89 | return 90 | 91 | assert actual.startswith("data: "), f"Invalid data SSE entry: {actual!r}" 92 | actual = actual[len("data: ") :] 93 | 94 | try: 95 | actual_dict = json.loads(actual) 96 | except json.JSONDecodeError: 97 | raise AssertionError(f"Invalid JSON in data SSE entry: {actual!r}") 98 | 99 | assert actual_dict == expected 100 | 101 | 102 | ExpectedSSEStream = Iterable[Union[str, dict]] 103 | 104 | 105 | def check_sse_stream( 106 | actual: Iterable[str], expected: ExpectedSSEStream 107 | ) -> bool: 108 | expected = itertools.chain(expected, ["data: [DONE]"]) 109 | expected = itertools.chain.from_iterable((line, "") for line in expected) 110 | 111 | sentinel = object() 112 | for a_line, e_obj in itertools.zip_longest( 113 | actual, expected, fillvalue=sentinel 114 | ): 115 | assert ( 116 | a_line is not sentinel 117 | ), "The list of actual values is shorter than the list of expected values" 118 | assert ( 119 | e_obj is not sentinel 120 | ), "The list of expected values is shorter than the list of actual values" 121 | 122 | _check_sse_line(a_line, e_obj) # type: ignore 123 | 124 | return True 125 | -------------------------------------------------------------------------------- /tests/test_cancellation.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Optional 3 | 4 | import pytest 5 | 6 | from aidial_sdk.chat_completion.response import ( 7 | Response as ChatCompletionResponse, 8 | ) 9 | from aidial_sdk.utils.streaming import add_heartbeat 10 | from tests.utils.constants import DUMMY_DIAL_REQUEST 11 | 12 | 13 | class Counter: 14 | done: int = 0 15 | cancelled: int = 0 16 | _lock: asyncio.Lock 17 | 18 | def __init__(self) -> None: 19 | self._lock = asyncio.Lock() 20 | 21 | async def inc_done(self): 22 | async with self._lock: 23 | self.done += 1 24 | 25 | async def inc_cancelled(self): 26 | async with self._lock: 27 | self.cancelled += 1 28 | 29 | 30 | async def _wait_forever(): 31 | await asyncio.Event().wait() 32 | 33 | 34 | async def _wait(counter: Counter, secs: Optional[int] = None): 35 | try: 36 | if secs is None: 37 | await _wait_forever() 38 | else: 39 | for _ in range(secs): 40 | await asyncio.sleep(1) 41 | except asyncio.CancelledError: 42 | await counter.inc_cancelled() 43 | raise 44 | else: 45 | await counter.inc_done() 46 | 47 | 48 | def chat_completion_wait_forever(counter: Counter): 49 | 50 | async def _chat_completion(*args, **kwargs): 51 | await _wait(counter) 52 | 53 | return _chat_completion 54 | 55 | 56 | def chat_completion_gather(counter: Counter): 57 | 58 | async def _chat_completion(*args, **kwargs): 59 | tasks = (asyncio.create_task(_wait(counter)) for _ in range(10)) 60 | await asyncio.gather(*tasks) 61 | 62 | return _chat_completion 63 | 64 | 65 | def chat_completion_create_task(counter: Counter): 66 | 67 | async def _chat_completion(*args, **kwargs): 68 | # Saving tasks in a set to avoid potential deallocation by GC 69 | # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task 70 | bg_tasks = set() 71 | 72 | for _ in range(10): 73 | task = asyncio.create_task(_wait(counter, 3)) 74 | bg_tasks.add(task) 75 | task.add_done_callback(bg_tasks.discard) 76 | await _wait_forever() 77 | 78 | return _chat_completion 79 | 80 | 81 | @pytest.mark.slow 82 | @pytest.mark.parametrize("with_heartbeat", [True, False]) 83 | @pytest.mark.parametrize( 84 | "chat_completion, expected_cancelled, expected_done", 85 | [ 86 | (chat_completion_wait_forever, 1, 0), 87 | (chat_completion_gather, 10, 0), 88 | (chat_completion_create_task, 0, 10), 89 | ], 90 | ) 91 | async def test_cancellation( 92 | with_heartbeat: bool, chat_completion, expected_cancelled, expected_done 93 | ): 94 | 95 | response = ChatCompletionResponse(DUMMY_DIAL_REQUEST) 96 | 97 | counter = Counter() 98 | chat_completion = chat_completion(counter) 99 | 100 | async def _exhaust_stream(stream): 101 | async for _ in stream: 102 | pass 103 | 104 | try: 105 | stream = response._generate_stream(chat_completion) 106 | if with_heartbeat: 107 | stream = add_heartbeat( 108 | stream, 109 | heartbeat_interval=0.2, 110 | heartbeat_object=": heartbeat\n\n", 111 | ) 112 | 113 | await asyncio.wait_for(_exhaust_stream(stream), timeout=2) 114 | except asyncio.TimeoutError: 115 | pass 116 | else: 117 | assert False, "Stream should have timed out" 118 | 119 | await asyncio.sleep(2) 120 | 121 | assert ( 122 | counter.cancelled == expected_cancelled 123 | and counter.done == expected_done 124 | ), "Stream should have been cancelled" 125 | -------------------------------------------------------------------------------- /aidial_sdk/_pydantic/_compat.py: -------------------------------------------------------------------------------- 1 | """ 2 | The module provide the basic Pydantic BaseModel extended 3 | with `model_dump` method mimicking the one from Pydantic V2. 4 | 5 | All SDK models inherit from this class. 6 | 7 | It proves to be useful since 8 | 1. `model_dump` method is used extensively in the SDK, 9 | 2. the SDK client may call this method on SDK models even if the client uses Pydantic V1. 10 | """ 11 | 12 | from datetime import date, datetime 13 | from typing import Any, Dict, Iterable, Mapping, Optional, Set, Union, cast 14 | 15 | from typing_extensions import Literal 16 | 17 | import aidial_sdk._pydantic as pydantic 18 | from aidial_sdk._pydantic import PYDANTIC_V2 19 | 20 | _IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any], None] 21 | 22 | 23 | class BaseModel(pydantic.BaseModel): 24 | if not PYDANTIC_V2: 25 | # we define aliases for some of the new pydantic v2 methods so 26 | # that we can just document these methods without having to specify 27 | # a specific pydantic version as some users may not know which 28 | # pydantic version they are currently using 29 | 30 | def model_dump( 31 | self, 32 | *, 33 | mode: Union[Literal["json", "python"], str] = "python", 34 | include: _IncEx = None, 35 | exclude: _IncEx = None, 36 | by_alias: bool = False, 37 | exclude_unset: bool = False, 38 | exclude_defaults: bool = False, 39 | exclude_none: bool = False, 40 | round_trip: bool = False, 41 | warnings: Union[bool, Literal["none", "warn", "error"]] = True, 42 | context: Optional[Dict[str, Any]] = None, 43 | serialize_as_any: bool = False, 44 | ) -> Dict[str, Any]: 45 | if mode not in {"json", "python"}: 46 | raise ValueError("mode must be either 'json' or 'python'") 47 | if round_trip is not False: 48 | raise ValueError("round_trip is only supported in Pydantic v2") 49 | if warnings is not True: 50 | raise ValueError("warnings is only supported in Pydantic v2") 51 | if context is not None: 52 | raise ValueError("context is only supported in Pydantic v2") 53 | if serialize_as_any is not False: 54 | raise ValueError( 55 | "serialize_as_any is only supported in Pydantic v2" 56 | ) 57 | dumped = super().dict( # pyright: ignore[reportDeprecated] 58 | include=include, 59 | exclude=exclude, 60 | by_alias=by_alias, 61 | exclude_unset=exclude_unset, 62 | exclude_defaults=exclude_defaults, 63 | exclude_none=exclude_none, 64 | ) 65 | 66 | return ( 67 | cast(Dict[str, Any], _json_safe(dumped)) 68 | if mode == "json" 69 | else dumped 70 | ) 71 | 72 | 73 | def _json_safe(data: object) -> object: 74 | """Translates a mapping / sequence recursively in the same fashion 75 | as `pydantic` v2's `model_dump(mode="json")`. 76 | """ 77 | if isinstance(data, Mapping): 78 | return { 79 | _json_safe(key): _json_safe(value) for key, value in data.items() 80 | } 81 | 82 | if isinstance(data, Iterable) and not isinstance( 83 | data, (str, bytes, bytearray) 84 | ): 85 | return [_json_safe(item) for item in data] 86 | 87 | if isinstance(data, (datetime, date)): 88 | return data.isoformat() 89 | 90 | return data 91 | 92 | 93 | def model_validator(*, mode: Literal["before", "after"]) -> Any: 94 | if PYDANTIC_V2: 95 | return pydantic.model_validator(mode=mode) 96 | else: 97 | return pydantic.root_validator(pre=(mode == "before")) # type: ignore 98 | -------------------------------------------------------------------------------- /tests/test_truncate_prompt.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import pytest 4 | 5 | from aidial_sdk import DIALApp 6 | from tests.applications.echo import EchoApplication 7 | from tests.applications.noop import NoopApplication 8 | from tests.utils.endpoint_test import TestCase, run_endpoint_test 9 | from tests.utils.errors import route_not_found_error 10 | 11 | CHAT_COMPLETION_REQUEST = { 12 | "messages": [ 13 | {"role": "system", "content": "system"}, 14 | {"role": "user", "content": "ping"}, 15 | {"role": "assistant", "content": "pong"}, 16 | {"role": "user", "content": "hello"}, 17 | ], 18 | } 19 | 20 | 21 | def create_request(max_prompt_tokens: Optional[int]): 22 | return { 23 | "inputs": [ 24 | { 25 | **CHAT_COMPLETION_REQUEST, 26 | "max_prompt_tokens": max_prompt_tokens, 27 | } 28 | ] 29 | } 30 | 31 | 32 | def create_response( 33 | model_max_prompt_tokens: int, max_prompt_tokens: Optional[int] 34 | ): 35 | if max_prompt_tokens is None: 36 | if model_max_prompt_tokens >= 4: 37 | return { 38 | "outputs": [{"status": "success", "discarded_messages": []}] 39 | } 40 | else: 41 | return { 42 | "outputs": [ 43 | { 44 | "status": "error", 45 | "error": "Token count of all messages (4) exceeds " 46 | f"the model maximum prompt tokens ({model_max_prompt_tokens}).", 47 | } 48 | ] 49 | } 50 | 51 | if max_prompt_tokens == 1: 52 | return { 53 | "outputs": [ 54 | { 55 | "status": "error", 56 | "error": "Token count of the last user message and all " 57 | "system messages (2) exceeds the maximum prompt tokens (1).", 58 | } 59 | ] 60 | } 61 | if max_prompt_tokens == 2: 62 | return { 63 | "outputs": [{"status": "success", "discarded_messages": [1, 2]}] 64 | } 65 | if max_prompt_tokens == 3: 66 | return {"outputs": [{"status": "success", "discarded_messages": [1]}]} 67 | return {"outputs": [{"status": "success", "discarded_messages": []}]} 68 | 69 | 70 | deployment = "test-app" 71 | 72 | noop = DIALApp().add_chat_completion(deployment, NoopApplication()) 73 | 74 | 75 | def echo(model_max_prompt_tokens: int): 76 | return DIALApp().add_chat_completion( 77 | deployment, EchoApplication(model_max_prompt_tokens) 78 | ) 79 | 80 | 81 | testcases: List[TestCase] = [ 82 | TestCase( 83 | noop, 84 | deployment, 85 | "truncate_prompt", 86 | create_request(None), 87 | route_not_found_error, 88 | ), 89 | TestCase( 90 | noop, 91 | deployment, 92 | "truncate_prompts", 93 | create_request(None), 94 | route_not_found_error, 95 | ), 96 | *[ 97 | TestCase( 98 | echo(4), 99 | deployment, 100 | "truncate_prompt", 101 | create_request(max_prompt_tokens), 102 | create_response(4, max_prompt_tokens), 103 | ) 104 | for max_prompt_tokens in range(1, 6) 105 | ], 106 | *[ 107 | TestCase( 108 | echo(model_limit), 109 | deployment, 110 | "truncate_prompt", 111 | create_request(None), 112 | create_response(model_limit, None), 113 | ) 114 | for model_limit in [3, 4] 115 | ], 116 | ] 117 | 118 | 119 | @pytest.mark.parametrize("testcase", testcases) 120 | def test_truncate_prompt(testcase: TestCase): 121 | run_endpoint_test(testcase) 122 | -------------------------------------------------------------------------------- /tests/test_extra_fields.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from fastapi.testclient import TestClient 3 | 4 | from aidial_sdk import DIALApp 5 | from tests.applications.validator import RequestValidator, ValidatorApplication 6 | from tests.utils.errors import extra_fields_error 7 | 8 | 9 | def _create_client(allow_extra: bool, validator: RequestValidator): 10 | dial_app = ( 11 | DIALApp() 12 | if allow_extra is None 13 | else DIALApp(allow_extra_request_fields=allow_extra) 14 | ).add_chat_completion( 15 | "test-app", ValidatorApplication(request_validator=validator) 16 | ) 17 | 18 | return TestClient( 19 | dial_app, 20 | headers={"Api-Key": "TEST_API_KEY"}, 21 | base_url="http://testserver/openai/deployments/test-app", 22 | ) 23 | 24 | 25 | @pytest.mark.parametrize("allow_extra", [True, False, None]) 26 | @pytest.mark.parametrize("stream", [True, False]) 27 | def test_extra_field_top_level(allow_extra: bool, stream: bool): 28 | client = _create_client( 29 | allow_extra, lambda r: r.extra_field == "extra_value" # type: ignore 30 | ) 31 | 32 | response = client.post( 33 | "chat/completions", 34 | json={ 35 | "messages": [{"role": "user", "content": "Test content"}], 36 | "extra_field": "extra_value", 37 | "stream": stream, 38 | }, 39 | ) 40 | 41 | if allow_extra in [None, False]: 42 | expected_response = extra_fields_error("extra_field") 43 | assert response.status_code == expected_response.code 44 | assert response.json() == expected_response.error 45 | else: 46 | assert response.status_code == 200 47 | 48 | 49 | @pytest.mark.parametrize("allow_extra", [True, False, None]) 50 | @pytest.mark.parametrize("stream", [True, False]) 51 | def test_extra_field_message(allow_extra: bool, stream: bool): 52 | client = _create_client( 53 | allow_extra, lambda r: r.messages[0].extra_field == "extra_value" # type: ignore 54 | ) 55 | 56 | response = client.post( 57 | "chat/completions", 58 | json={ 59 | "messages": [ 60 | { 61 | "role": "user", 62 | "content": "Test content", 63 | "extra_field": "extra_value", 64 | } 65 | ], 66 | "stream": stream, 67 | }, 68 | ) 69 | 70 | if allow_extra in [None, False]: 71 | expected_response = extra_fields_error("messages.0.extra_field") 72 | assert response.status_code == expected_response.code 73 | assert response.json() == expected_response.error 74 | else: 75 | assert response.status_code == 200 76 | 77 | 78 | @pytest.mark.parametrize("allow_extra", [True, False, None]) 79 | @pytest.mark.parametrize("stream", [True, False]) 80 | def test_extra_two_fields(allow_extra: bool, stream: bool): 81 | client = _create_client( 82 | allow_extra, 83 | lambda r: r.extra_field1 == "extra_value1" # type: ignore 84 | and r.messages[0].extra_field2 == "extra_value2", # type: ignore 85 | ) 86 | 87 | response = client.post( 88 | "chat/completions", 89 | json={ 90 | "messages": [ 91 | { 92 | "role": "user", 93 | "content": "Test content", 94 | "extra_field2": "extra_value2", 95 | } 96 | ], 97 | "extra_field1": "extra_value1", 98 | "stream": stream, 99 | }, 100 | ) 101 | 102 | if allow_extra in [None, False]: 103 | expected_response = extra_fields_error("extra_field1") 104 | assert response.status_code == expected_response.code 105 | assert response.json() == expected_response.error 106 | else: 107 | assert response.status_code == 200 108 | -------------------------------------------------------------------------------- /tests/test_request_tools_parsing.py: -------------------------------------------------------------------------------- 1 | from itertools import zip_longest 2 | 3 | import pytest 4 | 5 | from aidial_sdk.chat_completion.request import Request, StaticTool, Tool 6 | from tests.utils.chat_completion_validation import validate_chat_completion 7 | from tests.utils.pydantic import model_dump 8 | 9 | TEST_CASES = [ 10 | { 11 | "messages": [{"role": "user", "content": "Hello"}], 12 | "tools": [ 13 | { 14 | "type": "function", 15 | "function": { 16 | "name": "test_tool", 17 | "description": "Test tool", 18 | "parameters": {"type": "object", "properties": {}}, 19 | }, 20 | }, 21 | { 22 | "type": "function", 23 | "function": { 24 | "name": "test_tool_2", 25 | "description": "Test tool 2", 26 | "parameters": {"type": "object", "properties": {}}, 27 | }, 28 | }, 29 | ], 30 | "model": "gpt-3.5-turbo", 31 | }, 32 | { 33 | "model": "gpt-3.5-turbo", 34 | "messages": [{"role": "user", "content": "Hello"}], 35 | "tools": [ 36 | { 37 | "type": "function", 38 | "function": { 39 | "name": "test_tool", 40 | "description": "Test tool", 41 | "parameters": {"type": "object", "properties": {}}, 42 | }, 43 | }, 44 | { 45 | "type": "function", 46 | "function": { 47 | "name": "test_tool_2", 48 | "description": "Test tool 2", 49 | "parameters": {"type": "object", "properties": {}}, 50 | }, 51 | }, 52 | ], 53 | }, 54 | { 55 | "messages": [{"role": "user", "content": "Hello"}], 56 | "tools": [ 57 | { 58 | "type": "static_function", 59 | "static_function": { 60 | "name": "test_static_tool", 61 | "description": "Test static tool", 62 | "configuration": { 63 | "datastore": "test_datastore", 64 | "threshold": 0.5, 65 | }, 66 | }, 67 | }, 68 | ], 69 | "model": "gpt-3.5-turbo", 70 | }, 71 | { 72 | "messages": [{"role": "user", "content": "Hello"}], 73 | "tools": [ 74 | { 75 | "type": "function", 76 | "function": { 77 | "name": "test_tool", 78 | "description": "Test tool", 79 | "parameters": {"type": "object", "properties": {}}, 80 | }, 81 | }, 82 | { 83 | "type": "static_function", 84 | "static_function": { 85 | "name": "test_static_tool", 86 | "description": "Test static tool", 87 | "configuration": { 88 | "datastore": "test_datastore", 89 | "threshold": 0.5, 90 | }, 91 | }, 92 | }, 93 | ], 94 | "model": "gpt-3.5-turbo", 95 | }, 96 | ] 97 | 98 | 99 | @pytest.mark.parametrize("mock_data", TEST_CASES) 100 | def test_tools_parsing(mock_data): 101 | 102 | def _request_validator(r: Request): 103 | assert model_dump(r, exclude_none=True) == mock_data 104 | assert r.tools 105 | for mock_tool, tool in zip_longest( 106 | mock_data["tools"], r.tools, fillvalue={} 107 | ): 108 | if mock_tool["type"] == "function": 109 | assert isinstance(tool, Tool) 110 | elif mock_tool["type"] == "static_function": 111 | assert isinstance(tool, StaticTool) 112 | 113 | validate_chat_completion( 114 | request=mock_data, 115 | request_validator=_request_validator, 116 | ) 117 | -------------------------------------------------------------------------------- /examples/tic_tac_toe/app/game.py: -------------------------------------------------------------------------------- 1 | from typing import List, Literal, Optional, Union 2 | 3 | from pydantic import BaseModel 4 | 5 | # 1 for X 6 | # 2 for O 7 | Player = Literal[1, 2] 8 | 9 | X_PLAYER: Player = 1 10 | O_PLAYER: Player = 2 11 | 12 | 13 | def _print_player(player: Player) -> str: 14 | return "X" if player == X_PLAYER else "O" 15 | 16 | 17 | class Move(BaseModel): 18 | row: int 19 | col: int 20 | 21 | @classmethod 22 | def from_button_value(cls, move: int) -> "Move": 23 | return Move(row=(move % 10) - 1, col=(move // 10) - 1) 24 | 25 | def to_button_value(self) -> int: 26 | return (self.row + 1) + 10 * (self.col + 1) 27 | 28 | def print(self) -> str: 29 | col = {0: "A", 1: "B", 2: "C"}[self.col] 30 | row = self.row + 1 31 | return f"{col}{row}" 32 | 33 | @staticmethod 34 | def col_lines() -> List[List["Move"]]: 35 | return [ 36 | [Move(row=row, col=col) for row in range(3)] for col in range(3) 37 | ] 38 | 39 | @staticmethod 40 | def row_lines() -> List[List["Move"]]: 41 | return [ 42 | [Move(row=row, col=col) for col in range(3)] for row in range(3) 43 | ] 44 | 45 | @staticmethod 46 | def diag_lines() -> List[List["Move"]]: 47 | return [ 48 | [Move(row=i, col=i) for i in range(3)], 49 | [Move(row=i, col=2 - i) for i in range(3)], 50 | ] 51 | 52 | 53 | class Board(BaseModel): 54 | cells: List[List[Optional[Player]]] = [[None] * 3] * 3 55 | """Current state of the game board""" 56 | 57 | @property 58 | def x_moves(self) -> int: 59 | return sum(row.count(X_PLAYER) for row in self.cells) 60 | 61 | @property 62 | def o_moves(self) -> int: 63 | return sum(row.count(O_PLAYER) for row in self.cells) 64 | 65 | @property 66 | def player(self) -> Player: 67 | """Returns the player who should make the next move""" 68 | return X_PLAYER if self.x_moves == self.o_moves else O_PLAYER 69 | 70 | @property 71 | def finished(self) -> bool: 72 | """Returns True if the game is finished""" 73 | return self.status != "Unfinished" 74 | 75 | @property 76 | def status(self) -> Union[Player, Literal["Draw", "Unfinished"]]: 77 | """Returns the winner of the game, None if the game is not yet finished""" 78 | for line in Move.col_lines() + Move.row_lines() + Move.diag_lines(): 79 | cells: List[Optional[Player]] = [self.get(move) for move in line] 80 | if ( 81 | all(cells) 82 | and len(set(cells)) == 1 83 | and (winner := cells[0]) is not None 84 | ): 85 | return winner 86 | 87 | if self.x_moves + self.o_moves == 9: 88 | return "Draw" 89 | 90 | return "Unfinished" 91 | 92 | @property 93 | def possible_moves(self) -> List[Move]: 94 | """Returns the list of possible moves""" 95 | ret: List[Move] = [] 96 | for row_idx, row in enumerate(self.cells): 97 | for col_idx, cell in enumerate(row): 98 | if cell is None: 99 | ret.append(Move(row=row_idx, col=col_idx)) 100 | return ret 101 | 102 | def get(self, move: Move) -> Optional[Player]: 103 | """Returns the player who made the move""" 104 | return self.cells[move.row][move.col] 105 | 106 | def make_move(self, move: Move) -> "Board": 107 | """Returns a new board after making the move""" 108 | assert self.get(move) is None 109 | cells = [row.copy() for row in self.cells] 110 | cells[move.row][move.col] = self.player 111 | return Board(cells=cells) 112 | 113 | def to_markdown(self) -> str: 114 | """Prints the game board as a Markdown table""" 115 | ret = "" 116 | ret += "||A|B|C|\n" 117 | ret += "|---|---|---|---|\n" 118 | for i, row in reversed(list(enumerate(self.cells))): 119 | ret += f"|{i + 1}|" 120 | for cell in row: 121 | ret += f"{' ' if cell is None else _print_player(cell)}|" 122 | ret += "\n" 123 | return ret 124 | -------------------------------------------------------------------------------- /aidial_sdk/utils/pydantic.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Any, Dict, Iterator, List, Mapping, Optional, Tuple, Union 3 | 4 | from aidial_sdk._pydantic import PYDANTIC_V2, ConfigDict, FieldInfo 5 | from aidial_sdk._pydantic._compat import BaseModel, model_validator 6 | 7 | 8 | class ExtraAllowModel(BaseModel): 9 | if PYDANTIC_V2: 10 | model_config = ConfigDict(extra="allow") 11 | else: 12 | 13 | class Config: 14 | extra = "allow" 15 | 16 | 17 | class IgnoreIndex(BaseModel): 18 | @model_validator(mode="before") 19 | @classmethod 20 | def strip_index(cls, data: Any) -> Any: 21 | if ( 22 | isinstance(data, Mapping) 23 | and (idx := data.get("index")) is not None 24 | and isinstance(idx, int) 25 | ): 26 | d = dict(data) 27 | d.pop("index") 28 | return d 29 | return data 30 | 31 | 32 | _Loc = Tuple[Union[int, str], ...] 33 | 34 | 35 | def _get_model_fields(obj: BaseModel) -> Dict[str, FieldInfo]: 36 | if PYDANTIC_V2: 37 | return obj.model_fields 38 | else: 39 | return obj.__fields__ # type: ignore 40 | 41 | 42 | def _get_model_config_field(obj: BaseModel, field_name: str) -> Optional[Any]: 43 | if PYDANTIC_V2: 44 | return obj.model_config.get(field_name) 45 | else: 46 | return getattr(obj.Config, field_name, None) # type: ignore 47 | 48 | 49 | def _model_iterate_fields( 50 | obj: Any, any_types: bool, loc: _Loc 51 | ) -> Iterator[Tuple[BaseModel, _Loc]]: 52 | if isinstance(obj, BaseModel): 53 | yield (obj, loc) 54 | any_types = ( 55 | _get_model_config_field(obj, "arbitrary_types_allowed") or False 56 | ) 57 | for field in _get_model_fields(obj): 58 | value = getattr(obj, field) 59 | yield from _model_iterate_fields(value, any_types, loc + (field,)) 60 | 61 | elif isinstance(obj, list): 62 | for idx, item in enumerate(obj): 63 | yield from _model_iterate_fields(item, any_types, loc + (idx,)) 64 | 65 | elif isinstance(obj, dict): 66 | for key, val in obj.items(): 67 | yield from _model_iterate_fields(val, any_types, loc + (key,)) 68 | 69 | elif isinstance(obj, (str, int, float, bool, type(None), Enum)): 70 | pass 71 | 72 | else: 73 | err_message = f"Cannot iterate model fields within an object with the unexpected type: {type(obj)}, loc: {loc}" 74 | assert any_types, err_message 75 | 76 | 77 | if PYDANTIC_V2: 78 | from pydantic import ValidationError 79 | from pydantic_core import InitErrorDetails, PydanticCustomError 80 | 81 | def model_validate_extra_fields(root_model: BaseModel) -> None: 82 | errors: List[InitErrorDetails] = [] 83 | 84 | extra_error_type = PydanticCustomError( 85 | "extra_forbidden", "Extra inputs are not permitted" 86 | ) 87 | 88 | for model, loc in _model_iterate_fields(root_model, False, ()): 89 | 90 | for key, value in (model.model_extra or {}).items(): 91 | errors.append( 92 | { 93 | "type": extra_error_type, 94 | "loc": loc + (key,), 95 | "input": value, 96 | } 97 | ) 98 | 99 | if errors: 100 | raise ValidationError.from_exception_data( 101 | type(root_model).__name__, errors 102 | ) 103 | 104 | else: 105 | from pydantic.v1.error_wrappers import ErrorWrapper, ValidationError 106 | from pydantic.v1.errors import ExtraError 107 | 108 | def model_validate_extra_fields(root_model: BaseModel) -> None: 109 | errors: List[ErrorWrapper] = [] 110 | 111 | for model, loc in _model_iterate_fields(root_model, False, ()): 112 | declared = set(_get_model_fields(model).keys()) 113 | for key in model.__dict__: 114 | if key not in declared: 115 | errors.append(ErrorWrapper(ExtraError(), loc=loc + (key,))) 116 | 117 | if errors: 118 | raise ValidationError(errors, root_model.__class__) # type: ignore 119 | -------------------------------------------------------------------------------- /examples/image_size/app/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple image-to-text DIAL application. 3 | Takes the last message, extract an image from an attachment and 4 | returns its width and height as text. 5 | """ 6 | 7 | import os 8 | from urllib.parse import urlparse 9 | 10 | import uvicorn 11 | 12 | from aidial_sdk import DIALApp 13 | from aidial_sdk import HTTPException as DIALException 14 | from aidial_sdk.chat_completion import ChatCompletion, Request, Response 15 | 16 | from .image import download_image_as_base64, get_image_base64_size 17 | 18 | DIAL_URL = os.environ.get("DIAL_URL") 19 | 20 | 21 | # A helper to distinguish relative URLs from absolute ones 22 | # Relative URLs are treated as URLs to the DIAL File storage 23 | # Absolute URLs are treated as publicly accessible URLs to external resources 24 | def is_relative_url(url) -> bool: 25 | parsed_url = urlparse(url) 26 | return ( 27 | not parsed_url.scheme 28 | and not parsed_url.netloc 29 | and not url.startswith("/") 30 | ) 31 | 32 | 33 | # ChatCompletion is an abstract class for applications and model adapters 34 | class ImageSizeApplication(ChatCompletion): 35 | async def chat_completion(self, request: Request, response: Response): 36 | # Create a single choice 37 | with response.create_single_choice() as choice: 38 | # Validate the request: 39 | # - the request must contain at least one message, and 40 | # - the last message must contain one and only one image attachment. 41 | 42 | if len(request.messages) == 0: 43 | raise DIALException( 44 | message="The request must contain at least one message", 45 | status_code=422, 46 | ) 47 | 48 | message = request.messages[-1] 49 | 50 | if ( 51 | message.custom_content is None 52 | or message.custom_content.attachments is None 53 | ): 54 | raise DIALException( 55 | message="No image attachment was found in the last message", 56 | status_code=422, 57 | ) 58 | 59 | attachments = message.custom_content.attachments 60 | 61 | if len(attachments) != 1: 62 | raise DIALException( 63 | message="Only one attachment is expected in the last message", 64 | status_code=422, 65 | ) 66 | 67 | # Get the image from the last message attachments 68 | attachment = attachments[0] 69 | 70 | # The attachment contains either the image content as a base64 string or an image URL 71 | if attachment.data is not None: 72 | image_data = attachment.data 73 | elif attachment.url is not None: 74 | image_url = attachment.url 75 | 76 | # Download the image from the URL 77 | if is_relative_url(image_url): 78 | if DIAL_URL is None: 79 | # DIAL SDK automatically converts standard Python exceptions to 500 Internal Server Error 80 | raise ValueError( 81 | "DIAL_URL environment variable is unset" 82 | ) 83 | 84 | image_abs_url = f"{DIAL_URL}/v1/{image_url}" 85 | else: 86 | image_abs_url = image_url 87 | 88 | image_data = await download_image_as_base64(image_abs_url) 89 | else: 90 | raise DIALException( 91 | message="Either 'data' or 'url' field must be provided in the attachment", 92 | status_code=422, 93 | ) 94 | 95 | # Compute the image size 96 | (w, h) = get_image_base64_size(image_data) 97 | 98 | # Return the image size 99 | choice.append_content(f"Size: {w}x{h}px") 100 | 101 | 102 | # DIALApp extends FastAPI to provide a user-friendly interface for routing requests to your applications 103 | app = DIALApp( 104 | dial_url=DIAL_URL, 105 | propagate_auth_headers=DIAL_URL is not None, 106 | add_healthcheck=True, 107 | ) 108 | 109 | app.add_chat_completion("image-size", ImageSizeApplication()) 110 | 111 | if __name__ == "__main__": 112 | uvicorn.run(app, port=5000) 113 | -------------------------------------------------------------------------------- /aidial_sdk/pydantic/v2.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides extensions of `ConfigDict` class and `Field` 3 | descriptor with DIAL-specific features. 4 | 5 | These extensions should be used instead of the native counterparts to avoid 6 | deprecation warnings and type-checking issues. 7 | """ 8 | 9 | from typing import Any, Callable, Dict, List, Optional, Union 10 | 11 | import pydantic as pyd2 12 | from typing_extensions import Literal 13 | 14 | from aidial_sdk._pydantic import PYDANTIC_V2 15 | 16 | 17 | class ConfigDict(pyd2.ConfigDict): 18 | chat_message_input_disabled: bool 19 | 20 | 21 | if not PYDANTIC_V2: 22 | 23 | def Field(*args, **kwargs) -> Any: # type: ignore 24 | raise ImportError("The Field helper is only supported in Pydantic v2") 25 | 26 | else: 27 | from pydantic.aliases import AliasChoices, AliasPath 28 | from pydantic.fields import Field as PydanticField 29 | from pydantic_core import PydanticUndefined 30 | 31 | from aidial_sdk.chat_completion.form import Button 32 | 33 | _Unset: Any = PydanticUndefined 34 | 35 | def Field( 36 | default: Any = PydanticUndefined, 37 | *, 38 | default_factory: Optional[Callable[[], Any]] = _Unset, 39 | alias: Optional[str] = _Unset, 40 | alias_priority: Optional[int] = _Unset, 41 | validation_alias: Optional[ 42 | Union[str, AliasPath, AliasChoices] 43 | ] = _Unset, 44 | serialization_alias: Optional[str] = _Unset, 45 | title: Optional[str] = _Unset, 46 | description: Optional[str] = _Unset, 47 | examples: Optional[List[Any]] = _Unset, 48 | exclude: Optional[bool] = _Unset, 49 | discriminator: Optional[str] = _Unset, 50 | json_schema_extra: Optional[ 51 | (Union[Dict[str, Any], Callable[[Dict[str, Any]], None]]) 52 | ] = _Unset, 53 | frozen: Optional[bool] = _Unset, 54 | validate_default: Optional[bool] = _Unset, 55 | repr: bool = _Unset, 56 | init_var: Optional[bool] = _Unset, 57 | kw_only: Optional[bool] = _Unset, 58 | pattern: Optional[str] = _Unset, 59 | strict: Optional[bool] = _Unset, 60 | gt: Optional[float] = _Unset, 61 | ge: Optional[float] = _Unset, 62 | lt: Optional[float] = _Unset, 63 | le: Optional[float] = _Unset, 64 | multiple_of: Optional[float] = _Unset, 65 | allow_inf_nan: Optional[bool] = _Unset, 66 | max_digits: Optional[int] = _Unset, 67 | decimal_places: Optional[int] = _Unset, 68 | min_length: Optional[int] = _Unset, 69 | max_length: Optional[int] = _Unset, 70 | union_mode: Literal["smart", "left_to_right"] = _Unset, 71 | buttons: Optional[List[Button]] = _Unset, 72 | ) -> Any: 73 | 74 | if buttons is not _Unset and buttons is not None: 75 | if json_schema_extra is _Unset or json_schema_extra is None: 76 | json_schema_extra = {} 77 | 78 | if not callable(json_schema_extra): 79 | new_extra = {**json_schema_extra, "buttons": buttons} 80 | else: 81 | 82 | def _extra(x: Dict[str, Any]) -> None: 83 | json_schema_extra({**x, "buttons": buttons}) 84 | 85 | new_extra = _extra 86 | else: 87 | new_extra = json_schema_extra 88 | 89 | return PydanticField( 90 | default=default, 91 | default_factory=default_factory, 92 | alias=alias, 93 | alias_priority=alias_priority, 94 | validation_alias=validation_alias, 95 | serialization_alias=serialization_alias, 96 | title=title, 97 | description=description, 98 | examples=examples, 99 | exclude=exclude, 100 | discriminator=discriminator, 101 | json_schema_extra=new_extra, 102 | frozen=frozen, 103 | validate_default=validate_default, 104 | repr=repr, 105 | init_var=init_var, 106 | kw_only=kw_only, 107 | pattern=pattern, 108 | strict=strict, 109 | gt=gt, 110 | ge=ge, 111 | lt=lt, 112 | le=le, 113 | multiple_of=multiple_of, 114 | allow_inf_nan=allow_inf_nan, 115 | max_digits=max_digits, 116 | decimal_places=decimal_places, 117 | min_length=min_length, 118 | max_length=max_length, 119 | union_mode=union_mode, 120 | ) 121 | -------------------------------------------------------------------------------- /aidial_sdk/header_propagator.py: -------------------------------------------------------------------------------- 1 | import types 2 | from contextvars import ContextVar 3 | from typing import MutableMapping, Optional 4 | 5 | import wrapt 6 | from fastapi import FastAPI 7 | from starlette.types import ASGIApp, Receive, Scope, Send 8 | 9 | 10 | class FastAPIMiddleware: 11 | def __init__( 12 | self, 13 | app: ASGIApp, 14 | api_key: ContextVar[Optional[str]], 15 | ) -> None: 16 | self.app = app 17 | self.api_key = api_key 18 | 19 | async def __call__( 20 | self, scope: Scope, receive: Receive, send: Send 21 | ) -> None: 22 | for header in scope.get("headers") or []: 23 | if header[0] == b"api-key": 24 | self.api_key.set(header[1].decode("utf-8")) 25 | 26 | await self.app(scope, receive, send) 27 | 28 | 29 | class HeaderPropagator: 30 | _app: FastAPI 31 | _dial_url: str 32 | _api_key: ContextVar[Optional[str]] 33 | _enabled: bool 34 | 35 | def __init__(self, app: FastAPI, dial_url: str): 36 | self._app = app 37 | self._dial_url = dial_url 38 | 39 | self._api_key: ContextVar[Optional[str]] = ContextVar( 40 | "api_key", default=None 41 | ) 42 | 43 | self._enabled = False 44 | 45 | def enable(self): 46 | if self._enabled: 47 | return 48 | 49 | self._instrument_fast_api(self._app) 50 | self._instrument_aiohttp() 51 | self._instrument_httpx() 52 | self._instrument_requests() 53 | self._enabled = True 54 | 55 | def _instrument_fast_api(self, app: FastAPI): 56 | app.add_middleware(FastAPIMiddleware, api_key=self._api_key) 57 | 58 | def _instrument_aiohttp(self): 59 | try: 60 | import aiohttp 61 | except ImportError: 62 | return 63 | 64 | async def _on_request_start( 65 | session: aiohttp.ClientSession, 66 | trace_config_ctx: types.SimpleNamespace, 67 | params: aiohttp.TraceRequestStartParams, 68 | ): 69 | self._modify_headers(str(params.url), params.headers) 70 | 71 | def instrumented_init(wrapped, instance, args, kwargs): 72 | trace_config = aiohttp.TraceConfig() 73 | trace_config.on_request_start.append(_on_request_start) 74 | 75 | trace_configs = list(kwargs.get("trace_configs") or []) 76 | trace_configs.append(trace_config) 77 | 78 | kwargs["trace_configs"] = trace_configs 79 | return wrapped(*args, **kwargs) 80 | 81 | wrapt.wrap_function_wrapper( 82 | aiohttp.ClientSession, "__init__", instrumented_init 83 | ) 84 | 85 | def _instrument_requests(self): 86 | try: 87 | import requests 88 | except ImportError: 89 | return 90 | 91 | def instrumented_send(wrapped, instance, args, kwargs): 92 | request: requests.PreparedRequest = args[0] 93 | self._modify_headers(request.url or "", request.headers) 94 | return wrapped(*args, **kwargs) 95 | 96 | wrapt.wrap_function_wrapper(requests.Session, "send", instrumented_send) 97 | 98 | def _instrument_httpx(self): 99 | try: 100 | import httpx 101 | except ImportError: 102 | return 103 | 104 | def instrumented_build_request(wrapped, instance, args, kwargs): 105 | request: httpx.Request = wrapped(*args, **kwargs) 106 | self._modify_headers(str(request.url), request.headers) 107 | return request 108 | 109 | wrapt.wrap_function_wrapper( 110 | httpx.Client, "build_request", instrumented_build_request 111 | ) 112 | 113 | wrapt.wrap_function_wrapper( 114 | httpx.AsyncClient, "build_request", instrumented_build_request 115 | ) 116 | 117 | def _modify_headers( 118 | self, url: str, headers: MutableMapping[str, str] 119 | ) -> None: 120 | if url.startswith(self._dial_url): 121 | api_key = self._api_key.get() 122 | if api_key: 123 | old_api_key = headers.get("api-key") 124 | old_authz = headers.get("Authorization") 125 | 126 | if ( 127 | old_api_key 128 | and old_authz 129 | and old_authz == f"Bearer {old_api_key}" 130 | ): 131 | headers["Authorization"] = f"Bearer {api_key}" 132 | 133 | headers["api-key"] = api_key 134 | -------------------------------------------------------------------------------- /aidial_sdk/_pydantic/_model_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper classes that unify model configuration between Pydantic v1 and v2. 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | from abc import ABC, abstractmethod 8 | from typing import Any, Callable, Dict, Optional, Type, TypeVar 9 | 10 | from aidial_sdk._pydantic import PYDANTIC_V2, BaseModel 11 | 12 | _Model = TypeVar("_Model", bound=BaseModel) 13 | 14 | 15 | class ModelConfigWrapper: 16 | _model_config: ModelConfigBase 17 | 18 | def __init__(self, model_config: ModelConfigBase): 19 | self._model_config = model_config 20 | 21 | def __getitem__(self, field: str) -> Any: 22 | return self._model_config.get_field(field, None) 23 | 24 | def __setitem__(self, field: str, value: Any) -> None: 25 | self._model_config.set_field(field, value) 26 | 27 | def post_process_schema( 28 | self, on_schema: Callable[[Dict[str, Any]], None] 29 | ) -> None: 30 | attr_name = self._model_config.schema_extra_field 31 | old_schema_extra = self[attr_name] 32 | 33 | def _schema_extra( 34 | schema: Dict[str, Any], model: Type[BaseModel] 35 | ) -> None: 36 | if old_schema_extra: 37 | old_schema_extra(schema, model) 38 | on_schema(schema) 39 | 40 | self[attr_name] = _schema_extra 41 | 42 | @classmethod 43 | def create( 44 | cls, base_cls: Optional[Type[_Model]], namespace: Dict[str, Any] 45 | ) -> ModelConfigWrapper: 46 | if PYDANTIC_V2: 47 | return cls(_ConfigV2.create(base_cls, namespace)) 48 | else: 49 | return cls(_ConfigV1.create(base_cls, namespace)) 50 | 51 | 52 | class ModelConfigBase(ABC): 53 | @abstractmethod 54 | def set_field(self, field: str, value: Any) -> None: 55 | pass 56 | 57 | @abstractmethod 58 | def get_field(self, field: str, default: Any) -> Any: 59 | pass 60 | 61 | @property 62 | @abstractmethod 63 | def schema_extra_field(self) -> str: 64 | pass 65 | 66 | @classmethod 67 | @abstractmethod 68 | def create( 69 | cls, base_cls: Optional[Type[_Model]], namespace: Dict[str, Any] 70 | ) -> ModelConfigBase: 71 | pass 72 | 73 | 74 | class _ConfigV1(ModelConfigBase): 75 | config_cls: Type 76 | 77 | def __init__(self, config_cls: Type): 78 | self.config_cls = config_cls 79 | 80 | def set_field(self, field: str, value: Any) -> None: 81 | setattr(self.config_cls, field, value) 82 | 83 | def get_field(self, field: str, default: Any) -> Any: 84 | return getattr(self.config_cls, field, default) 85 | 86 | @property 87 | def schema_extra_field(self) -> str: 88 | return "schema_extra" 89 | 90 | @classmethod 91 | def create( 92 | cls, base_cls: Optional[Type[_Model]], namespace: Dict[str, Any] 93 | ) -> ModelConfigBase: 94 | if (config_cls := namespace.get("Config")) is None: 95 | conf_base_cls = ( 96 | None if base_cls is None else getattr(base_cls, "Config", None) 97 | ) 98 | 99 | config_cls = type("Config", (conf_base_cls or object,), {}) 100 | 101 | if module := namespace.get("__module__"): 102 | config_cls.__module__ = module 103 | if qualname := namespace.get("__qualname__"): 104 | config_cls.__qualname__ = f"{qualname}.{config_cls.__name__}" 105 | 106 | namespace["Config"] = config_cls 107 | 108 | return cls(config_cls) 109 | 110 | 111 | class _ConfigV2(ModelConfigBase): 112 | model_config: Dict 113 | 114 | def __init__(self, model_config: Dict): 115 | self.model_config = model_config 116 | 117 | def set_field(self, field: str, value: Any) -> None: 118 | self.model_config[field] = value 119 | 120 | def get_field(self, field: str, default: Any) -> Any: 121 | return self.model_config.get(field, default) 122 | 123 | @property 124 | def schema_extra_field(self) -> str: 125 | return "json_schema_extra" 126 | 127 | @classmethod 128 | def create( 129 | cls, base_cls: Optional[Type[_Model]], namespace: Dict[str, Any] 130 | ) -> ModelConfigBase: 131 | base_model_config = ( 132 | {} if base_cls is None else getattr(base_cls, "model_config", {}) 133 | ) 134 | 135 | curr_model_config = namespace.get("model_config") or {} 136 | 137 | model_config = namespace["model_config"] = { 138 | **base_model_config, 139 | **curr_model_config, 140 | } 141 | 142 | return cls(model_config) 143 | -------------------------------------------------------------------------------- /aidial_sdk/telemetry/init.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Optional 3 | 4 | from fastapi import FastAPI 5 | from opentelemetry._logs import set_logger_provider 6 | from opentelemetry.exporter.otlp.proto.grpc._log_exporter import OTLPLogExporter 7 | from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import ( 8 | OTLPMetricExporter, 9 | ) 10 | from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import ( 11 | OTLPSpanExporter, 12 | ) 13 | from opentelemetry.exporter.prometheus import PrometheusMetricReader 14 | from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor 15 | from opentelemetry.instrumentation.logging import LoggingInstrumentor 16 | from opentelemetry.instrumentation.system_metrics import ( 17 | SystemMetricsInstrumentor, 18 | ) 19 | from opentelemetry.instrumentation.urllib import URLLibInstrumentor 20 | from opentelemetry.metrics import set_meter_provider 21 | from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler 22 | from opentelemetry.sdk._logs.export import BatchLogRecordProcessor 23 | from opentelemetry.sdk.metrics import MeterProvider 24 | from opentelemetry.sdk.metrics._internal.export import ( 25 | PeriodicExportingMetricReader, 26 | ) 27 | from opentelemetry.sdk.resources import SERVICE_NAME, Resource 28 | from opentelemetry.sdk.trace import TracerProvider 29 | from opentelemetry.sdk.trace.export import BatchSpanProcessor 30 | from opentelemetry.trace import set_tracer_provider 31 | from prometheus_client import start_http_server 32 | 33 | from aidial_sdk.telemetry.types import TelemetryConfig 34 | 35 | 36 | def init_telemetry( 37 | app: Optional[FastAPI], 38 | config: TelemetryConfig, 39 | ): 40 | resource = Resource.create( 41 | attributes=( 42 | {SERVICE_NAME: config.service_name} if config.service_name else None 43 | ) 44 | ) 45 | 46 | if config.tracing is not None: 47 | tracer_provider = TracerProvider(resource=resource) 48 | 49 | if config.tracing.otlp_export: 50 | tracer_provider.add_span_processor( 51 | BatchSpanProcessor(OTLPSpanExporter()) 52 | ) 53 | 54 | set_tracer_provider(tracer_provider) 55 | 56 | try: 57 | from opentelemetry.instrumentation.requests import ( 58 | RequestsInstrumentor, 59 | ) 60 | 61 | RequestsInstrumentor().instrument() 62 | except ImportError: 63 | pass 64 | 65 | try: 66 | from opentelemetry.instrumentation.aiohttp_client import ( 67 | AioHttpClientInstrumentor, 68 | ) 69 | 70 | AioHttpClientInstrumentor().instrument() 71 | except ImportError: 72 | pass 73 | 74 | URLLibInstrumentor().instrument() 75 | 76 | try: 77 | from opentelemetry.instrumentation.httpx import ( 78 | HTTPXClientInstrumentor, 79 | ) 80 | 81 | HTTPXClientInstrumentor().instrument() 82 | except ImportError: 83 | pass 84 | 85 | if config.tracing.logging: 86 | # Setting the root logger format in order to include 87 | # tracing information: span_id, trace_id 88 | LoggingInstrumentor().instrument(set_logging_format=True) 89 | 90 | if config.logs is not None: 91 | # Adding a handler to the root logger which exports the logs to OTLP 92 | provider = LoggerProvider(resource=resource) 93 | 94 | if config.logs.otlp_export: 95 | provider.add_log_record_processor( 96 | BatchLogRecordProcessor(OTLPLogExporter()) 97 | ) 98 | 99 | set_logger_provider(provider) 100 | 101 | handler = LoggingHandler(level=config.logs.level) 102 | logging.getLogger().addHandler(handler) 103 | 104 | if config.metrics is not None: 105 | metric_readers = [] 106 | 107 | if config.metrics.prometheus_export: 108 | metric_readers.append(PrometheusMetricReader()) 109 | 110 | if config.metrics.otlp_export: 111 | metric_readers.append( 112 | PeriodicExportingMetricReader(OTLPMetricExporter()) 113 | ) 114 | 115 | set_meter_provider( 116 | MeterProvider(resource=resource, metric_readers=metric_readers) 117 | ) 118 | 119 | SystemMetricsInstrumentor().instrument() 120 | 121 | if config.metrics.prometheus_export: 122 | start_http_server(port=config.metrics.port) 123 | 124 | if app and (config.tracing is not None or config.metrics is not None): 125 | # FastAPI instrumentor reports both metrics and traces 126 | FastAPIInstrumentor.instrument_app(app) 127 | -------------------------------------------------------------------------------- /tests/test_header_propagation.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from itertools import product 4 | from typing import Mapping, Optional 5 | 6 | import aioresponses 7 | import httpx 8 | import pytest 9 | import requests 10 | import responses 11 | import respx 12 | from fastapi import FastAPI 13 | from fastapi.testclient import TestClient 14 | from requests.structures import CaseInsensitiveDict 15 | 16 | from aidial_sdk.header_propagator import HeaderPropagator 17 | from aidial_sdk.utils.json import remove_nones 18 | from tests.header_propagation.client import app as sender 19 | from tests.utils.text import removeprefix 20 | 21 | DIAL_URL = "http://dial.example.com" 22 | NON_DIAL_URL = "http://non-dial.example.com" 23 | API_KEY = "test-api-key" 24 | 25 | URL_PATTERN = re.compile(rf"{re.escape(DIAL_URL)}|{re.escape(NON_DIAL_URL)}") 26 | HOSTS = [removeprefix(url, "http://") for url in [DIAL_URL, NON_DIAL_URL]] 27 | 28 | 29 | @pytest.fixture 30 | def client(): 31 | app = FastAPI() 32 | app.include_router(sender.router) 33 | HeaderPropagator(app, DIAL_URL).enable() 34 | return TestClient(app) 35 | 36 | 37 | def _get_headers(headers: Mapping[str, str]) -> dict: 38 | api_key = headers.get("Api-Key") 39 | authz = headers.get("Authorization") 40 | return remove_nones({"api-key": api_key, "authorization": authz}) 41 | 42 | 43 | @pytest.fixture 44 | def mock_requests(): 45 | with responses.mock as mock: 46 | 47 | def callback(request: requests.PreparedRequest): 48 | return ( 49 | 200, 50 | {"content-type": "application/json"}, 51 | json.dumps(_get_headers(request.headers)), 52 | ) 53 | 54 | mock.add_callback( 55 | responses.GET, 56 | URL_PATTERN, 57 | callback=callback, 58 | content_type="application/json", 59 | ) 60 | 61 | yield mock 62 | 63 | 64 | @pytest.fixture 65 | def mock_httpx(): 66 | with respx.mock as mock: 67 | 68 | @respx.route(method="GET", host__in=HOSTS, path="/") 69 | def handler(request: httpx.Request): 70 | return httpx.Response(200, json=_get_headers(request.headers)) 71 | 72 | yield mock 73 | 74 | 75 | @pytest.fixture 76 | def mock_aiohttp(): 77 | with aioresponses.aioresponses() as mock: 78 | 79 | def callback(url, **kwargs) -> aioresponses.CallbackResult: 80 | headers = CaseInsensitiveDict(kwargs.get("headers", {})) 81 | return aioresponses.CallbackResult(payload=_get_headers(headers)) 82 | 83 | mock.get(URL_PATTERN, callback=callback) 84 | yield mock 85 | 86 | 87 | @pytest.mark.parametrize( 88 | "lib, url, key_to_propagate, key_for_upstream, add_authz", 89 | product( 90 | ["aiohttp", "requests", "httpx_sync", "httpx_async"], 91 | [DIAL_URL, NON_DIAL_URL], 92 | ["test-api-key", None], 93 | ["dummy-api-key", None], 94 | [True, False], 95 | ), 96 | ) 97 | def test_send_request( 98 | client: TestClient, 99 | mock_requests, 100 | mock_httpx, 101 | mock_aiohttp, 102 | lib: str, 103 | url: str, 104 | key_to_propagate: Optional[str], 105 | key_for_upstream: Optional[str], 106 | add_authz: bool, 107 | ): 108 | headers_to_propagate = {} 109 | if key_to_propagate: 110 | headers_to_propagate["api-key"] = key_to_propagate 111 | if add_authz: 112 | headers_to_propagate["authorization"] = f"Bearer {key_to_propagate}" 113 | 114 | headers_for_upstream = {} 115 | if key_for_upstream: 116 | headers_for_upstream["api-key"] = key_for_upstream 117 | if add_authz: 118 | headers_for_upstream["authorization"] = f"Bearer {key_for_upstream}" 119 | 120 | response = client.post( 121 | "/", 122 | json={"url": url, "lib": lib, "headers": headers_for_upstream}, 123 | headers=headers_to_propagate, 124 | ) 125 | assert response.status_code == 200, response.json() 126 | 127 | expected_key = ( 128 | key_to_propagate if url == DIAL_URL else None 129 | ) or key_for_upstream 130 | 131 | expected_headers = {} 132 | if expected_key: 133 | expected_headers["api-key"] = expected_key 134 | if add_authz and key_for_upstream: 135 | expected_headers["authorization"] = f"Bearer {expected_key}" 136 | 137 | # NOTE: aioresponses doesn't call trace_configs in the mocked version, 138 | # and since we are patching the request via a dedicated trace config, 139 | # we can't test the header propagation for aiohttp. 140 | # https://github.com/pnuckowski/aioresponses/issues/246 141 | if lib == "aiohttp": 142 | expected_headers = headers_for_upstream 143 | 144 | assert response.json() == expected_headers 145 | -------------------------------------------------------------------------------- /aidial_sdk/chat_completion/stage.py: -------------------------------------------------------------------------------- 1 | from types import TracebackType 2 | from typing import Optional, Type, overload 3 | 4 | from aidial_sdk._pydantic import ValidationError 5 | from aidial_sdk.chat_completion._types import ChunkQueue 6 | from aidial_sdk.chat_completion.chunks import ( 7 | AttachmentStageChunk, 8 | ContentStageChunk, 9 | FinishStageChunk, 10 | NameStageChunk, 11 | StartStageChunk, 12 | ) 13 | from aidial_sdk.chat_completion.enums import Status 14 | from aidial_sdk.chat_completion.request import Attachment 15 | from aidial_sdk.utils._attachment import create_attachment 16 | from aidial_sdk.utils._content_stream import ContentStream 17 | from aidial_sdk.utils.errors import runtime_error 18 | 19 | 20 | class Stage: 21 | _queue: ChunkQueue 22 | _choice_index: int 23 | _stage_index: int 24 | _name: Optional[str] 25 | _last_attachment_index: int 26 | _closed: bool 27 | _opened: bool 28 | 29 | def __init__( 30 | self, 31 | queue: ChunkQueue, 32 | choice_index: int, 33 | stage_index: int, 34 | name: Optional[str] = None, 35 | ): 36 | self._queue = queue 37 | self._choice_index = choice_index 38 | self._stage_index = stage_index 39 | self._last_attachment_index = 0 40 | self._opened = False 41 | self._closed = False 42 | self._name = name 43 | 44 | def __enter__(self): 45 | self.open() 46 | return self 47 | 48 | def __exit__( 49 | self, 50 | exc_type: Optional[Type[BaseException]], 51 | exc: Optional[BaseException], 52 | traceback: Optional[TracebackType], 53 | ) -> Optional[bool]: 54 | if not exc: 55 | if not self._closed: 56 | self.close(Status.COMPLETED) 57 | else: 58 | self.close(Status.FAILED) 59 | 60 | return False 61 | 62 | def append_content(self, content: str): 63 | if not self._opened: 64 | raise runtime_error("Trying to append content to an unopened stage") 65 | if self._closed: 66 | raise runtime_error("Trying to append content to a closed stage") 67 | 68 | self._queue.put_nowait( 69 | ContentStageChunk(self._choice_index, self._stage_index, content) 70 | ) 71 | 72 | @property 73 | def content_stream(self) -> ContentStream: 74 | return ContentStream(self) 75 | 76 | def append_name(self, name: str): 77 | if not self._opened: 78 | raise runtime_error("Trying to append name to an unopened stage") 79 | if self._closed: 80 | raise runtime_error("Trying to append name to a closed stage") 81 | 82 | self._queue.put_nowait( 83 | NameStageChunk(self._choice_index, self._stage_index, name) 84 | ) 85 | 86 | @overload 87 | def add_attachment(self, attachment: Attachment) -> None: ... 88 | 89 | @overload 90 | def add_attachment( 91 | self, 92 | type: Optional[str] = None, 93 | title: Optional[str] = None, 94 | data: Optional[str] = None, 95 | url: Optional[str] = None, 96 | reference_url: Optional[str] = None, 97 | reference_type: Optional[str] = None, 98 | ) -> None: ... 99 | 100 | def add_attachment(self, *args, **kwargs) -> None: 101 | if not self._opened: 102 | raise runtime_error("Trying to add attachment to an unopened stage") 103 | if self._closed: 104 | raise runtime_error("Trying to add attachment to a closed stage") 105 | 106 | attachment_stage_chunk = None 107 | try: 108 | attachment_stage_chunk = AttachmentStageChunk( 109 | choice_index=self._choice_index, 110 | stage_index=self._stage_index, 111 | attachment_index=self._last_attachment_index, 112 | **create_attachment(*args, **kwargs).model_dump(), 113 | ) 114 | except ValidationError as e: 115 | raise runtime_error(e.errors()[0]["msg"]) 116 | 117 | self._queue.put_nowait(attachment_stage_chunk) 118 | self._last_attachment_index += 1 119 | 120 | def open(self): 121 | if self._opened: 122 | raise runtime_error("The stage is already open") 123 | 124 | self._opened = True 125 | self._queue.put_nowait( 126 | StartStageChunk(self._choice_index, self._stage_index, self._name) 127 | ) 128 | 129 | def close(self, status: Status = Status.COMPLETED): 130 | if not self._opened: 131 | raise runtime_error("Trying to close an unopened stage") 132 | if self._closed: 133 | raise runtime_error("The stage is already closed") 134 | 135 | self._closed = True 136 | self._queue.put_nowait( 137 | FinishStageChunk(self._choice_index, self._stage_index, status) 138 | ) 139 | -------------------------------------------------------------------------------- /tests/utils/tokenization.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional, Set 2 | 3 | from aidial_sdk.chat_completion.request import ( 4 | ChatCompletionRequest, 5 | Message, 6 | Role, 7 | ) 8 | from aidial_sdk.deployment.tokenize import ( 9 | TokenizeInput, 10 | TokenizeOutput, 11 | TokenizeRequest, 12 | TokenizeResponse, 13 | TokenizeSuccess, 14 | ) 15 | from aidial_sdk.deployment.truncate_prompt import ( 16 | TruncatePromptError, 17 | TruncatePromptRequest, 18 | TruncatePromptResponse, 19 | TruncatePromptResult, 20 | TruncatePromptSuccess, 21 | ) 22 | from tests.utils.pydantic import model_copy 23 | 24 | 25 | def word_count_string(string: str) -> int: 26 | return len(string.split()) 27 | 28 | 29 | def word_count_message(message: Message) -> int: 30 | return word_count_string(message.text()) 31 | 32 | 33 | def word_count_request(request: ChatCompletionRequest) -> int: 34 | return sum(map(word_count_message, request.messages)) 35 | 36 | 37 | def word_count_tokenize(request: TokenizeInput) -> TokenizeOutput: 38 | if request.type == "request": 39 | token_count = word_count_request(request.value) 40 | elif request.type == "string": 41 | token_count = word_count_string(request.value) 42 | else: 43 | raise ValueError(f"Unknown tokenize input type: {request.type}") 44 | 45 | return TokenizeSuccess(token_count=token_count) 46 | 47 | 48 | def make_batched_tokenize( 49 | tokenize: Callable[[TokenizeInput], TokenizeOutput] 50 | ) -> Callable[[TokenizeRequest], TokenizeResponse]: 51 | def ret(request: TokenizeRequest) -> TokenizeResponse: 52 | return TokenizeResponse( 53 | outputs=[tokenize(inp) for inp in request.inputs] 54 | ) 55 | 56 | return ret 57 | 58 | 59 | def default_truncate_prompt( 60 | request: ChatCompletionRequest, 61 | count_request_tokens: Callable[[ChatCompletionRequest], int], 62 | model_max_prompt_tokens: int, 63 | ) -> TruncatePromptResult: 64 | def _count_tokens_selected(indices: Set[int]) -> int: 65 | messages = [ 66 | message 67 | for idx, message in enumerate(request.messages) 68 | if idx in indices 69 | ] 70 | sub_request = model_copy(request, update={"messages": messages}) 71 | return count_request_tokens(sub_request) 72 | 73 | all_indices = set(range(0, len(request.messages))) 74 | 75 | max_prompt_tokens: Optional[int] = request.max_prompt_tokens 76 | if max_prompt_tokens is None: 77 | token_count = _count_tokens_selected(all_indices) 78 | if token_count > model_max_prompt_tokens: 79 | return TruncatePromptError( 80 | error=f"Token count of all messages ({token_count}) exceeds" 81 | f" the model maximum prompt tokens ({model_max_prompt_tokens}).", 82 | ) 83 | return TruncatePromptSuccess(discarded_messages=[]) 84 | 85 | token_count: int = 0 86 | found_user_message = False 87 | selected_indices: Set[int] = set() 88 | 89 | for idx in reversed(range(0, len(request.messages))): 90 | message = request.messages[idx] 91 | 92 | is_user_message = message.role == Role.USER 93 | is_last_user_message = not found_user_message and is_user_message 94 | found_user_message = found_user_message or is_user_message 95 | 96 | is_message_required = ( 97 | message.role == Role.SYSTEM or is_last_user_message 98 | ) 99 | 100 | if not is_message_required: 101 | continue 102 | 103 | selected_indices.add(idx) 104 | token_count = _count_tokens_selected(selected_indices) 105 | 106 | if token_count > max_prompt_tokens: 107 | return TruncatePromptError( 108 | error="Token count of the last user message and all system messages " 109 | f"({token_count}) exceeds the maximum prompt tokens ({max_prompt_tokens}).", 110 | ) 111 | 112 | for idx in reversed(range(0, len(request.messages))): 113 | if idx in selected_indices: 114 | continue 115 | 116 | new_token_count = _count_tokens_selected({*selected_indices, idx}) 117 | if new_token_count > max_prompt_tokens: 118 | break 119 | 120 | selected_indices.add(idx) 121 | token_count = new_token_count 122 | 123 | discarded_indices = all_indices - selected_indices 124 | return TruncatePromptSuccess( 125 | discarded_messages=list(sorted(discarded_indices)) 126 | ) 127 | 128 | 129 | def make_batched_truncate_prompt( 130 | truncate_prompt: Callable[[ChatCompletionRequest], TruncatePromptResult], 131 | ) -> Callable[[TruncatePromptRequest], TruncatePromptResponse]: 132 | def ret(request: TruncatePromptRequest) -> TruncatePromptResponse: 133 | return TruncatePromptResponse( 134 | outputs=[truncate_prompt(req) for req in request.inputs] 135 | ) 136 | 137 | return ret 138 | -------------------------------------------------------------------------------- /tests/benchmark/benchmark_merge_chunks.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import timeit 3 | from typing import Iterable, List 4 | 5 | from pydantic import BaseModel 6 | 7 | from aidial_sdk.utils.merge_chunks import merge_chat_completion_chunks 8 | from tests.utils.chunks import create_single_choice_chunk 9 | 10 | 11 | def _interleave(*iters): 12 | sentinel = object() 13 | return [ 14 | elem 15 | for tpl in itertools.zip_longest(*iters, fillvalue=sentinel) 16 | for elem in tpl 17 | if elem is not sentinel 18 | ] 19 | 20 | 21 | class ChunkGenerator(BaseModel): 22 | n_choices: int = 1 23 | n_chunks_per_choice: int = 1 24 | reversed_choices: bool = False 25 | n_attachments_per_choice: int = 0 26 | reversed_attachments: bool = False 27 | 28 | @property 29 | def desc(self) -> str: 30 | r1 = "r" if self.reversed_choices else "" 31 | r2 = "r" if self.reversed_attachments else "" 32 | return f"{self.n_choices}{r1}x({self.n_chunks_per_choice}+{self.n_attachments_per_choice}{r2})" 33 | 34 | def get_stream(self) -> Iterable[dict]: 35 | def _range(n: int, rev: bool): 36 | return reversed(range(n)) if rev else range(n) 37 | 38 | def gen_content(choice_idx: int): 39 | for chunk_idx in range(self.n_chunks_per_choice): 40 | yield create_single_choice_chunk( 41 | choice_idx=choice_idx, 42 | delta={"content": f"{chunk_idx} "}, 43 | ) 44 | 45 | def gen_attachments(choice_idx: int): 46 | for attachment_idx in _range( 47 | self.n_attachments_per_choice, self.reversed_attachments 48 | ): 49 | yield create_single_choice_chunk( 50 | choice_idx=choice_idx, 51 | delta={ 52 | "custom_content": { 53 | "attachments": [ 54 | { 55 | "index": attachment_idx, 56 | "url": f"url{attachment_idx}", 57 | } 58 | ] 59 | } 60 | }, 61 | ) 62 | 63 | yield create_single_choice_chunk( 64 | delta={"role": "assistant", "content": None} 65 | ) 66 | 67 | for choice_idx in _range(self.n_choices, self.reversed_choices): 68 | yield from _interleave( 69 | gen_content(choice_idx), 70 | gen_attachments(choice_idx), 71 | ) 72 | 73 | yield create_single_choice_chunk( 74 | choice_idx=choice_idx, finish_reason="stop" 75 | ) 76 | 77 | 78 | def benchmark(gen: ChunkGenerator, *, repeat: int, number: int | None = None): 79 | def stmt(): 80 | merge_chat_completion_chunks({}, *gen.get_stream()) 81 | 82 | t = timeit.Timer(stmt=stmt) 83 | 84 | if number is None: 85 | number, _ = t.autorange() 86 | 87 | timings = t.repeat(number=number, repeat=repeat) 88 | 89 | best_sec = min(timings) / number 90 | best_msec = best_sec * 1e3 91 | best_usec = best_sec * 1e6 92 | 93 | n_chunks = len(list(gen.get_stream())) 94 | 95 | print( 96 | ",".join( 97 | [ 98 | gen.desc, 99 | str(n_chunks), 100 | str(number), 101 | str(repeat), 102 | f"{best_sec:.3f}", 103 | f"{best_msec:.3f}", 104 | f"{best_usec:.3f}", 105 | ] 106 | ) 107 | ) 108 | 109 | 110 | base_case = ChunkGenerator( 111 | n_choices=10, 112 | reversed_choices=False, 113 | n_chunks_per_choice=10, 114 | n_attachments_per_choice=10, 115 | reversed_attachments=False, 116 | ) 117 | 118 | one_choice = base_case.model_copy(update={"n_choices": 1}) 119 | 120 | cases: List[ChunkGenerator] = [ 121 | base_case, 122 | base_case.model_copy(update={"n_choices": 20}), 123 | base_case.model_copy(update={"n_chunks_per_choice": 20}), 124 | base_case.model_copy(update={"n_attachments_per_choice": 20}), 125 | base_case.model_copy(update={"reversed_choices": True}), 126 | base_case.model_copy(update={"reversed_attachments": True}), 127 | base_case.model_copy( 128 | update={"reversed_choices": True, "reversed_attachments": True} 129 | ), 130 | one_choice.model_copy( 131 | update={"n_chunks_per_choice": 0, "n_attachments_per_choice": 200} 132 | ), 133 | one_choice.model_copy( 134 | update={"n_chunks_per_choice": 0, "n_attachments_per_choice": 400} 135 | ), 136 | one_choice.model_copy( 137 | update={"n_chunks_per_choice": 200, "n_attachments_per_choice": 0} 138 | ), 139 | one_choice.model_copy( 140 | update={"n_chunks_per_choice": 400, "n_attachments_per_choice": 0} 141 | ), 142 | ] 143 | 144 | if __name__ == "__main__": 145 | print("Description,N chunks,Number,Repeats,Best sec,Best msec,Best usec") 146 | for gen in cases: 147 | benchmark(gen, repeat=10) 148 | -------------------------------------------------------------------------------- /aidial_sdk/utils/streaming.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | from typing import ( 4 | Any, 5 | AsyncIterator, 6 | Awaitable, 7 | Callable, 8 | Dict, 9 | Optional, 10 | TypeVar, 11 | Union, 12 | cast, 13 | ) 14 | 15 | from typing_extensions import assert_never 16 | 17 | from aidial_sdk.chat_completion.chunks import BaseChunkWithDefaults 18 | from aidial_sdk.exceptions import HTTPException as DIALException 19 | from aidial_sdk.utils._cancel_scope import CancelScope 20 | from aidial_sdk.utils.logging import log_debug 21 | from aidial_sdk.utils.merge_chunks import cleanup_indices, merge 22 | 23 | _DONE_MARKER = "[DONE]" 24 | 25 | 26 | async def merge_chunks(chunk_stream: AsyncIterator[dict]) -> Dict[str, Any]: 27 | response: Dict[str, Any] = {} 28 | async for chunk in chunk_stream: 29 | response = merge(response, chunk) 30 | 31 | for choice in response["choices"]: 32 | choice["message"] = cleanup_indices(choice["delta"]) 33 | del choice["delta"] 34 | 35 | return response 36 | 37 | 38 | def _format_chunk(data: Union[dict, str]) -> str: 39 | data = "data: " + ( 40 | json.dumps(data, separators=(",", ":")) 41 | if isinstance(data, dict) 42 | else data 43 | ) 44 | log_debug(data) 45 | return f"{data}\n\n" 46 | 47 | 48 | ResponseStream = AsyncIterator[Union[BaseChunkWithDefaults, DIALException]] 49 | 50 | ResponseStreamWithStr = AsyncIterator[ 51 | Union[BaseChunkWithDefaults, DIALException, str] 52 | ] 53 | 54 | 55 | async def _handle_exceptions_in_block_response( 56 | stream: ResponseStream, 57 | ) -> AsyncIterator[dict]: 58 | is_first_chunk = True 59 | 60 | async for chunk in stream: 61 | if isinstance(chunk, DIALException): 62 | raise chunk.to_fastapi_exception() 63 | else: 64 | # Setting defaults only for the first chunk to make 65 | # the follow-up merging logic simpler. 66 | yield chunk.to_dict(with_defaults=is_first_chunk) 67 | 68 | is_first_chunk = False 69 | 70 | 71 | async def to_block_response(stream: ResponseStream) -> dict: 72 | chunk_stream = _handle_exceptions_in_block_response(stream) 73 | return await merge_chunks(chunk_stream) 74 | 75 | 76 | async def to_streaming_response( 77 | stream: ResponseStreamWithStr, 78 | ) -> AsyncIterator[str]: 79 | 80 | first_chunk = await stream.__anext__() 81 | 82 | if isinstance(first_chunk, DIALException): 83 | raise first_chunk.to_fastapi_exception() 84 | 85 | def _chunk_to_str( 86 | chunk: Union[BaseChunkWithDefaults, DIALException, str] 87 | ) -> str: 88 | if isinstance(chunk, DIALException): 89 | return _format_chunk(chunk.json_error()) 90 | elif isinstance(chunk, str): 91 | return chunk 92 | elif isinstance(chunk, BaseChunkWithDefaults): 93 | return _format_chunk(chunk.to_dict(with_defaults=True)) 94 | else: 95 | assert_never(chunk) 96 | 97 | async def _generator() -> AsyncIterator[str]: 98 | yield _chunk_to_str(first_chunk) 99 | 100 | async for chunk in stream: 101 | yield _chunk_to_str(chunk) 102 | 103 | yield _format_chunk(_DONE_MARKER) 104 | 105 | return _generator() 106 | 107 | 108 | _T = TypeVar("_T") 109 | 110 | _HeartbeatObject = Union[_T, Callable[[], Union[_T, Awaitable[_T]]]] 111 | _HeartbeatCallback = Callable[[], Union[None, Awaitable[None]]] 112 | 113 | 114 | async def _eval_heartbeat_object(o: _HeartbeatObject[_T]) -> _T: 115 | if callable(o): 116 | result = o() 117 | if isinstance(result, Awaitable): 118 | return await result 119 | return cast(_T, result) 120 | return o 121 | 122 | 123 | async def _call_heartbeat_callback(c: _HeartbeatCallback) -> None: 124 | result = c() 125 | if isinstance(result, Awaitable): 126 | await result 127 | 128 | 129 | async def add_heartbeat( 130 | stream: AsyncIterator[_T], 131 | *, 132 | heartbeat_interval: float, 133 | heartbeat_object: Optional[_HeartbeatObject] = None, 134 | heartbeat_callback: Optional[_HeartbeatCallback] = None, 135 | ) -> AsyncIterator[_T]: 136 | async with CancelScope() as cs: 137 | chunk_task: Optional[asyncio.Task[_T]] = None 138 | 139 | while True: 140 | if chunk_task is None: 141 | chunk_task = cs.create_task(stream.__anext__()) 142 | 143 | done = ( 144 | await asyncio.wait( 145 | [chunk_task], 146 | timeout=heartbeat_interval, 147 | return_when=asyncio.FIRST_COMPLETED, 148 | ) 149 | )[0] 150 | 151 | if chunk_task in done: 152 | try: 153 | chunk, chunk_task = chunk_task.result(), None 154 | yield chunk 155 | except StopAsyncIteration: 156 | break 157 | else: 158 | if heartbeat_object is not None: 159 | yield await _eval_heartbeat_object(heartbeat_object) 160 | 161 | if heartbeat_callback is not None: 162 | await _call_heartbeat_callback(heartbeat_callback) 163 | --------------------------------------------------------------------------------