├── .devcontainer ├── Dockerfile └── devcontainer.json ├── .github └── workflows │ ├── ci.yml │ ├── publish-pypi.yml │ └── release-doctor.yml ├── .gitignore ├── .python-version ├── .release-please-manifest.json ├── .stats.yml ├── Brewfile ├── CHANGELOG.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── SECURITY.md ├── api.md ├── bin ├── check-release-environment └── publish-pypi ├── examples ├── .keep └── cookbooks │ ├── Atla_Selene_Absolute_Scoring.ipynb │ ├── Atla_Selene_Hallucination.ipynb │ ├── Atla_Selene_Mini_Guardrails.ipynb │ ├── Atla_Selene_Model_Selection.ipynb │ ├── Atla_Selene_Multi_Criteria_Evals.ipynb │ ├── Atla_Selene_Prompt_Improvement.ipynb │ ├── README.md │ └── langfuse │ ├── Atla_Langfuse_Monitoring.ipynb │ ├── Atla_Langfuse_Offline_Evals.ipynb │ └── README.md ├── mypy.ini ├── noxfile.py ├── pyproject.toml ├── release-please-config.json ├── requirements-dev.lock ├── requirements.lock ├── scripts ├── bootstrap ├── format ├── lint ├── mock ├── test └── utils │ └── ruffen-docs.py ├── src └── atla │ ├── __init__.py │ ├── _base_client.py │ ├── _client.py │ ├── _compat.py │ ├── _constants.py │ ├── _exceptions.py │ ├── _files.py │ ├── _models.py │ ├── _qs.py │ ├── _resource.py │ ├── _response.py │ ├── _streaming.py │ ├── _types.py │ ├── _utils │ ├── __init__.py │ ├── _logs.py │ ├── _proxy.py │ ├── _reflection.py │ ├── _streams.py │ ├── _sync.py │ ├── _transform.py │ ├── _typing.py │ └── _utils.py │ ├── _version.py │ ├── lib │ └── .keep │ ├── py.typed │ ├── resources │ ├── __init__.py │ ├── chat │ │ ├── __init__.py │ │ ├── chat.py │ │ └── completions.py │ ├── evaluation.py │ └── metrics │ │ ├── __init__.py │ │ ├── few_shot_examples.py │ │ ├── metrics.py │ │ └── prompts.py │ └── types │ ├── __init__.py │ ├── chat │ ├── __init__.py │ └── completion_create_params.py │ ├── chat_completion.py │ ├── evaluation.py │ ├── evaluation_create_params.py │ ├── metric.py │ ├── metric_create_params.py │ ├── metric_create_response.py │ ├── metric_delete_response.py │ ├── metric_get_response.py │ ├── metric_list_params.py │ ├── metric_list_response.py │ └── metrics │ ├── __init__.py │ ├── few_shot_example.py │ ├── few_shot_example_param.py │ ├── few_shot_example_set_params.py │ ├── few_shot_example_set_response.py │ ├── prompt.py │ ├── prompt_create_params.py │ ├── prompt_create_response.py │ ├── prompt_get_response.py │ ├── prompt_list_response.py │ ├── prompt_set_active_prompt_version_params.py │ └── prompt_set_active_prompt_version_response.py └── tests ├── __init__.py ├── api_resources ├── __init__.py ├── chat │ ├── __init__.py │ └── test_completions.py ├── metrics │ ├── __init__.py │ ├── test_few_shot_examples.py │ └── test_prompts.py ├── test_evaluation.py └── test_metrics.py ├── conftest.py ├── sample_file.txt ├── test_client.py ├── test_deepcopy.py ├── test_extract_files.py ├── test_files.py ├── test_models.py ├── test_qs.py ├── test_required_args.py ├── test_response.py ├── test_streaming.py ├── test_transform.py ├── test_utils ├── test_proxy.py └── test_typing.py └── utils.py /.devcontainer/Dockerfile: -------------------------------------------------------------------------------- 1 | ARG VARIANT="3.9" 2 | FROM mcr.microsoft.com/vscode/devcontainers/python:0-${VARIANT} 3 | 4 | USER vscode 5 | 6 | RUN curl -sSf https://rye.astral.sh/get | RYE_VERSION="0.44.0" RYE_INSTALL_OPTION="--yes" bash 7 | ENV PATH=/home/vscode/.rye/shims:$PATH 8 | 9 | RUN echo "[[ -d .venv ]] && source .venv/bin/activate || export PATH=\$PATH" >> /home/vscode/.bashrc 10 | -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | // For format details, see https://aka.ms/devcontainer.json. For config options, see the 2 | // README at: https://github.com/devcontainers/templates/tree/main/src/debian 3 | { 4 | "name": "Debian", 5 | "build": { 6 | "dockerfile": "Dockerfile", 7 | "context": ".." 8 | }, 9 | 10 | "postStartCommand": "rye sync --all-features", 11 | 12 | "customizations": { 13 | "vscode": { 14 | "extensions": [ 15 | "ms-python.python" 16 | ], 17 | "settings": { 18 | "terminal.integrated.shell.linux": "/bin/bash", 19 | "python.pythonPath": ".venv/bin/python", 20 | "python.defaultInterpreterPath": ".venv/bin/python", 21 | "python.typeChecking": "basic", 22 | "terminal.integrated.env.linux": { 23 | "PATH": "/home/vscode/.rye/shims:${env:PATH}" 24 | } 25 | } 26 | } 27 | }, 28 | "features": { 29 | "ghcr.io/devcontainers/features/node:1": {} 30 | } 31 | 32 | // Features to add to the dev container. More info: https://containers.dev/features. 33 | // "features": {}, 34 | 35 | // Use 'forwardPorts' to make a list of ports inside the container available locally. 36 | // "forwardPorts": [], 37 | 38 | // Configure tool-specific properties. 39 | // "customizations": {}, 40 | 41 | // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. 42 | // "remoteUser": "root" 43 | } 44 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | on: 3 | push: 4 | branches-ignore: 5 | - 'generated' 6 | - 'codegen/**' 7 | - 'integrated/**' 8 | - 'stl-preview-head/**' 9 | - 'stl-preview-base/**' 10 | 11 | jobs: 12 | lint: 13 | timeout-minutes: 10 14 | name: lint 15 | runs-on: ${{ github.repository == 'stainless-sdks/atla-python' && 'depot-ubuntu-24.04' || 'ubuntu-latest' }} 16 | steps: 17 | - uses: actions/checkout@v4 18 | 19 | - name: Install Rye 20 | run: | 21 | curl -sSf https://rye.astral.sh/get | bash 22 | echo "$HOME/.rye/shims" >> $GITHUB_PATH 23 | env: 24 | RYE_VERSION: '0.44.0' 25 | RYE_INSTALL_OPTION: '--yes' 26 | 27 | - name: Install dependencies 28 | run: rye sync --all-features 29 | 30 | - name: Run lints 31 | run: ./scripts/lint 32 | 33 | test: 34 | timeout-minutes: 10 35 | name: test 36 | runs-on: ${{ github.repository == 'stainless-sdks/atla-python' && 'depot-ubuntu-24.04' || 'ubuntu-latest' }} 37 | steps: 38 | - uses: actions/checkout@v4 39 | 40 | - name: Install Rye 41 | run: | 42 | curl -sSf https://rye.astral.sh/get | bash 43 | echo "$HOME/.rye/shims" >> $GITHUB_PATH 44 | env: 45 | RYE_VERSION: '0.44.0' 46 | RYE_INSTALL_OPTION: '--yes' 47 | 48 | - name: Bootstrap 49 | run: ./scripts/bootstrap 50 | 51 | - name: Run tests 52 | run: ./scripts/test 53 | -------------------------------------------------------------------------------- /.github/workflows/publish-pypi.yml: -------------------------------------------------------------------------------- 1 | # This workflow is triggered when a GitHub release is created. 2 | # It can also be run manually to re-publish to PyPI in case it failed for some reason. 3 | # You can run this workflow by navigating to https://www.github.com/atla-ai/atla-sdk-python/actions/workflows/publish-pypi.yml 4 | name: Publish PyPI 5 | on: 6 | workflow_dispatch: 7 | 8 | release: 9 | types: [published] 10 | 11 | jobs: 12 | publish: 13 | name: publish 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v4 18 | 19 | - name: Install Rye 20 | run: | 21 | curl -sSf https://rye.astral.sh/get | bash 22 | echo "$HOME/.rye/shims" >> $GITHUB_PATH 23 | env: 24 | RYE_VERSION: '0.44.0' 25 | RYE_INSTALL_OPTION: '--yes' 26 | 27 | - name: Publish to PyPI 28 | run: | 29 | bash ./bin/publish-pypi 30 | env: 31 | PYPI_TOKEN: ${{ secrets.ATLA_PYPI_TOKEN || secrets.PYPI_TOKEN }} 32 | -------------------------------------------------------------------------------- /.github/workflows/release-doctor.yml: -------------------------------------------------------------------------------- 1 | name: Release Doctor 2 | on: 3 | pull_request: 4 | branches: 5 | - main 6 | workflow_dispatch: 7 | 8 | jobs: 9 | release_doctor: 10 | name: release doctor 11 | runs-on: ubuntu-latest 12 | if: github.repository == 'atla-ai/atla-sdk-python' && (github.event_name == 'push' || github.event_name == 'workflow_dispatch' || startsWith(github.head_ref, 'release-please') || github.head_ref == 'next') 13 | 14 | steps: 15 | - uses: actions/checkout@v4 16 | 17 | - name: Check release environment 18 | run: | 19 | bash ./bin/check-release-environment 20 | env: 21 | PYPI_TOKEN: ${{ secrets.ATLA_PYPI_TOKEN || secrets.PYPI_TOKEN }} 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .prism.log 2 | .vscode 3 | _dev 4 | 5 | __pycache__ 6 | .mypy_cache 7 | 8 | dist 9 | 10 | .venv 11 | .idea 12 | 13 | .env 14 | .envrc 15 | codegen.log 16 | Brewfile.lock.json 17 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.9.18 2 | -------------------------------------------------------------------------------- /.release-please-manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | ".": "0.6.2" 3 | } -------------------------------------------------------------------------------- /.stats.yml: -------------------------------------------------------------------------------- 1 | configured_endpoints: 11 2 | openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/atla-ai%2Fatla-b96f3c8b1e2142d2afc0ca76a932b25808b86a040f523d46a83643b87e34406d.yml 3 | openapi_spec_hash: 38198fda99be9f85013c7b8bbe0aa828 4 | config_hash: daa189f0b068e63c9b7a867beabea21e 5 | -------------------------------------------------------------------------------- /Brewfile: -------------------------------------------------------------------------------- 1 | brew "rye" 2 | 3 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## Setting up the environment 2 | 3 | ### With Rye 4 | 5 | We use [Rye](https://rye.astral.sh/) to manage dependencies because it will automatically provision a Python environment with the expected Python version. To set it up, run: 6 | 7 | ```sh 8 | $ ./scripts/bootstrap 9 | ``` 10 | 11 | Or [install Rye manually](https://rye.astral.sh/guide/installation/) and run: 12 | 13 | ```sh 14 | $ rye sync --all-features 15 | ``` 16 | 17 | You can then run scripts using `rye run python script.py` or by activating the virtual environment: 18 | 19 | ```sh 20 | $ rye shell 21 | # or manually activate - https://docs.python.org/3/library/venv.html#how-venvs-work 22 | $ source .venv/bin/activate 23 | 24 | # now you can omit the `rye run` prefix 25 | $ python script.py 26 | ``` 27 | 28 | ### Without Rye 29 | 30 | Alternatively if you don't want to install `Rye`, you can stick with the standard `pip` setup by ensuring you have the Python version specified in `.python-version`, create a virtual environment however you desire and then install dependencies using this command: 31 | 32 | ```sh 33 | $ pip install -r requirements-dev.lock 34 | ``` 35 | 36 | ## Modifying/Adding code 37 | 38 | Most of the SDK is generated code. Modifications to code will be persisted between generations, but may 39 | result in merge conflicts between manual patches and changes from the generator. The generator will never 40 | modify the contents of the `src/atla/lib/` and `examples/` directories. 41 | 42 | ## Adding and running examples 43 | 44 | All files in the `examples/` directory are not modified by the generator and can be freely edited or added to. 45 | 46 | ```py 47 | # add an example to examples/.py 48 | 49 | #!/usr/bin/env -S rye run python 50 | … 51 | ``` 52 | 53 | ```sh 54 | $ chmod +x examples/.py 55 | # run the example against your api 56 | $ ./examples/.py 57 | ``` 58 | 59 | ## Using the repository from source 60 | 61 | If you’d like to use the repository from source, you can either install from git or link to a cloned repository: 62 | 63 | To install via git: 64 | 65 | ```sh 66 | $ pip install git+ssh://git@github.com/atla-ai/atla-sdk-python.git 67 | ``` 68 | 69 | Alternatively, you can build from source and install the wheel file: 70 | 71 | Building this package will create two files in the `dist/` directory, a `.tar.gz` containing the source files and a `.whl` that can be used to install the package efficiently. 72 | 73 | To create a distributable version of the library, all you have to do is run this command: 74 | 75 | ```sh 76 | $ rye build 77 | # or 78 | $ python -m build 79 | ``` 80 | 81 | Then to install: 82 | 83 | ```sh 84 | $ pip install ./path-to-wheel-file.whl 85 | ``` 86 | 87 | ## Running tests 88 | 89 | Most tests require you to [set up a mock server](https://github.com/stoplightio/prism) against the OpenAPI spec to run the tests. 90 | 91 | ```sh 92 | # you will need npm installed 93 | $ npx prism mock path/to/your/openapi.yml 94 | ``` 95 | 96 | ```sh 97 | $ ./scripts/test 98 | ``` 99 | 100 | ## Linting and formatting 101 | 102 | This repository uses [ruff](https://github.com/astral-sh/ruff) and 103 | [black](https://github.com/psf/black) to format the code in the repository. 104 | 105 | To lint: 106 | 107 | ```sh 108 | $ ./scripts/lint 109 | ``` 110 | 111 | To format and fix all ruff issues automatically: 112 | 113 | ```sh 114 | $ ./scripts/format 115 | ``` 116 | 117 | ## Publishing and releases 118 | 119 | Changes made to this repository via the automated release PR pipeline should publish to PyPI automatically. If 120 | the changes aren't made through the automated pipeline, you may want to make releases manually. 121 | 122 | ### Publish with a GitHub workflow 123 | 124 | You can release to package managers by using [the `Publish PyPI` GitHub action](https://www.github.com/atla-ai/atla-sdk-python/actions/workflows/publish-pypi.yml). This requires a setup organization or repository secret to be set up. 125 | 126 | ### Publish manually 127 | 128 | If you need to manually release a package, you can run the `bin/publish-pypi` script with a `PYPI_TOKEN` set on 129 | the environment. 130 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | 3 | ## Reporting Security Issues 4 | 5 | This SDK is generated by [Stainless Software Inc](http://stainless.com). Stainless takes security seriously, and encourages you to report any security vulnerability promptly so that appropriate action can be taken. 6 | 7 | To report a security issue, please contact the Stainless team at security@stainless.com. 8 | 9 | ## Responsible Disclosure 10 | 11 | We appreciate the efforts of security researchers and individuals who help us maintain the security of 12 | SDKs we generate. If you believe you have found a security vulnerability, please adhere to responsible 13 | disclosure practices by allowing us a reasonable amount of time to investigate and address the issue 14 | before making any information public. 15 | 16 | ## Reporting Non-SDK Related Security Issues 17 | 18 | If you encounter security issues that are not directly related to SDKs but pertain to the services 19 | or products provided by Atla please follow the respective company's security reporting guidelines. 20 | 21 | ### Atla Terms and Policies 22 | 23 | Our Security Policy can be found at [Security Policy URL](https://www.atla-ai.com/terms-of-service). 24 | 25 | Please contact support@atla-ai.com for any questions or concerns regarding security of our services. 26 | 27 | --- 28 | 29 | Thank you for helping us keep the SDKs and systems they interact with secure. 30 | -------------------------------------------------------------------------------- /api.md: -------------------------------------------------------------------------------- 1 | # Chat 2 | 3 | Types: 4 | 5 | ```python 6 | from atla.types import ChatCompletion 7 | ``` 8 | 9 | ## Completions 10 | 11 | Methods: 12 | 13 | - client.chat.completions.create(\*\*params) -> ChatCompletion 14 | 15 | # Evaluation 16 | 17 | Types: 18 | 19 | ```python 20 | from atla.types import Evaluation 21 | ``` 22 | 23 | Methods: 24 | 25 | - client.evaluation.create(\*\*params) -> Evaluation 26 | 27 | # Metrics 28 | 29 | Types: 30 | 31 | ```python 32 | from atla.types import ( 33 | Metric, 34 | MetricCreateResponse, 35 | MetricListResponse, 36 | MetricDeleteResponse, 37 | MetricGetResponse, 38 | ) 39 | ``` 40 | 41 | Methods: 42 | 43 | - client.metrics.create(\*\*params) -> MetricCreateResponse 44 | - client.metrics.list(\*\*params) -> MetricListResponse 45 | - client.metrics.delete(metric_id) -> MetricDeleteResponse 46 | - client.metrics.get(metric_id) -> MetricGetResponse 47 | 48 | ## Prompts 49 | 50 | Types: 51 | 52 | ```python 53 | from atla.types.metrics import ( 54 | Prompt, 55 | PromptCreateResponse, 56 | PromptListResponse, 57 | PromptGetResponse, 58 | PromptSetActivePromptVersionResponse, 59 | ) 60 | ``` 61 | 62 | Methods: 63 | 64 | - client.metrics.prompts.create(metric_id, \*\*params) -> PromptCreateResponse 65 | - client.metrics.prompts.list(metric_id) -> PromptListResponse 66 | - client.metrics.prompts.get(version, \*, metric_id) -> PromptGetResponse 67 | - client.metrics.prompts.set_active_prompt_version(metric_id, \*\*params) -> PromptSetActivePromptVersionResponse 68 | 69 | ## FewShotExamples 70 | 71 | Types: 72 | 73 | ```python 74 | from atla.types.metrics import FewShotExample, FewShotExampleSetResponse 75 | ``` 76 | 77 | Methods: 78 | 79 | - client.metrics.few_shot_examples.set(metric_id, \*\*params) -> FewShotExampleSetResponse 80 | -------------------------------------------------------------------------------- /bin/check-release-environment: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | errors=() 4 | 5 | if [ -z "${PYPI_TOKEN}" ]; then 6 | errors+=("The ATLA_PYPI_TOKEN secret has not been set. Please set it in either this repository's secrets or your organization secrets.") 7 | fi 8 | 9 | lenErrors=${#errors[@]} 10 | 11 | if [[ lenErrors -gt 0 ]]; then 12 | echo -e "Found the following errors in the release environment:\n" 13 | 14 | for error in "${errors[@]}"; do 15 | echo -e "- $error\n" 16 | done 17 | 18 | exit 1 19 | fi 20 | 21 | echo "The environment is ready to push releases!" 22 | -------------------------------------------------------------------------------- /bin/publish-pypi: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -eux 4 | mkdir -p dist 5 | rye build --clean 6 | rye publish --yes --token=$PYPI_TOKEN 7 | -------------------------------------------------------------------------------- /examples/.keep: -------------------------------------------------------------------------------- 1 | File generated from our OpenAPI spec by Stainless. 2 | 3 | This directory can be used to store example files demonstrating usage of this SDK. 4 | It is ignored by Stainless code generation and its content (other than this keep file) won't be touched. -------------------------------------------------------------------------------- /examples/cookbooks/README.md: -------------------------------------------------------------------------------- 1 | # Choosing a Model (Guide) 2 | 3 | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/atla-ai/atla-sdk-python/blob/main/examples/cookbooks/Atla_Selene_Model_Selection.ipynb) 4 | 5 | This cookbook presents a structured way to approach picking the right model for your use case. 6 | 7 | We take Chat as an example use case, where we build a playful and helpful assistant that is cost-effective. We evaluate the performance of two popular models against criteria we are interested in - clarity, objectivity and tone. 8 | 9 | We demonstrate how Selene can be used to guide the decision. 10 | 11 | # Improving your Prompts (Guide) 12 | 13 | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/atla-ai/atla-sdk-python/blob/main/examples/cookbooks/Atla_Selene_Prompt_Improvement.ipynb) 14 | 15 | This cookbook presents a structured way to improve your prompts to get the best out of your foundation model for your use case. 16 | 17 | We take Chat as an example use case and demonstrate how Selene can be used to guide the decision. 18 | 19 | # Implementing Guardrails (Guide) 20 | 21 | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/atla-ai/atla-sdk-python/blob/main/examples/cookbooks/Atla_Selene_Mini_Guardrails.ipynb) 22 | 23 | This cookbook demonstrates how to implement **inference-time guardrails to validate and filter your AI outputs.** We evaluate GPT-4o outputs against example safety dimensions (toxicity, bias, and medical advice) to replace problematic outputs before they are delivered to users. 24 | 25 | We use Selene Mini, our state-of-the-art small-LLM-as-a-Judge that excels in low latency use cases. 26 | 27 | # Absolute Scoring (Tutorial) 28 | 29 | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/atla-ai/atla-sdk-python/blob/main/examples/cookbooks/Atla_Selene_Absolute_Scoring.ipynb) 30 | 31 | This cookbook gets you started running evals with absolute scores using Selene, and does so on a sample set from the public benchmark [FLASK](https://arxiv.org/pdf/2307.10928) dataset - a collection of 1,740 human-annotated samples from 120 NLP datasets. Evaluators assign scores ranging from 1 to 5 for each annotated skill based on the reference (ground-truth) answer and skill-specific scoring rubrics. 32 | 33 | We evaluate logical robustness (whether the model avoids logical contradictions in its reasoning) and completeness (whether the response provides sufficient explanation) using default and custom-defined metrics respectively, then compare how Selene's scores align with the human labels. 34 | 35 | # Hallucination Scoring (Tutorial) 36 | 37 | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/atla-ai/atla-sdk-python/blob/main/examples/cookbooks/Atla_Selene_Hallucination.ipynb) 38 | 39 | This cookbook gets you started detecting hallucinations using Selene, and runs over a sample set from the public benchmark [RAGTruth](https://arxiv.org/abs/2401.00396) benchmark - a large-scale corpus of naturally generated hallucinations, featuring detailed word-level annotations specifically designed for retrieval-augmented generation (RAG) scenarios. 40 | 41 | We check for hallucination in AI responses i.e. 'Is the information provided in the response directly supported by the context given in the related passages?' and compare how Selene's scores align with the human labels. 42 | 43 | # Multi-Criteria Evals (Tutorial) 44 | 45 | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/atla-ai/atla-sdk-python/blob/main/examples/cookbooks/Atla_Selene_Multi_Criteria_Evals.ipynb) 46 | 47 | This cookbook gets you started on running multi-criteria evals with Selene, to help you get a comprehensive picture of your model's performance. We follow eval best practices by evaluating each criterion as an individual metric to receive clearer insights and more reliable scores. 48 | 49 | The first section will show you how to run multi-criteria evals on one/many datapoints across 3 criteria using our async client. The second section will showcase how our model performs on multi-criteria evals, across 12 criteria on the public [FLASK](https://arxiv.org/pdf/2307.10928) dataset. 50 | 51 | # Atla on Langfuse 52 | 53 | You can use Selene as an LLM Judge in Langfuse to monitor your app’s performance in production using traces, as well as to run experiments over datasets pre-production. We provide demo videos and cookbooks for both use cases. Click [here](https://github.com/atla-ai/atla-sdk-python/blob/main/examples/cookbooks/langfuse) to go to our Langfuse cookbooks. 54 | 55 | # Contact 56 | Get in touch with us if there's another use case you'd like to see a cookbook for! 57 | 58 |

59 | 60 | 61 | 62 |

-------------------------------------------------------------------------------- /examples/cookbooks/langfuse/README.md: -------------------------------------------------------------------------------- 1 | # Monitoring (Tutorial) 2 | 3 | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/atla-ai/atla-sdk-python/blob/main/examples/cookbooks/langfuse/Atla_Langfuse_Monitoring.ipynb) 4 | 5 | This cookbook builds a Gradio application with a complete RAG pipeline. The app is a simple chatbot that answers questions based on a single webpage, which is set to Google’s Q4 2024 earnings call transcript. 6 | 7 | Traces will automatically be sent to Langfuse and scored by Selene. The evaluation example in this cookbook is evaluating the retrieval component of the RAG app by assessing ‘context relevance.’ 8 | 9 |
10 | 11 | # Offline Evals (Tutorial) 12 | 13 | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/atla-ai/atla-sdk-python/blob/main/examples/cookbooks/langfuse/Atla_Langfuse_Offline_Evals.ipynb) 14 | 15 | This cookbook compares the performance of various models (o1-mini, o3-mini, and gpt-4o) on function calling tasks using the [Salesforce ShareGPT dataset](https://huggingface.co/datasets/arcee-ai/agent-data/viewer/default/train?f%5Bdataset%5D%5Bvalue%5D=%27glaive-function-calling-v2-extended%27&sql=SELECT+*%0AFROM+train%0AWHERE+dataset+%3D+%27salesforce_sharegpt%27%0ALIMIT+10%3B&views%5B%5D=train). The notebook uploads the dataset to Langfuse and sets up experiment runs on different models. The various outputs are automatically evaluated by Selene. 16 | 17 | # Contact 18 | Get in touch with us if there's another use case you'd like to see a cookbook for! 19 | 20 |

21 | 22 | 23 | 24 |

-------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | pretty = True 3 | show_error_codes = True 4 | 5 | # Exclude _files.py because mypy isn't smart enough to apply 6 | # the correct type narrowing and as this is an internal module 7 | # it's fine to just use Pyright. 8 | # 9 | # We also exclude our `tests` as mypy doesn't always infer 10 | # types correctly and Pyright will still catch any type errors. 11 | exclude = ^(src/atla/_files\.py|_dev/.*\.py|tests/.*)$ 12 | 13 | strict_equality = True 14 | implicit_reexport = True 15 | check_untyped_defs = True 16 | no_implicit_optional = True 17 | 18 | warn_return_any = True 19 | warn_unreachable = True 20 | warn_unused_configs = True 21 | 22 | # Turn these options off as it could cause conflicts 23 | # with the Pyright options. 24 | warn_unused_ignores = False 25 | warn_redundant_casts = False 26 | 27 | disallow_any_generics = True 28 | disallow_untyped_defs = True 29 | disallow_untyped_calls = True 30 | disallow_subclassing_any = True 31 | disallow_incomplete_defs = True 32 | disallow_untyped_decorators = True 33 | cache_fine_grained = True 34 | 35 | # By default, mypy reports an error if you assign a value to the result 36 | # of a function call that doesn't return anything. We do this in our test 37 | # cases: 38 | # ``` 39 | # result = ... 40 | # assert result is None 41 | # ``` 42 | # Changing this codegen to make mypy happy would increase complexity 43 | # and would not be worth it. 44 | disable_error_code = func-returns-value,overload-cannot-match 45 | 46 | # https://github.com/python/mypy/issues/12162 47 | [mypy.overrides] 48 | module = "black.files.*" 49 | ignore_errors = true 50 | ignore_missing_imports = true 51 | -------------------------------------------------------------------------------- /noxfile.py: -------------------------------------------------------------------------------- 1 | import nox 2 | 3 | 4 | @nox.session(reuse_venv=True, name="test-pydantic-v1") 5 | def test_pydantic_v1(session: nox.Session) -> None: 6 | session.install("-r", "requirements-dev.lock") 7 | session.install("pydantic<2") 8 | 9 | session.run("pytest", "--showlocals", "--ignore=tests/functional", *session.posargs) 10 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "atla" 3 | version = "0.6.2" 4 | description = "The official Python library for the atla API" 5 | dynamic = ["readme"] 6 | license = "Apache-2.0" 7 | authors = [ 8 | { name = "Atla", email = "support@atla-ai.com" }, 9 | ] 10 | dependencies = [ 11 | "httpx>=0.23.0, <1", 12 | "pydantic>=1.9.0, <3", 13 | "typing-extensions>=4.10, <5", 14 | "anyio>=3.5.0, <5", 15 | "distro>=1.7.0, <2", 16 | "sniffio", 17 | ] 18 | requires-python = ">= 3.8" 19 | classifiers = [ 20 | "Typing :: Typed", 21 | "Intended Audience :: Developers", 22 | "Programming Language :: Python :: 3.8", 23 | "Programming Language :: Python :: 3.9", 24 | "Programming Language :: Python :: 3.10", 25 | "Programming Language :: Python :: 3.11", 26 | "Programming Language :: Python :: 3.12", 27 | "Operating System :: OS Independent", 28 | "Operating System :: POSIX", 29 | "Operating System :: MacOS", 30 | "Operating System :: POSIX :: Linux", 31 | "Operating System :: Microsoft :: Windows", 32 | "Topic :: Software Development :: Libraries :: Python Modules", 33 | "License :: OSI Approved :: Apache Software License" 34 | ] 35 | 36 | [project.urls] 37 | Homepage = "https://github.com/atla-ai/atla-sdk-python" 38 | Repository = "https://github.com/atla-ai/atla-sdk-python" 39 | 40 | 41 | [tool.rye] 42 | managed = true 43 | # version pins are in requirements-dev.lock 44 | dev-dependencies = [ 45 | "pyright==1.1.399", 46 | "mypy", 47 | "respx", 48 | "pytest", 49 | "pytest-asyncio", 50 | "ruff", 51 | "time-machine", 52 | "nox", 53 | "dirty-equals>=0.6.0", 54 | "importlib-metadata>=6.7.0", 55 | "rich>=13.7.1", 56 | "nest_asyncio==1.6.0", 57 | ] 58 | 59 | [tool.rye.scripts] 60 | format = { chain = [ 61 | "format:ruff", 62 | "format:docs", 63 | "fix:ruff", 64 | # run formatting again to fix any inconsistencies when imports are stripped 65 | "format:ruff", 66 | ]} 67 | "format:docs" = "python scripts/utils/ruffen-docs.py README.md api.md" 68 | "format:ruff" = "ruff format" 69 | 70 | "lint" = { chain = [ 71 | "check:ruff", 72 | "typecheck", 73 | "check:importable", 74 | ]} 75 | "check:ruff" = "ruff check ." 76 | "fix:ruff" = "ruff check --fix ." 77 | 78 | "check:importable" = "python -c 'import atla'" 79 | 80 | typecheck = { chain = [ 81 | "typecheck:pyright", 82 | "typecheck:mypy" 83 | ]} 84 | "typecheck:pyright" = "pyright" 85 | "typecheck:verify-types" = "pyright --verifytypes atla --ignoreexternal" 86 | "typecheck:mypy" = "mypy ." 87 | 88 | [build-system] 89 | requires = ["hatchling==1.26.3", "hatch-fancy-pypi-readme"] 90 | build-backend = "hatchling.build" 91 | 92 | [tool.hatch.build] 93 | include = [ 94 | "src/*" 95 | ] 96 | 97 | [tool.hatch.build.targets.wheel] 98 | packages = ["src/atla"] 99 | 100 | [tool.hatch.build.targets.sdist] 101 | # Basically everything except hidden files/directories (such as .github, .devcontainers, .python-version, etc) 102 | include = [ 103 | "/*.toml", 104 | "/*.json", 105 | "/*.lock", 106 | "/*.md", 107 | "/mypy.ini", 108 | "/noxfile.py", 109 | "bin/*", 110 | "examples/*", 111 | "src/*", 112 | "tests/*", 113 | ] 114 | 115 | [tool.hatch.metadata.hooks.fancy-pypi-readme] 116 | content-type = "text/markdown" 117 | 118 | [[tool.hatch.metadata.hooks.fancy-pypi-readme.fragments]] 119 | path = "README.md" 120 | 121 | [[tool.hatch.metadata.hooks.fancy-pypi-readme.substitutions]] 122 | # replace relative links with absolute links 123 | pattern = '\[(.+?)\]\(((?!https?://)\S+?)\)' 124 | replacement = '[\1](https://github.com/atla-ai/atla-sdk-python/tree/main/\g<2>)' 125 | 126 | [tool.pytest.ini_options] 127 | testpaths = ["tests"] 128 | addopts = "--tb=short" 129 | xfail_strict = true 130 | asyncio_mode = "auto" 131 | asyncio_default_fixture_loop_scope = "session" 132 | filterwarnings = [ 133 | "error" 134 | ] 135 | 136 | [tool.pyright] 137 | # this enables practically every flag given by pyright. 138 | # there are a couple of flags that are still disabled by 139 | # default in strict mode as they are experimental and niche. 140 | typeCheckingMode = "strict" 141 | pythonVersion = "3.8" 142 | 143 | exclude = [ 144 | "_dev", 145 | ".venv", 146 | ".nox", 147 | ] 148 | 149 | reportImplicitOverride = true 150 | reportOverlappingOverload = false 151 | 152 | reportImportCycles = false 153 | reportPrivateUsage = false 154 | 155 | [tool.ruff] 156 | line-length = 120 157 | output-format = "grouped" 158 | target-version = "py37" 159 | exclude = ["examples"] 160 | 161 | [tool.ruff.format] 162 | docstring-code-format = true 163 | 164 | [tool.ruff.lint] 165 | select = [ 166 | # isort 167 | "I", 168 | # bugbear rules 169 | "B", 170 | # remove unused imports 171 | "F401", 172 | # bare except statements 173 | "E722", 174 | # unused arguments 175 | "ARG", 176 | # print statements 177 | "T201", 178 | "T203", 179 | # misuse of typing.TYPE_CHECKING 180 | "TC004", 181 | # import rules 182 | "TID251", 183 | ] 184 | ignore = [ 185 | # mutable defaults 186 | "B006", 187 | ] 188 | unfixable = [ 189 | # disable auto fix for print statements 190 | "T201", 191 | "T203", 192 | ] 193 | 194 | [tool.ruff.lint.flake8-tidy-imports.banned-api] 195 | "functools.lru_cache".msg = "This function does not retain type information for the wrapped function's arguments; The `lru_cache` function from `_utils` should be used instead" 196 | 197 | [tool.ruff.lint.isort] 198 | length-sort = true 199 | length-sort-straight = true 200 | combine-as-imports = true 201 | extra-standard-library = ["typing_extensions"] 202 | known-first-party = ["atla", "tests"] 203 | 204 | [tool.ruff.lint.per-file-ignores] 205 | "bin/**.py" = ["T201", "T203"] 206 | "scripts/**.py" = ["T201", "T203"] 207 | "tests/**.py" = ["T201", "T203"] 208 | "examples/**.py" = ["T201", "T203"] 209 | -------------------------------------------------------------------------------- /release-please-config.json: -------------------------------------------------------------------------------- 1 | { 2 | "packages": { 3 | ".": {} 4 | }, 5 | "$schema": "https://raw.githubusercontent.com/stainless-api/release-please/main/schemas/config.json", 6 | "include-v-in-tag": true, 7 | "include-component-in-tag": false, 8 | "versioning": "prerelease", 9 | "prerelease": true, 10 | "bump-minor-pre-major": true, 11 | "bump-patch-for-minor-pre-major": false, 12 | "pull-request-header": "Automated Release PR", 13 | "pull-request-title-pattern": "release: ${version}", 14 | "changelog-sections": [ 15 | { 16 | "type": "feat", 17 | "section": "Features" 18 | }, 19 | { 20 | "type": "fix", 21 | "section": "Bug Fixes" 22 | }, 23 | { 24 | "type": "perf", 25 | "section": "Performance Improvements" 26 | }, 27 | { 28 | "type": "revert", 29 | "section": "Reverts" 30 | }, 31 | { 32 | "type": "chore", 33 | "section": "Chores" 34 | }, 35 | { 36 | "type": "docs", 37 | "section": "Documentation" 38 | }, 39 | { 40 | "type": "style", 41 | "section": "Styles" 42 | }, 43 | { 44 | "type": "refactor", 45 | "section": "Refactors" 46 | }, 47 | { 48 | "type": "test", 49 | "section": "Tests", 50 | "hidden": true 51 | }, 52 | { 53 | "type": "build", 54 | "section": "Build System" 55 | }, 56 | { 57 | "type": "ci", 58 | "section": "Continuous Integration", 59 | "hidden": true 60 | } 61 | ], 62 | "release-type": "python", 63 | "extra-files": [ 64 | "src/atla/_version.py" 65 | ] 66 | } -------------------------------------------------------------------------------- /requirements-dev.lock: -------------------------------------------------------------------------------- 1 | # generated by rye 2 | # use `rye lock` or `rye sync` to update this lockfile 3 | # 4 | # last locked with the following flags: 5 | # pre: false 6 | # features: [] 7 | # all-features: true 8 | # with-sources: false 9 | # generate-hashes: false 10 | # universal: false 11 | 12 | -e file:. 13 | annotated-types==0.6.0 14 | # via pydantic 15 | anyio==4.4.0 16 | # via atla 17 | # via httpx 18 | argcomplete==3.1.2 19 | # via nox 20 | certifi==2023.7.22 21 | # via httpcore 22 | # via httpx 23 | colorlog==6.7.0 24 | # via nox 25 | dirty-equals==0.6.0 26 | distlib==0.3.7 27 | # via virtualenv 28 | distro==1.8.0 29 | # via atla 30 | exceptiongroup==1.2.2 31 | # via anyio 32 | # via pytest 33 | filelock==3.12.4 34 | # via virtualenv 35 | h11==0.14.0 36 | # via httpcore 37 | httpcore==1.0.2 38 | # via httpx 39 | httpx==0.28.1 40 | # via atla 41 | # via respx 42 | idna==3.4 43 | # via anyio 44 | # via httpx 45 | importlib-metadata==7.0.0 46 | iniconfig==2.0.0 47 | # via pytest 48 | markdown-it-py==3.0.0 49 | # via rich 50 | mdurl==0.1.2 51 | # via markdown-it-py 52 | mypy==1.14.1 53 | mypy-extensions==1.0.0 54 | # via mypy 55 | nest-asyncio==1.6.0 56 | nodeenv==1.8.0 57 | # via pyright 58 | nox==2023.4.22 59 | packaging==23.2 60 | # via nox 61 | # via pytest 62 | platformdirs==3.11.0 63 | # via virtualenv 64 | pluggy==1.5.0 65 | # via pytest 66 | pydantic==2.10.3 67 | # via atla 68 | pydantic-core==2.27.1 69 | # via pydantic 70 | pygments==2.18.0 71 | # via rich 72 | pyright==1.1.399 73 | pytest==8.3.3 74 | # via pytest-asyncio 75 | pytest-asyncio==0.24.0 76 | python-dateutil==2.8.2 77 | # via time-machine 78 | pytz==2023.3.post1 79 | # via dirty-equals 80 | respx==0.22.0 81 | rich==13.7.1 82 | ruff==0.9.4 83 | setuptools==68.2.2 84 | # via nodeenv 85 | six==1.16.0 86 | # via python-dateutil 87 | sniffio==1.3.0 88 | # via anyio 89 | # via atla 90 | time-machine==2.9.0 91 | tomli==2.0.2 92 | # via mypy 93 | # via pytest 94 | typing-extensions==4.12.2 95 | # via anyio 96 | # via atla 97 | # via mypy 98 | # via pydantic 99 | # via pydantic-core 100 | # via pyright 101 | virtualenv==20.24.5 102 | # via nox 103 | zipp==3.17.0 104 | # via importlib-metadata 105 | -------------------------------------------------------------------------------- /requirements.lock: -------------------------------------------------------------------------------- 1 | # generated by rye 2 | # use `rye lock` or `rye sync` to update this lockfile 3 | # 4 | # last locked with the following flags: 5 | # pre: false 6 | # features: [] 7 | # all-features: true 8 | # with-sources: false 9 | # generate-hashes: false 10 | # universal: false 11 | 12 | -e file:. 13 | annotated-types==0.6.0 14 | # via pydantic 15 | anyio==4.4.0 16 | # via atla 17 | # via httpx 18 | certifi==2023.7.22 19 | # via httpcore 20 | # via httpx 21 | distro==1.8.0 22 | # via atla 23 | exceptiongroup==1.2.2 24 | # via anyio 25 | h11==0.14.0 26 | # via httpcore 27 | httpcore==1.0.2 28 | # via httpx 29 | httpx==0.28.1 30 | # via atla 31 | idna==3.4 32 | # via anyio 33 | # via httpx 34 | pydantic==2.10.3 35 | # via atla 36 | pydantic-core==2.27.1 37 | # via pydantic 38 | sniffio==1.3.0 39 | # via anyio 40 | # via atla 41 | typing-extensions==4.12.2 42 | # via anyio 43 | # via atla 44 | # via pydantic 45 | # via pydantic-core 46 | -------------------------------------------------------------------------------- /scripts/bootstrap: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | cd "$(dirname "$0")/.." 6 | 7 | if ! command -v rye >/dev/null 2>&1 && [ -f "Brewfile" ] && [ "$(uname -s)" = "Darwin" ]; then 8 | brew bundle check >/dev/null 2>&1 || { 9 | echo "==> Installing Homebrew dependencies…" 10 | brew bundle 11 | } 12 | fi 13 | 14 | echo "==> Installing Python dependencies…" 15 | 16 | # experimental uv support makes installations significantly faster 17 | rye config --set-bool behavior.use-uv=true 18 | 19 | rye sync --all-features 20 | -------------------------------------------------------------------------------- /scripts/format: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | cd "$(dirname "$0")/.." 6 | 7 | echo "==> Running formatters" 8 | rye run format 9 | -------------------------------------------------------------------------------- /scripts/lint: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | cd "$(dirname "$0")/.." 6 | 7 | echo "==> Running lints" 8 | rye run lint 9 | 10 | echo "==> Making sure it imports" 11 | rye run python -c 'import atla' 12 | -------------------------------------------------------------------------------- /scripts/mock: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | cd "$(dirname "$0")/.." 6 | 7 | if [[ -n "$1" && "$1" != '--'* ]]; then 8 | URL="$1" 9 | shift 10 | else 11 | URL="$(grep 'openapi_spec_url' .stats.yml | cut -d' ' -f2)" 12 | fi 13 | 14 | # Check if the URL is empty 15 | if [ -z "$URL" ]; then 16 | echo "Error: No OpenAPI spec path/url provided or found in .stats.yml" 17 | exit 1 18 | fi 19 | 20 | echo "==> Starting mock server with URL ${URL}" 21 | 22 | # Run prism mock on the given spec 23 | if [ "$1" == "--daemon" ]; then 24 | npm exec --package=@stainless-api/prism-cli@5.8.5 -- prism mock "$URL" &> .prism.log & 25 | 26 | # Wait for server to come online 27 | echo -n "Waiting for server" 28 | while ! grep -q "✖ fatal\|Prism is listening" ".prism.log" ; do 29 | echo -n "." 30 | sleep 0.1 31 | done 32 | 33 | if grep -q "✖ fatal" ".prism.log"; then 34 | cat .prism.log 35 | exit 1 36 | fi 37 | 38 | echo 39 | else 40 | npm exec --package=@stainless-api/prism-cli@5.8.5 -- prism mock "$URL" 41 | fi 42 | -------------------------------------------------------------------------------- /scripts/test: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | cd "$(dirname "$0")/.." 6 | 7 | RED='\033[0;31m' 8 | GREEN='\033[0;32m' 9 | YELLOW='\033[0;33m' 10 | NC='\033[0m' # No Color 11 | 12 | function prism_is_running() { 13 | curl --silent "http://localhost:4010" >/dev/null 2>&1 14 | } 15 | 16 | kill_server_on_port() { 17 | pids=$(lsof -t -i tcp:"$1" || echo "") 18 | if [ "$pids" != "" ]; then 19 | kill "$pids" 20 | echo "Stopped $pids." 21 | fi 22 | } 23 | 24 | function is_overriding_api_base_url() { 25 | [ -n "$TEST_API_BASE_URL" ] 26 | } 27 | 28 | if ! is_overriding_api_base_url && ! prism_is_running ; then 29 | # When we exit this script, make sure to kill the background mock server process 30 | trap 'kill_server_on_port 4010' EXIT 31 | 32 | # Start the dev server 33 | ./scripts/mock --daemon 34 | fi 35 | 36 | if is_overriding_api_base_url ; then 37 | echo -e "${GREEN}✔ Running tests against ${TEST_API_BASE_URL}${NC}" 38 | echo 39 | elif ! prism_is_running ; then 40 | echo -e "${RED}ERROR:${NC} The test suite will not run without a mock Prism server" 41 | echo -e "running against your OpenAPI spec." 42 | echo 43 | echo -e "To run the server, pass in the path or url of your OpenAPI" 44 | echo -e "spec to the prism command:" 45 | echo 46 | echo -e " \$ ${YELLOW}npm exec --package=@stoplight/prism-cli@~5.3.2 -- prism mock path/to/your.openapi.yml${NC}" 47 | echo 48 | 49 | exit 1 50 | else 51 | echo -e "${GREEN}✔ Mock prism server is running with your OpenAPI spec${NC}" 52 | echo 53 | fi 54 | 55 | export DEFER_PYDANTIC_BUILD=false 56 | 57 | echo "==> Running tests" 58 | rye run pytest "$@" 59 | 60 | echo "==> Running Pydantic v1 tests" 61 | rye run nox -s test-pydantic-v1 -- "$@" 62 | -------------------------------------------------------------------------------- /scripts/utils/ruffen-docs.py: -------------------------------------------------------------------------------- 1 | # fork of https://github.com/asottile/blacken-docs adapted for ruff 2 | from __future__ import annotations 3 | 4 | import re 5 | import sys 6 | import argparse 7 | import textwrap 8 | import contextlib 9 | import subprocess 10 | from typing import Match, Optional, Sequence, Generator, NamedTuple, cast 11 | 12 | MD_RE = re.compile( 13 | r"(?P^(?P *)```\s*python\n)" r"(?P.*?)" r"(?P^(?P=indent)```\s*$)", 14 | re.DOTALL | re.MULTILINE, 15 | ) 16 | MD_PYCON_RE = re.compile( 17 | r"(?P^(?P *)```\s*pycon\n)" r"(?P.*?)" r"(?P^(?P=indent)```.*$)", 18 | re.DOTALL | re.MULTILINE, 19 | ) 20 | PYCON_PREFIX = ">>> " 21 | PYCON_CONTINUATION_PREFIX = "..." 22 | PYCON_CONTINUATION_RE = re.compile( 23 | rf"^{re.escape(PYCON_CONTINUATION_PREFIX)}( |$)", 24 | ) 25 | DEFAULT_LINE_LENGTH = 100 26 | 27 | 28 | class CodeBlockError(NamedTuple): 29 | offset: int 30 | exc: Exception 31 | 32 | 33 | def format_str( 34 | src: str, 35 | ) -> tuple[str, Sequence[CodeBlockError]]: 36 | errors: list[CodeBlockError] = [] 37 | 38 | @contextlib.contextmanager 39 | def _collect_error(match: Match[str]) -> Generator[None, None, None]: 40 | try: 41 | yield 42 | except Exception as e: 43 | errors.append(CodeBlockError(match.start(), e)) 44 | 45 | def _md_match(match: Match[str]) -> str: 46 | code = textwrap.dedent(match["code"]) 47 | with _collect_error(match): 48 | code = format_code_block(code) 49 | code = textwrap.indent(code, match["indent"]) 50 | return f"{match['before']}{code}{match['after']}" 51 | 52 | def _pycon_match(match: Match[str]) -> str: 53 | code = "" 54 | fragment = cast(Optional[str], None) 55 | 56 | def finish_fragment() -> None: 57 | nonlocal code 58 | nonlocal fragment 59 | 60 | if fragment is not None: 61 | with _collect_error(match): 62 | fragment = format_code_block(fragment) 63 | fragment_lines = fragment.splitlines() 64 | code += f"{PYCON_PREFIX}{fragment_lines[0]}\n" 65 | for line in fragment_lines[1:]: 66 | # Skip blank lines to handle Black adding a blank above 67 | # functions within blocks. A blank line would end the REPL 68 | # continuation prompt. 69 | # 70 | # >>> if True: 71 | # ... def f(): 72 | # ... pass 73 | # ... 74 | if line: 75 | code += f"{PYCON_CONTINUATION_PREFIX} {line}\n" 76 | if fragment_lines[-1].startswith(" "): 77 | code += f"{PYCON_CONTINUATION_PREFIX}\n" 78 | fragment = None 79 | 80 | indentation = None 81 | for line in match["code"].splitlines(): 82 | orig_line, line = line, line.lstrip() 83 | if indentation is None and line: 84 | indentation = len(orig_line) - len(line) 85 | continuation_match = PYCON_CONTINUATION_RE.match(line) 86 | if continuation_match and fragment is not None: 87 | fragment += line[continuation_match.end() :] + "\n" 88 | else: 89 | finish_fragment() 90 | if line.startswith(PYCON_PREFIX): 91 | fragment = line[len(PYCON_PREFIX) :] + "\n" 92 | else: 93 | code += orig_line[indentation:] + "\n" 94 | finish_fragment() 95 | return code 96 | 97 | def _md_pycon_match(match: Match[str]) -> str: 98 | code = _pycon_match(match) 99 | code = textwrap.indent(code, match["indent"]) 100 | return f"{match['before']}{code}{match['after']}" 101 | 102 | src = MD_RE.sub(_md_match, src) 103 | src = MD_PYCON_RE.sub(_md_pycon_match, src) 104 | return src, errors 105 | 106 | 107 | def format_code_block(code: str) -> str: 108 | return subprocess.check_output( 109 | [ 110 | sys.executable, 111 | "-m", 112 | "ruff", 113 | "format", 114 | "--stdin-filename=script.py", 115 | f"--line-length={DEFAULT_LINE_LENGTH}", 116 | ], 117 | encoding="utf-8", 118 | input=code, 119 | ) 120 | 121 | 122 | def format_file( 123 | filename: str, 124 | skip_errors: bool, 125 | ) -> int: 126 | with open(filename, encoding="UTF-8") as f: 127 | contents = f.read() 128 | new_contents, errors = format_str(contents) 129 | for error in errors: 130 | lineno = contents[: error.offset].count("\n") + 1 131 | print(f"{filename}:{lineno}: code block parse error {error.exc}") 132 | if errors and not skip_errors: 133 | return 1 134 | if contents != new_contents: 135 | print(f"{filename}: Rewriting...") 136 | with open(filename, "w", encoding="UTF-8") as f: 137 | f.write(new_contents) 138 | return 0 139 | else: 140 | return 0 141 | 142 | 143 | def main(argv: Sequence[str] | None = None) -> int: 144 | parser = argparse.ArgumentParser() 145 | parser.add_argument( 146 | "-l", 147 | "--line-length", 148 | type=int, 149 | default=DEFAULT_LINE_LENGTH, 150 | ) 151 | parser.add_argument( 152 | "-S", 153 | "--skip-string-normalization", 154 | action="store_true", 155 | ) 156 | parser.add_argument("-E", "--skip-errors", action="store_true") 157 | parser.add_argument("filenames", nargs="*") 158 | args = parser.parse_args(argv) 159 | 160 | retv = 0 161 | for filename in args.filenames: 162 | retv |= format_file(filename, skip_errors=args.skip_errors) 163 | return retv 164 | 165 | 166 | if __name__ == "__main__": 167 | raise SystemExit(main()) 168 | -------------------------------------------------------------------------------- /src/atla/__init__.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | from . import types 4 | from ._types import NOT_GIVEN, Omit, NoneType, NotGiven, Transport, ProxiesTypes 5 | from ._utils import file_from_path 6 | from ._client import Atla, Client, Stream, Timeout, AsyncAtla, Transport, AsyncClient, AsyncStream, RequestOptions 7 | from ._models import BaseModel 8 | from ._version import __title__, __version__ 9 | from ._response import APIResponse as APIResponse, AsyncAPIResponse as AsyncAPIResponse 10 | from ._constants import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES, DEFAULT_CONNECTION_LIMITS 11 | from ._exceptions import ( 12 | APIError, 13 | AtlaError, 14 | ConflictError, 15 | NotFoundError, 16 | APIStatusError, 17 | RateLimitError, 18 | APITimeoutError, 19 | BadRequestError, 20 | APIConnectionError, 21 | AuthenticationError, 22 | InternalServerError, 23 | PermissionDeniedError, 24 | UnprocessableEntityError, 25 | APIResponseValidationError, 26 | ) 27 | from ._base_client import DefaultHttpxClient, DefaultAsyncHttpxClient 28 | from ._utils._logs import setup_logging as _setup_logging 29 | 30 | __all__ = [ 31 | "types", 32 | "__version__", 33 | "__title__", 34 | "NoneType", 35 | "Transport", 36 | "ProxiesTypes", 37 | "NotGiven", 38 | "NOT_GIVEN", 39 | "Omit", 40 | "AtlaError", 41 | "APIError", 42 | "APIStatusError", 43 | "APITimeoutError", 44 | "APIConnectionError", 45 | "APIResponseValidationError", 46 | "BadRequestError", 47 | "AuthenticationError", 48 | "PermissionDeniedError", 49 | "NotFoundError", 50 | "ConflictError", 51 | "UnprocessableEntityError", 52 | "RateLimitError", 53 | "InternalServerError", 54 | "Timeout", 55 | "RequestOptions", 56 | "Client", 57 | "AsyncClient", 58 | "Stream", 59 | "AsyncStream", 60 | "Atla", 61 | "AsyncAtla", 62 | "file_from_path", 63 | "BaseModel", 64 | "DEFAULT_TIMEOUT", 65 | "DEFAULT_MAX_RETRIES", 66 | "DEFAULT_CONNECTION_LIMITS", 67 | "DefaultHttpxClient", 68 | "DefaultAsyncHttpxClient", 69 | ] 70 | 71 | _setup_logging() 72 | 73 | # Update the __module__ attribute for exported symbols so that 74 | # error messages point to this module instead of the module 75 | # it was originally defined in, e.g. 76 | # atla._exceptions.NotFoundError -> atla.NotFoundError 77 | __locals = locals() 78 | for __name in __all__: 79 | if not __name.startswith("__"): 80 | try: 81 | __locals[__name].__module__ = "atla" 82 | except (TypeError, AttributeError): 83 | # Some of our exported symbols are builtins which we can't set attributes for. 84 | pass 85 | -------------------------------------------------------------------------------- /src/atla/_compat.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast, overload 4 | from datetime import date, datetime 5 | from typing_extensions import Self, Literal 6 | 7 | import pydantic 8 | from pydantic.fields import FieldInfo 9 | 10 | from ._types import IncEx, StrBytesIntFloat 11 | 12 | _T = TypeVar("_T") 13 | _ModelT = TypeVar("_ModelT", bound=pydantic.BaseModel) 14 | 15 | # --------------- Pydantic v2 compatibility --------------- 16 | 17 | # Pyright incorrectly reports some of our functions as overriding a method when they don't 18 | # pyright: reportIncompatibleMethodOverride=false 19 | 20 | PYDANTIC_V2 = pydantic.VERSION.startswith("2.") 21 | 22 | # v1 re-exports 23 | if TYPE_CHECKING: 24 | 25 | def parse_date(value: date | StrBytesIntFloat) -> date: # noqa: ARG001 26 | ... 27 | 28 | def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: # noqa: ARG001 29 | ... 30 | 31 | def get_args(t: type[Any]) -> tuple[Any, ...]: # noqa: ARG001 32 | ... 33 | 34 | def is_union(tp: type[Any] | None) -> bool: # noqa: ARG001 35 | ... 36 | 37 | def get_origin(t: type[Any]) -> type[Any] | None: # noqa: ARG001 38 | ... 39 | 40 | def is_literal_type(type_: type[Any]) -> bool: # noqa: ARG001 41 | ... 42 | 43 | def is_typeddict(type_: type[Any]) -> bool: # noqa: ARG001 44 | ... 45 | 46 | else: 47 | if PYDANTIC_V2: 48 | from pydantic.v1.typing import ( 49 | get_args as get_args, 50 | is_union as is_union, 51 | get_origin as get_origin, 52 | is_typeddict as is_typeddict, 53 | is_literal_type as is_literal_type, 54 | ) 55 | from pydantic.v1.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime 56 | else: 57 | from pydantic.typing import ( 58 | get_args as get_args, 59 | is_union as is_union, 60 | get_origin as get_origin, 61 | is_typeddict as is_typeddict, 62 | is_literal_type as is_literal_type, 63 | ) 64 | from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime 65 | 66 | 67 | # refactored config 68 | if TYPE_CHECKING: 69 | from pydantic import ConfigDict as ConfigDict 70 | else: 71 | if PYDANTIC_V2: 72 | from pydantic import ConfigDict 73 | else: 74 | # TODO: provide an error message here? 75 | ConfigDict = None 76 | 77 | 78 | # renamed methods / properties 79 | def parse_obj(model: type[_ModelT], value: object) -> _ModelT: 80 | if PYDANTIC_V2: 81 | return model.model_validate(value) 82 | else: 83 | return cast(_ModelT, model.parse_obj(value)) # pyright: ignore[reportDeprecated, reportUnnecessaryCast] 84 | 85 | 86 | def field_is_required(field: FieldInfo) -> bool: 87 | if PYDANTIC_V2: 88 | return field.is_required() 89 | return field.required # type: ignore 90 | 91 | 92 | def field_get_default(field: FieldInfo) -> Any: 93 | value = field.get_default() 94 | if PYDANTIC_V2: 95 | from pydantic_core import PydanticUndefined 96 | 97 | if value == PydanticUndefined: 98 | return None 99 | return value 100 | return value 101 | 102 | 103 | def field_outer_type(field: FieldInfo) -> Any: 104 | if PYDANTIC_V2: 105 | return field.annotation 106 | return field.outer_type_ # type: ignore 107 | 108 | 109 | def get_model_config(model: type[pydantic.BaseModel]) -> Any: 110 | if PYDANTIC_V2: 111 | return model.model_config 112 | return model.__config__ # type: ignore 113 | 114 | 115 | def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]: 116 | if PYDANTIC_V2: 117 | return model.model_fields 118 | return model.__fields__ # type: ignore 119 | 120 | 121 | def model_copy(model: _ModelT, *, deep: bool = False) -> _ModelT: 122 | if PYDANTIC_V2: 123 | return model.model_copy(deep=deep) 124 | return model.copy(deep=deep) # type: ignore 125 | 126 | 127 | def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str: 128 | if PYDANTIC_V2: 129 | return model.model_dump_json(indent=indent) 130 | return model.json(indent=indent) # type: ignore 131 | 132 | 133 | def model_dump( 134 | model: pydantic.BaseModel, 135 | *, 136 | exclude: IncEx | None = None, 137 | exclude_unset: bool = False, 138 | exclude_defaults: bool = False, 139 | warnings: bool = True, 140 | mode: Literal["json", "python"] = "python", 141 | ) -> dict[str, Any]: 142 | if PYDANTIC_V2 or hasattr(model, "model_dump"): 143 | return model.model_dump( 144 | mode=mode, 145 | exclude=exclude, 146 | exclude_unset=exclude_unset, 147 | exclude_defaults=exclude_defaults, 148 | # warnings are not supported in Pydantic v1 149 | warnings=warnings if PYDANTIC_V2 else True, 150 | ) 151 | return cast( 152 | "dict[str, Any]", 153 | model.dict( # pyright: ignore[reportDeprecated, reportUnnecessaryCast] 154 | exclude=exclude, 155 | exclude_unset=exclude_unset, 156 | exclude_defaults=exclude_defaults, 157 | ), 158 | ) 159 | 160 | 161 | def model_parse(model: type[_ModelT], data: Any) -> _ModelT: 162 | if PYDANTIC_V2: 163 | return model.model_validate(data) 164 | return model.parse_obj(data) # pyright: ignore[reportDeprecated] 165 | 166 | 167 | # generic models 168 | if TYPE_CHECKING: 169 | 170 | class GenericModel(pydantic.BaseModel): ... 171 | 172 | else: 173 | if PYDANTIC_V2: 174 | # there no longer needs to be a distinction in v2 but 175 | # we still have to create our own subclass to avoid 176 | # inconsistent MRO ordering errors 177 | class GenericModel(pydantic.BaseModel): ... 178 | 179 | else: 180 | import pydantic.generics 181 | 182 | class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel): ... 183 | 184 | 185 | # cached properties 186 | if TYPE_CHECKING: 187 | cached_property = property 188 | 189 | # we define a separate type (copied from typeshed) 190 | # that represents that `cached_property` is `set`able 191 | # at runtime, which differs from `@property`. 192 | # 193 | # this is a separate type as editors likely special case 194 | # `@property` and we don't want to cause issues just to have 195 | # more helpful internal types. 196 | 197 | class typed_cached_property(Generic[_T]): 198 | func: Callable[[Any], _T] 199 | attrname: str | None 200 | 201 | def __init__(self, func: Callable[[Any], _T]) -> None: ... 202 | 203 | @overload 204 | def __get__(self, instance: None, owner: type[Any] | None = None) -> Self: ... 205 | 206 | @overload 207 | def __get__(self, instance: object, owner: type[Any] | None = None) -> _T: ... 208 | 209 | def __get__(self, instance: object, owner: type[Any] | None = None) -> _T | Self: 210 | raise NotImplementedError() 211 | 212 | def __set_name__(self, owner: type[Any], name: str) -> None: ... 213 | 214 | # __set__ is not defined at runtime, but @cached_property is designed to be settable 215 | def __set__(self, instance: object, value: _T) -> None: ... 216 | else: 217 | from functools import cached_property as cached_property 218 | 219 | typed_cached_property = cached_property 220 | -------------------------------------------------------------------------------- /src/atla/_constants.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | import httpx 4 | 5 | RAW_RESPONSE_HEADER = "X-Stainless-Raw-Response" 6 | OVERRIDE_CAST_TO_HEADER = "____stainless_override_cast_to" 7 | 8 | # default timeout is 1 minute 9 | DEFAULT_TIMEOUT = httpx.Timeout(timeout=60, connect=5.0) 10 | DEFAULT_MAX_RETRIES = 2 11 | DEFAULT_CONNECTION_LIMITS = httpx.Limits(max_connections=100, max_keepalive_connections=20) 12 | 13 | INITIAL_RETRY_DELAY = 0.5 14 | MAX_RETRY_DELAY = 8.0 15 | -------------------------------------------------------------------------------- /src/atla/_exceptions.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | from __future__ import annotations 4 | 5 | from typing_extensions import Literal 6 | 7 | import httpx 8 | 9 | __all__ = [ 10 | "BadRequestError", 11 | "AuthenticationError", 12 | "PermissionDeniedError", 13 | "NotFoundError", 14 | "ConflictError", 15 | "UnprocessableEntityError", 16 | "RateLimitError", 17 | "InternalServerError", 18 | ] 19 | 20 | 21 | class AtlaError(Exception): 22 | pass 23 | 24 | 25 | class APIError(AtlaError): 26 | message: str 27 | request: httpx.Request 28 | 29 | body: object | None 30 | """The API response body. 31 | 32 | If the API responded with a valid JSON structure then this property will be the 33 | decoded result. 34 | 35 | If it isn't a valid JSON structure then this will be the raw response. 36 | 37 | If there was no response associated with this error then it will be `None`. 38 | """ 39 | 40 | def __init__(self, message: str, request: httpx.Request, *, body: object | None) -> None: # noqa: ARG002 41 | super().__init__(message) 42 | self.request = request 43 | self.message = message 44 | self.body = body 45 | 46 | 47 | class APIResponseValidationError(APIError): 48 | response: httpx.Response 49 | status_code: int 50 | 51 | def __init__(self, response: httpx.Response, body: object | None, *, message: str | None = None) -> None: 52 | super().__init__(message or "Data returned by API invalid for expected schema.", response.request, body=body) 53 | self.response = response 54 | self.status_code = response.status_code 55 | 56 | 57 | class APIStatusError(APIError): 58 | """Raised when an API response has a status code of 4xx or 5xx.""" 59 | 60 | response: httpx.Response 61 | status_code: int 62 | 63 | def __init__(self, message: str, *, response: httpx.Response, body: object | None) -> None: 64 | super().__init__(message, response.request, body=body) 65 | self.response = response 66 | self.status_code = response.status_code 67 | 68 | 69 | class APIConnectionError(APIError): 70 | def __init__(self, *, message: str = "Connection error.", request: httpx.Request) -> None: 71 | super().__init__(message, request, body=None) 72 | 73 | 74 | class APITimeoutError(APIConnectionError): 75 | def __init__(self, request: httpx.Request) -> None: 76 | super().__init__(message="Request timed out.", request=request) 77 | 78 | 79 | class BadRequestError(APIStatusError): 80 | status_code: Literal[400] = 400 # pyright: ignore[reportIncompatibleVariableOverride] 81 | 82 | 83 | class AuthenticationError(APIStatusError): 84 | status_code: Literal[401] = 401 # pyright: ignore[reportIncompatibleVariableOverride] 85 | 86 | 87 | class PermissionDeniedError(APIStatusError): 88 | status_code: Literal[403] = 403 # pyright: ignore[reportIncompatibleVariableOverride] 89 | 90 | 91 | class NotFoundError(APIStatusError): 92 | status_code: Literal[404] = 404 # pyright: ignore[reportIncompatibleVariableOverride] 93 | 94 | 95 | class ConflictError(APIStatusError): 96 | status_code: Literal[409] = 409 # pyright: ignore[reportIncompatibleVariableOverride] 97 | 98 | 99 | class UnprocessableEntityError(APIStatusError): 100 | status_code: Literal[422] = 422 # pyright: ignore[reportIncompatibleVariableOverride] 101 | 102 | 103 | class RateLimitError(APIStatusError): 104 | status_code: Literal[429] = 429 # pyright: ignore[reportIncompatibleVariableOverride] 105 | 106 | 107 | class InternalServerError(APIStatusError): 108 | pass 109 | -------------------------------------------------------------------------------- /src/atla/_files.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import io 4 | import os 5 | import pathlib 6 | from typing import overload 7 | from typing_extensions import TypeGuard 8 | 9 | import anyio 10 | 11 | from ._types import ( 12 | FileTypes, 13 | FileContent, 14 | RequestFiles, 15 | HttpxFileTypes, 16 | Base64FileInput, 17 | HttpxFileContent, 18 | HttpxRequestFiles, 19 | ) 20 | from ._utils import is_tuple_t, is_mapping_t, is_sequence_t 21 | 22 | 23 | def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]: 24 | return isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike) 25 | 26 | 27 | def is_file_content(obj: object) -> TypeGuard[FileContent]: 28 | return ( 29 | isinstance(obj, bytes) or isinstance(obj, tuple) or isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike) 30 | ) 31 | 32 | 33 | def assert_is_file_content(obj: object, *, key: str | None = None) -> None: 34 | if not is_file_content(obj): 35 | prefix = f"Expected entry at `{key}`" if key is not None else f"Expected file input `{obj!r}`" 36 | raise RuntimeError( 37 | f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(obj)} instead." 38 | ) from None 39 | 40 | 41 | @overload 42 | def to_httpx_files(files: None) -> None: ... 43 | 44 | 45 | @overload 46 | def to_httpx_files(files: RequestFiles) -> HttpxRequestFiles: ... 47 | 48 | 49 | def to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None: 50 | if files is None: 51 | return None 52 | 53 | if is_mapping_t(files): 54 | files = {key: _transform_file(file) for key, file in files.items()} 55 | elif is_sequence_t(files): 56 | files = [(key, _transform_file(file)) for key, file in files] 57 | else: 58 | raise TypeError(f"Unexpected file type input {type(files)}, expected mapping or sequence") 59 | 60 | return files 61 | 62 | 63 | def _transform_file(file: FileTypes) -> HttpxFileTypes: 64 | if is_file_content(file): 65 | if isinstance(file, os.PathLike): 66 | path = pathlib.Path(file) 67 | return (path.name, path.read_bytes()) 68 | 69 | return file 70 | 71 | if is_tuple_t(file): 72 | return (file[0], _read_file_content(file[1]), *file[2:]) 73 | 74 | raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple") 75 | 76 | 77 | def _read_file_content(file: FileContent) -> HttpxFileContent: 78 | if isinstance(file, os.PathLike): 79 | return pathlib.Path(file).read_bytes() 80 | return file 81 | 82 | 83 | @overload 84 | async def async_to_httpx_files(files: None) -> None: ... 85 | 86 | 87 | @overload 88 | async def async_to_httpx_files(files: RequestFiles) -> HttpxRequestFiles: ... 89 | 90 | 91 | async def async_to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None: 92 | if files is None: 93 | return None 94 | 95 | if is_mapping_t(files): 96 | files = {key: await _async_transform_file(file) for key, file in files.items()} 97 | elif is_sequence_t(files): 98 | files = [(key, await _async_transform_file(file)) for key, file in files] 99 | else: 100 | raise TypeError("Unexpected file type input {type(files)}, expected mapping or sequence") 101 | 102 | return files 103 | 104 | 105 | async def _async_transform_file(file: FileTypes) -> HttpxFileTypes: 106 | if is_file_content(file): 107 | if isinstance(file, os.PathLike): 108 | path = anyio.Path(file) 109 | return (path.name, await path.read_bytes()) 110 | 111 | return file 112 | 113 | if is_tuple_t(file): 114 | return (file[0], await _async_read_file_content(file[1]), *file[2:]) 115 | 116 | raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple") 117 | 118 | 119 | async def _async_read_file_content(file: FileContent) -> HttpxFileContent: 120 | if isinstance(file, os.PathLike): 121 | return await anyio.Path(file).read_bytes() 122 | 123 | return file 124 | -------------------------------------------------------------------------------- /src/atla/_qs.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any, List, Tuple, Union, Mapping, TypeVar 4 | from urllib.parse import parse_qs, urlencode 5 | from typing_extensions import Literal, get_args 6 | 7 | from ._types import NOT_GIVEN, NotGiven, NotGivenOr 8 | from ._utils import flatten 9 | 10 | _T = TypeVar("_T") 11 | 12 | 13 | ArrayFormat = Literal["comma", "repeat", "indices", "brackets"] 14 | NestedFormat = Literal["dots", "brackets"] 15 | 16 | PrimitiveData = Union[str, int, float, bool, None] 17 | # this should be Data = Union[PrimitiveData, "List[Data]", "Tuple[Data]", "Mapping[str, Data]"] 18 | # https://github.com/microsoft/pyright/issues/3555 19 | Data = Union[PrimitiveData, List[Any], Tuple[Any], "Mapping[str, Any]"] 20 | Params = Mapping[str, Data] 21 | 22 | 23 | class Querystring: 24 | array_format: ArrayFormat 25 | nested_format: NestedFormat 26 | 27 | def __init__( 28 | self, 29 | *, 30 | array_format: ArrayFormat = "repeat", 31 | nested_format: NestedFormat = "brackets", 32 | ) -> None: 33 | self.array_format = array_format 34 | self.nested_format = nested_format 35 | 36 | def parse(self, query: str) -> Mapping[str, object]: 37 | # Note: custom format syntax is not supported yet 38 | return parse_qs(query) 39 | 40 | def stringify( 41 | self, 42 | params: Params, 43 | *, 44 | array_format: NotGivenOr[ArrayFormat] = NOT_GIVEN, 45 | nested_format: NotGivenOr[NestedFormat] = NOT_GIVEN, 46 | ) -> str: 47 | return urlencode( 48 | self.stringify_items( 49 | params, 50 | array_format=array_format, 51 | nested_format=nested_format, 52 | ) 53 | ) 54 | 55 | def stringify_items( 56 | self, 57 | params: Params, 58 | *, 59 | array_format: NotGivenOr[ArrayFormat] = NOT_GIVEN, 60 | nested_format: NotGivenOr[NestedFormat] = NOT_GIVEN, 61 | ) -> list[tuple[str, str]]: 62 | opts = Options( 63 | qs=self, 64 | array_format=array_format, 65 | nested_format=nested_format, 66 | ) 67 | return flatten([self._stringify_item(key, value, opts) for key, value in params.items()]) 68 | 69 | def _stringify_item( 70 | self, 71 | key: str, 72 | value: Data, 73 | opts: Options, 74 | ) -> list[tuple[str, str]]: 75 | if isinstance(value, Mapping): 76 | items: list[tuple[str, str]] = [] 77 | nested_format = opts.nested_format 78 | for subkey, subvalue in value.items(): 79 | items.extend( 80 | self._stringify_item( 81 | # TODO: error if unknown format 82 | f"{key}.{subkey}" if nested_format == "dots" else f"{key}[{subkey}]", 83 | subvalue, 84 | opts, 85 | ) 86 | ) 87 | return items 88 | 89 | if isinstance(value, (list, tuple)): 90 | array_format = opts.array_format 91 | if array_format == "comma": 92 | return [ 93 | ( 94 | key, 95 | ",".join(self._primitive_value_to_str(item) for item in value if item is not None), 96 | ), 97 | ] 98 | elif array_format == "repeat": 99 | items = [] 100 | for item in value: 101 | items.extend(self._stringify_item(key, item, opts)) 102 | return items 103 | elif array_format == "indices": 104 | raise NotImplementedError("The array indices format is not supported yet") 105 | elif array_format == "brackets": 106 | items = [] 107 | key = key + "[]" 108 | for item in value: 109 | items.extend(self._stringify_item(key, item, opts)) 110 | return items 111 | else: 112 | raise NotImplementedError( 113 | f"Unknown array_format value: {array_format}, choose from {', '.join(get_args(ArrayFormat))}" 114 | ) 115 | 116 | serialised = self._primitive_value_to_str(value) 117 | if not serialised: 118 | return [] 119 | return [(key, serialised)] 120 | 121 | def _primitive_value_to_str(self, value: PrimitiveData) -> str: 122 | # copied from httpx 123 | if value is True: 124 | return "true" 125 | elif value is False: 126 | return "false" 127 | elif value is None: 128 | return "" 129 | return str(value) 130 | 131 | 132 | _qs = Querystring() 133 | parse = _qs.parse 134 | stringify = _qs.stringify 135 | stringify_items = _qs.stringify_items 136 | 137 | 138 | class Options: 139 | array_format: ArrayFormat 140 | nested_format: NestedFormat 141 | 142 | def __init__( 143 | self, 144 | qs: Querystring = _qs, 145 | *, 146 | array_format: NotGivenOr[ArrayFormat] = NOT_GIVEN, 147 | nested_format: NotGivenOr[NestedFormat] = NOT_GIVEN, 148 | ) -> None: 149 | self.array_format = qs.array_format if isinstance(array_format, NotGiven) else array_format 150 | self.nested_format = qs.nested_format if isinstance(nested_format, NotGiven) else nested_format 151 | -------------------------------------------------------------------------------- /src/atla/_resource.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | from __future__ import annotations 4 | 5 | import time 6 | from typing import TYPE_CHECKING 7 | 8 | import anyio 9 | 10 | if TYPE_CHECKING: 11 | from ._client import Atla, AsyncAtla 12 | 13 | 14 | class SyncAPIResource: 15 | _client: Atla 16 | 17 | def __init__(self, client: Atla) -> None: 18 | self._client = client 19 | self._get = client.get 20 | self._post = client.post 21 | self._patch = client.patch 22 | self._put = client.put 23 | self._delete = client.delete 24 | self._get_api_list = client.get_api_list 25 | 26 | def _sleep(self, seconds: float) -> None: 27 | time.sleep(seconds) 28 | 29 | 30 | class AsyncAPIResource: 31 | _client: AsyncAtla 32 | 33 | def __init__(self, client: AsyncAtla) -> None: 34 | self._client = client 35 | self._get = client.get 36 | self._post = client.post 37 | self._patch = client.patch 38 | self._put = client.put 39 | self._delete = client.delete 40 | self._get_api_list = client.get_api_list 41 | 42 | async def _sleep(self, seconds: float) -> None: 43 | await anyio.sleep(seconds) 44 | -------------------------------------------------------------------------------- /src/atla/_streaming.py: -------------------------------------------------------------------------------- 1 | # Note: initially copied from https://github.com/florimondmanca/httpx-sse/blob/master/src/httpx_sse/_decoders.py 2 | from __future__ import annotations 3 | 4 | import json 5 | import inspect 6 | from types import TracebackType 7 | from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast 8 | from typing_extensions import Self, Protocol, TypeGuard, override, get_origin, runtime_checkable 9 | 10 | import httpx 11 | 12 | from ._utils import extract_type_var_from_base 13 | 14 | if TYPE_CHECKING: 15 | from ._client import Atla, AsyncAtla 16 | 17 | 18 | _T = TypeVar("_T") 19 | 20 | 21 | class Stream(Generic[_T]): 22 | """Provides the core interface to iterate over a synchronous stream response.""" 23 | 24 | response: httpx.Response 25 | 26 | _decoder: SSEBytesDecoder 27 | 28 | def __init__( 29 | self, 30 | *, 31 | cast_to: type[_T], 32 | response: httpx.Response, 33 | client: Atla, 34 | ) -> None: 35 | self.response = response 36 | self._cast_to = cast_to 37 | self._client = client 38 | self._decoder = client._make_sse_decoder() 39 | self._iterator = self.__stream__() 40 | 41 | def __next__(self) -> _T: 42 | return self._iterator.__next__() 43 | 44 | def __iter__(self) -> Iterator[_T]: 45 | for item in self._iterator: 46 | yield item 47 | 48 | def _iter_events(self) -> Iterator[ServerSentEvent]: 49 | yield from self._decoder.iter_bytes(self.response.iter_bytes()) 50 | 51 | def __stream__(self) -> Iterator[_T]: 52 | cast_to = cast(Any, self._cast_to) 53 | response = self.response 54 | process_data = self._client._process_response_data 55 | iterator = self._iter_events() 56 | 57 | for sse in iterator: 58 | yield process_data(data=sse.json(), cast_to=cast_to, response=response) 59 | 60 | # Ensure the entire stream is consumed 61 | for _sse in iterator: 62 | ... 63 | 64 | def __enter__(self) -> Self: 65 | return self 66 | 67 | def __exit__( 68 | self, 69 | exc_type: type[BaseException] | None, 70 | exc: BaseException | None, 71 | exc_tb: TracebackType | None, 72 | ) -> None: 73 | self.close() 74 | 75 | def close(self) -> None: 76 | """ 77 | Close the response and release the connection. 78 | 79 | Automatically called if the response body is read to completion. 80 | """ 81 | self.response.close() 82 | 83 | 84 | class AsyncStream(Generic[_T]): 85 | """Provides the core interface to iterate over an asynchronous stream response.""" 86 | 87 | response: httpx.Response 88 | 89 | _decoder: SSEDecoder | SSEBytesDecoder 90 | 91 | def __init__( 92 | self, 93 | *, 94 | cast_to: type[_T], 95 | response: httpx.Response, 96 | client: AsyncAtla, 97 | ) -> None: 98 | self.response = response 99 | self._cast_to = cast_to 100 | self._client = client 101 | self._decoder = client._make_sse_decoder() 102 | self._iterator = self.__stream__() 103 | 104 | async def __anext__(self) -> _T: 105 | return await self._iterator.__anext__() 106 | 107 | async def __aiter__(self) -> AsyncIterator[_T]: 108 | async for item in self._iterator: 109 | yield item 110 | 111 | async def _iter_events(self) -> AsyncIterator[ServerSentEvent]: 112 | async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()): 113 | yield sse 114 | 115 | async def __stream__(self) -> AsyncIterator[_T]: 116 | cast_to = cast(Any, self._cast_to) 117 | response = self.response 118 | process_data = self._client._process_response_data 119 | iterator = self._iter_events() 120 | 121 | async for sse in iterator: 122 | yield process_data(data=sse.json(), cast_to=cast_to, response=response) 123 | 124 | # Ensure the entire stream is consumed 125 | async for _sse in iterator: 126 | ... 127 | 128 | async def __aenter__(self) -> Self: 129 | return self 130 | 131 | async def __aexit__( 132 | self, 133 | exc_type: type[BaseException] | None, 134 | exc: BaseException | None, 135 | exc_tb: TracebackType | None, 136 | ) -> None: 137 | await self.close() 138 | 139 | async def close(self) -> None: 140 | """ 141 | Close the response and release the connection. 142 | 143 | Automatically called if the response body is read to completion. 144 | """ 145 | await self.response.aclose() 146 | 147 | 148 | class ServerSentEvent: 149 | def __init__( 150 | self, 151 | *, 152 | event: str | None = None, 153 | data: str | None = None, 154 | id: str | None = None, 155 | retry: int | None = None, 156 | ) -> None: 157 | if data is None: 158 | data = "" 159 | 160 | self._id = id 161 | self._data = data 162 | self._event = event or None 163 | self._retry = retry 164 | 165 | @property 166 | def event(self) -> str | None: 167 | return self._event 168 | 169 | @property 170 | def id(self) -> str | None: 171 | return self._id 172 | 173 | @property 174 | def retry(self) -> int | None: 175 | return self._retry 176 | 177 | @property 178 | def data(self) -> str: 179 | return self._data 180 | 181 | def json(self) -> Any: 182 | return json.loads(self.data) 183 | 184 | @override 185 | def __repr__(self) -> str: 186 | return f"ServerSentEvent(event={self.event}, data={self.data}, id={self.id}, retry={self.retry})" 187 | 188 | 189 | class SSEDecoder: 190 | _data: list[str] 191 | _event: str | None 192 | _retry: int | None 193 | _last_event_id: str | None 194 | 195 | def __init__(self) -> None: 196 | self._event = None 197 | self._data = [] 198 | self._last_event_id = None 199 | self._retry = None 200 | 201 | def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]: 202 | """Given an iterator that yields raw binary data, iterate over it & yield every event encountered""" 203 | for chunk in self._iter_chunks(iterator): 204 | # Split before decoding so splitlines() only uses \r and \n 205 | for raw_line in chunk.splitlines(): 206 | line = raw_line.decode("utf-8") 207 | sse = self.decode(line) 208 | if sse: 209 | yield sse 210 | 211 | def _iter_chunks(self, iterator: Iterator[bytes]) -> Iterator[bytes]: 212 | """Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks""" 213 | data = b"" 214 | for chunk in iterator: 215 | for line in chunk.splitlines(keepends=True): 216 | data += line 217 | if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")): 218 | yield data 219 | data = b"" 220 | if data: 221 | yield data 222 | 223 | async def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]: 224 | """Given an iterator that yields raw binary data, iterate over it & yield every event encountered""" 225 | async for chunk in self._aiter_chunks(iterator): 226 | # Split before decoding so splitlines() only uses \r and \n 227 | for raw_line in chunk.splitlines(): 228 | line = raw_line.decode("utf-8") 229 | sse = self.decode(line) 230 | if sse: 231 | yield sse 232 | 233 | async def _aiter_chunks(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[bytes]: 234 | """Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks""" 235 | data = b"" 236 | async for chunk in iterator: 237 | for line in chunk.splitlines(keepends=True): 238 | data += line 239 | if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")): 240 | yield data 241 | data = b"" 242 | if data: 243 | yield data 244 | 245 | def decode(self, line: str) -> ServerSentEvent | None: 246 | # See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation # noqa: E501 247 | 248 | if not line: 249 | if not self._event and not self._data and not self._last_event_id and self._retry is None: 250 | return None 251 | 252 | sse = ServerSentEvent( 253 | event=self._event, 254 | data="\n".join(self._data), 255 | id=self._last_event_id, 256 | retry=self._retry, 257 | ) 258 | 259 | # NOTE: as per the SSE spec, do not reset last_event_id. 260 | self._event = None 261 | self._data = [] 262 | self._retry = None 263 | 264 | return sse 265 | 266 | if line.startswith(":"): 267 | return None 268 | 269 | fieldname, _, value = line.partition(":") 270 | 271 | if value.startswith(" "): 272 | value = value[1:] 273 | 274 | if fieldname == "event": 275 | self._event = value 276 | elif fieldname == "data": 277 | self._data.append(value) 278 | elif fieldname == "id": 279 | if "\0" in value: 280 | pass 281 | else: 282 | self._last_event_id = value 283 | elif fieldname == "retry": 284 | try: 285 | self._retry = int(value) 286 | except (TypeError, ValueError): 287 | pass 288 | else: 289 | pass # Field is ignored. 290 | 291 | return None 292 | 293 | 294 | @runtime_checkable 295 | class SSEBytesDecoder(Protocol): 296 | def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]: 297 | """Given an iterator that yields raw binary data, iterate over it & yield every event encountered""" 298 | ... 299 | 300 | def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]: 301 | """Given an async iterator that yields raw binary data, iterate over it & yield every event encountered""" 302 | ... 303 | 304 | 305 | def is_stream_class_type(typ: type) -> TypeGuard[type[Stream[object]] | type[AsyncStream[object]]]: 306 | """TypeGuard for determining whether or not the given type is a subclass of `Stream` / `AsyncStream`""" 307 | origin = get_origin(typ) or typ 308 | return inspect.isclass(origin) and issubclass(origin, (Stream, AsyncStream)) 309 | 310 | 311 | def extract_stream_chunk_type( 312 | stream_cls: type, 313 | *, 314 | failure_message: str | None = None, 315 | ) -> type: 316 | """Given a type like `Stream[T]`, returns the generic type variable `T`. 317 | 318 | This also handles the case where a concrete subclass is given, e.g. 319 | ```py 320 | class MyStream(Stream[bytes]): 321 | ... 322 | 323 | extract_stream_chunk_type(MyStream) -> bytes 324 | ``` 325 | """ 326 | from ._base_client import Stream, AsyncStream 327 | 328 | return extract_type_var_from_base( 329 | stream_cls, 330 | index=0, 331 | generic_bases=cast("tuple[type, ...]", (Stream, AsyncStream)), 332 | failure_message=failure_message, 333 | ) 334 | -------------------------------------------------------------------------------- /src/atla/_types.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from os import PathLike 4 | from typing import ( 5 | IO, 6 | TYPE_CHECKING, 7 | Any, 8 | Dict, 9 | List, 10 | Type, 11 | Tuple, 12 | Union, 13 | Mapping, 14 | TypeVar, 15 | Callable, 16 | Optional, 17 | Sequence, 18 | ) 19 | from typing_extensions import Set, Literal, Protocol, TypeAlias, TypedDict, override, runtime_checkable 20 | 21 | import httpx 22 | import pydantic 23 | from httpx import URL, Proxy, Timeout, Response, BaseTransport, AsyncBaseTransport 24 | 25 | if TYPE_CHECKING: 26 | from ._models import BaseModel 27 | from ._response import APIResponse, AsyncAPIResponse 28 | 29 | Transport = BaseTransport 30 | AsyncTransport = AsyncBaseTransport 31 | Query = Mapping[str, object] 32 | Body = object 33 | AnyMapping = Mapping[str, object] 34 | ModelT = TypeVar("ModelT", bound=pydantic.BaseModel) 35 | _T = TypeVar("_T") 36 | 37 | 38 | # Approximates httpx internal ProxiesTypes and RequestFiles types 39 | # while adding support for `PathLike` instances 40 | ProxiesDict = Dict["str | URL", Union[None, str, URL, Proxy]] 41 | ProxiesTypes = Union[str, Proxy, ProxiesDict] 42 | if TYPE_CHECKING: 43 | Base64FileInput = Union[IO[bytes], PathLike[str]] 44 | FileContent = Union[IO[bytes], bytes, PathLike[str]] 45 | else: 46 | Base64FileInput = Union[IO[bytes], PathLike] 47 | FileContent = Union[IO[bytes], bytes, PathLike] # PathLike is not subscriptable in Python 3.8. 48 | FileTypes = Union[ 49 | # file (or bytes) 50 | FileContent, 51 | # (filename, file (or bytes)) 52 | Tuple[Optional[str], FileContent], 53 | # (filename, file (or bytes), content_type) 54 | Tuple[Optional[str], FileContent, Optional[str]], 55 | # (filename, file (or bytes), content_type, headers) 56 | Tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]], 57 | ] 58 | RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]] 59 | 60 | # duplicate of the above but without our custom file support 61 | HttpxFileContent = Union[IO[bytes], bytes] 62 | HttpxFileTypes = Union[ 63 | # file (or bytes) 64 | HttpxFileContent, 65 | # (filename, file (or bytes)) 66 | Tuple[Optional[str], HttpxFileContent], 67 | # (filename, file (or bytes), content_type) 68 | Tuple[Optional[str], HttpxFileContent, Optional[str]], 69 | # (filename, file (or bytes), content_type, headers) 70 | Tuple[Optional[str], HttpxFileContent, Optional[str], Mapping[str, str]], 71 | ] 72 | HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[Tuple[str, HttpxFileTypes]]] 73 | 74 | # Workaround to support (cast_to: Type[ResponseT]) -> ResponseT 75 | # where ResponseT includes `None`. In order to support directly 76 | # passing `None`, overloads would have to be defined for every 77 | # method that uses `ResponseT` which would lead to an unacceptable 78 | # amount of code duplication and make it unreadable. See _base_client.py 79 | # for example usage. 80 | # 81 | # This unfortunately means that you will either have 82 | # to import this type and pass it explicitly: 83 | # 84 | # from atla import NoneType 85 | # client.get('/foo', cast_to=NoneType) 86 | # 87 | # or build it yourself: 88 | # 89 | # client.get('/foo', cast_to=type(None)) 90 | if TYPE_CHECKING: 91 | NoneType: Type[None] 92 | else: 93 | NoneType = type(None) 94 | 95 | 96 | class RequestOptions(TypedDict, total=False): 97 | headers: Headers 98 | max_retries: int 99 | timeout: float | Timeout | None 100 | params: Query 101 | extra_json: AnyMapping 102 | idempotency_key: str 103 | 104 | 105 | # Sentinel class used until PEP 0661 is accepted 106 | class NotGiven: 107 | """ 108 | A sentinel singleton class used to distinguish omitted keyword arguments 109 | from those passed in with the value None (which may have different behavior). 110 | 111 | For example: 112 | 113 | ```py 114 | def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: ... 115 | 116 | 117 | get(timeout=1) # 1s timeout 118 | get(timeout=None) # No timeout 119 | get() # Default timeout behavior, which may not be statically known at the method definition. 120 | ``` 121 | """ 122 | 123 | def __bool__(self) -> Literal[False]: 124 | return False 125 | 126 | @override 127 | def __repr__(self) -> str: 128 | return "NOT_GIVEN" 129 | 130 | 131 | NotGivenOr = Union[_T, NotGiven] 132 | NOT_GIVEN = NotGiven() 133 | 134 | 135 | class Omit: 136 | """In certain situations you need to be able to represent a case where a default value has 137 | to be explicitly removed and `None` is not an appropriate substitute, for example: 138 | 139 | ```py 140 | # as the default `Content-Type` header is `application/json` that will be sent 141 | client.post("/upload/files", files={"file": b"my raw file content"}) 142 | 143 | # you can't explicitly override the header as it has to be dynamically generated 144 | # to look something like: 'multipart/form-data; boundary=0d8382fcf5f8c3be01ca2e11002d2983' 145 | client.post(..., headers={"Content-Type": "multipart/form-data"}) 146 | 147 | # instead you can remove the default `application/json` header by passing Omit 148 | client.post(..., headers={"Content-Type": Omit()}) 149 | ``` 150 | """ 151 | 152 | def __bool__(self) -> Literal[False]: 153 | return False 154 | 155 | 156 | @runtime_checkable 157 | class ModelBuilderProtocol(Protocol): 158 | @classmethod 159 | def build( 160 | cls: type[_T], 161 | *, 162 | response: Response, 163 | data: object, 164 | ) -> _T: ... 165 | 166 | 167 | Headers = Mapping[str, Union[str, Omit]] 168 | 169 | 170 | class HeadersLikeProtocol(Protocol): 171 | def get(self, __key: str) -> str | None: ... 172 | 173 | 174 | HeadersLike = Union[Headers, HeadersLikeProtocol] 175 | 176 | ResponseT = TypeVar( 177 | "ResponseT", 178 | bound=Union[ 179 | object, 180 | str, 181 | None, 182 | "BaseModel", 183 | List[Any], 184 | Dict[str, Any], 185 | Response, 186 | ModelBuilderProtocol, 187 | "APIResponse[Any]", 188 | "AsyncAPIResponse[Any]", 189 | ], 190 | ) 191 | 192 | StrBytesIntFloat = Union[str, bytes, int, float] 193 | 194 | # Note: copied from Pydantic 195 | # https://github.com/pydantic/pydantic/blob/6f31f8f68ef011f84357330186f603ff295312fd/pydantic/main.py#L79 196 | IncEx: TypeAlias = Union[Set[int], Set[str], Mapping[int, Union["IncEx", bool]], Mapping[str, Union["IncEx", bool]]] 197 | 198 | PostParser = Callable[[Any], Any] 199 | 200 | 201 | @runtime_checkable 202 | class InheritsGeneric(Protocol): 203 | """Represents a type that has inherited from `Generic` 204 | 205 | The `__orig_bases__` property can be used to determine the resolved 206 | type variable for a given base class. 207 | """ 208 | 209 | __orig_bases__: tuple[_GenericAlias] 210 | 211 | 212 | class _GenericAlias(Protocol): 213 | __origin__: type[object] 214 | 215 | 216 | class HttpxSendArgs(TypedDict, total=False): 217 | auth: httpx.Auth 218 | -------------------------------------------------------------------------------- /src/atla/_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from ._sync import asyncify as asyncify 2 | from ._proxy import LazyProxy as LazyProxy 3 | from ._utils import ( 4 | flatten as flatten, 5 | is_dict as is_dict, 6 | is_list as is_list, 7 | is_given as is_given, 8 | is_tuple as is_tuple, 9 | json_safe as json_safe, 10 | lru_cache as lru_cache, 11 | is_mapping as is_mapping, 12 | is_tuple_t as is_tuple_t, 13 | parse_date as parse_date, 14 | is_iterable as is_iterable, 15 | is_sequence as is_sequence, 16 | coerce_float as coerce_float, 17 | is_mapping_t as is_mapping_t, 18 | removeprefix as removeprefix, 19 | removesuffix as removesuffix, 20 | extract_files as extract_files, 21 | is_sequence_t as is_sequence_t, 22 | required_args as required_args, 23 | coerce_boolean as coerce_boolean, 24 | coerce_integer as coerce_integer, 25 | file_from_path as file_from_path, 26 | parse_datetime as parse_datetime, 27 | strip_not_given as strip_not_given, 28 | deepcopy_minimal as deepcopy_minimal, 29 | get_async_library as get_async_library, 30 | maybe_coerce_float as maybe_coerce_float, 31 | get_required_header as get_required_header, 32 | maybe_coerce_boolean as maybe_coerce_boolean, 33 | maybe_coerce_integer as maybe_coerce_integer, 34 | ) 35 | from ._typing import ( 36 | is_list_type as is_list_type, 37 | is_union_type as is_union_type, 38 | extract_type_arg as extract_type_arg, 39 | is_iterable_type as is_iterable_type, 40 | is_required_type as is_required_type, 41 | is_annotated_type as is_annotated_type, 42 | is_type_alias_type as is_type_alias_type, 43 | strip_annotated_type as strip_annotated_type, 44 | extract_type_var_from_base as extract_type_var_from_base, 45 | ) 46 | from ._streams import consume_sync_iterator as consume_sync_iterator, consume_async_iterator as consume_async_iterator 47 | from ._transform import ( 48 | PropertyInfo as PropertyInfo, 49 | transform as transform, 50 | async_transform as async_transform, 51 | maybe_transform as maybe_transform, 52 | async_maybe_transform as async_maybe_transform, 53 | ) 54 | from ._reflection import ( 55 | function_has_argument as function_has_argument, 56 | assert_signatures_in_sync as assert_signatures_in_sync, 57 | ) 58 | -------------------------------------------------------------------------------- /src/atla/_utils/_logs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | logger: logging.Logger = logging.getLogger("atla") 5 | httpx_logger: logging.Logger = logging.getLogger("httpx") 6 | 7 | 8 | def _basic_config() -> None: 9 | # e.g. [2023-10-05 14:12:26 - atla._base_client:818 - DEBUG] HTTP Request: POST http://127.0.0.1:4010/foo/bar "200 OK" 10 | logging.basicConfig( 11 | format="[%(asctime)s - %(name)s:%(lineno)d - %(levelname)s] %(message)s", 12 | datefmt="%Y-%m-%d %H:%M:%S", 13 | ) 14 | 15 | 16 | def setup_logging() -> None: 17 | env = os.environ.get("ATLA_LOG") 18 | if env == "debug": 19 | _basic_config() 20 | logger.setLevel(logging.DEBUG) 21 | httpx_logger.setLevel(logging.DEBUG) 22 | elif env == "info": 23 | _basic_config() 24 | logger.setLevel(logging.INFO) 25 | httpx_logger.setLevel(logging.INFO) 26 | -------------------------------------------------------------------------------- /src/atla/_utils/_proxy.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import Generic, TypeVar, Iterable, cast 5 | from typing_extensions import override 6 | 7 | T = TypeVar("T") 8 | 9 | 10 | class LazyProxy(Generic[T], ABC): 11 | """Implements data methods to pretend that an instance is another instance. 12 | 13 | This includes forwarding attribute access and other methods. 14 | """ 15 | 16 | # Note: we have to special case proxies that themselves return proxies 17 | # to support using a proxy as a catch-all for any random access, e.g. `proxy.foo.bar.baz` 18 | 19 | def __getattr__(self, attr: str) -> object: 20 | proxied = self.__get_proxied__() 21 | if isinstance(proxied, LazyProxy): 22 | return proxied # pyright: ignore 23 | return getattr(proxied, attr) 24 | 25 | @override 26 | def __repr__(self) -> str: 27 | proxied = self.__get_proxied__() 28 | if isinstance(proxied, LazyProxy): 29 | return proxied.__class__.__name__ 30 | return repr(self.__get_proxied__()) 31 | 32 | @override 33 | def __str__(self) -> str: 34 | proxied = self.__get_proxied__() 35 | if isinstance(proxied, LazyProxy): 36 | return proxied.__class__.__name__ 37 | return str(proxied) 38 | 39 | @override 40 | def __dir__(self) -> Iterable[str]: 41 | proxied = self.__get_proxied__() 42 | if isinstance(proxied, LazyProxy): 43 | return [] 44 | return proxied.__dir__() 45 | 46 | @property # type: ignore 47 | @override 48 | def __class__(self) -> type: # pyright: ignore 49 | proxied = self.__get_proxied__() 50 | if issubclass(type(proxied), LazyProxy): 51 | return type(proxied) 52 | return proxied.__class__ 53 | 54 | def __get_proxied__(self) -> T: 55 | return self.__load__() 56 | 57 | def __as_proxied__(self) -> T: 58 | """Helper method that returns the current proxy, typed as the loaded object""" 59 | return cast(T, self) 60 | 61 | @abstractmethod 62 | def __load__(self) -> T: ... 63 | -------------------------------------------------------------------------------- /src/atla/_utils/_reflection.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import inspect 4 | from typing import Any, Callable 5 | 6 | 7 | def function_has_argument(func: Callable[..., Any], arg_name: str) -> bool: 8 | """Returns whether or not the given function has a specific parameter""" 9 | sig = inspect.signature(func) 10 | return arg_name in sig.parameters 11 | 12 | 13 | def assert_signatures_in_sync( 14 | source_func: Callable[..., Any], 15 | check_func: Callable[..., Any], 16 | *, 17 | exclude_params: set[str] = set(), 18 | ) -> None: 19 | """Ensure that the signature of the second function matches the first.""" 20 | 21 | check_sig = inspect.signature(check_func) 22 | source_sig = inspect.signature(source_func) 23 | 24 | errors: list[str] = [] 25 | 26 | for name, source_param in source_sig.parameters.items(): 27 | if name in exclude_params: 28 | continue 29 | 30 | custom_param = check_sig.parameters.get(name) 31 | if not custom_param: 32 | errors.append(f"the `{name}` param is missing") 33 | continue 34 | 35 | if custom_param.annotation != source_param.annotation: 36 | errors.append( 37 | f"types for the `{name}` param are do not match; source={repr(source_param.annotation)} checking={repr(custom_param.annotation)}" 38 | ) 39 | continue 40 | 41 | if errors: 42 | raise AssertionError(f"{len(errors)} errors encountered when comparing signatures:\n\n" + "\n\n".join(errors)) 43 | -------------------------------------------------------------------------------- /src/atla/_utils/_streams.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from typing_extensions import Iterator, AsyncIterator 3 | 4 | 5 | def consume_sync_iterator(iterator: Iterator[Any]) -> None: 6 | for _ in iterator: 7 | ... 8 | 9 | 10 | async def consume_async_iterator(iterator: AsyncIterator[Any]) -> None: 11 | async for _ in iterator: 12 | ... 13 | -------------------------------------------------------------------------------- /src/atla/_utils/_sync.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import sys 4 | import asyncio 5 | import functools 6 | import contextvars 7 | from typing import Any, TypeVar, Callable, Awaitable 8 | from typing_extensions import ParamSpec 9 | 10 | import anyio 11 | import sniffio 12 | import anyio.to_thread 13 | 14 | T_Retval = TypeVar("T_Retval") 15 | T_ParamSpec = ParamSpec("T_ParamSpec") 16 | 17 | 18 | if sys.version_info >= (3, 9): 19 | _asyncio_to_thread = asyncio.to_thread 20 | else: 21 | # backport of https://docs.python.org/3/library/asyncio-task.html#asyncio.to_thread 22 | # for Python 3.8 support 23 | async def _asyncio_to_thread( 24 | func: Callable[T_ParamSpec, T_Retval], /, *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs 25 | ) -> Any: 26 | """Asynchronously run function *func* in a separate thread. 27 | 28 | Any *args and **kwargs supplied for this function are directly passed 29 | to *func*. Also, the current :class:`contextvars.Context` is propagated, 30 | allowing context variables from the main thread to be accessed in the 31 | separate thread. 32 | 33 | Returns a coroutine that can be awaited to get the eventual result of *func*. 34 | """ 35 | loop = asyncio.events.get_running_loop() 36 | ctx = contextvars.copy_context() 37 | func_call = functools.partial(ctx.run, func, *args, **kwargs) 38 | return await loop.run_in_executor(None, func_call) 39 | 40 | 41 | async def to_thread( 42 | func: Callable[T_ParamSpec, T_Retval], /, *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs 43 | ) -> T_Retval: 44 | if sniffio.current_async_library() == "asyncio": 45 | return await _asyncio_to_thread(func, *args, **kwargs) 46 | 47 | return await anyio.to_thread.run_sync( 48 | functools.partial(func, *args, **kwargs), 49 | ) 50 | 51 | 52 | # inspired by `asyncer`, https://github.com/tiangolo/asyncer 53 | def asyncify(function: Callable[T_ParamSpec, T_Retval]) -> Callable[T_ParamSpec, Awaitable[T_Retval]]: 54 | """ 55 | Take a blocking function and create an async one that receives the same 56 | positional and keyword arguments. For python version 3.9 and above, it uses 57 | asyncio.to_thread to run the function in a separate thread. For python version 58 | 3.8, it uses locally defined copy of the asyncio.to_thread function which was 59 | introduced in python 3.9. 60 | 61 | Usage: 62 | 63 | ```python 64 | def blocking_func(arg1, arg2, kwarg1=None): 65 | # blocking code 66 | return result 67 | 68 | 69 | result = asyncify(blocking_function)(arg1, arg2, kwarg1=value1) 70 | ``` 71 | 72 | ## Arguments 73 | 74 | `function`: a blocking regular callable (e.g. a function) 75 | 76 | ## Return 77 | 78 | An async function that takes the same positional and keyword arguments as the 79 | original one, that when called runs the same original function in a thread worker 80 | and returns the result. 81 | """ 82 | 83 | async def wrapper(*args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs) -> T_Retval: 84 | return await to_thread(function, *args, **kwargs) 85 | 86 | return wrapper 87 | -------------------------------------------------------------------------------- /src/atla/_utils/_typing.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import sys 4 | import typing 5 | import typing_extensions 6 | from typing import Any, TypeVar, Iterable, cast 7 | from collections import abc as _c_abc 8 | from typing_extensions import ( 9 | TypeIs, 10 | Required, 11 | Annotated, 12 | get_args, 13 | get_origin, 14 | ) 15 | 16 | from ._utils import lru_cache 17 | from .._types import InheritsGeneric 18 | from .._compat import is_union as _is_union 19 | 20 | 21 | def is_annotated_type(typ: type) -> bool: 22 | return get_origin(typ) == Annotated 23 | 24 | 25 | def is_list_type(typ: type) -> bool: 26 | return (get_origin(typ) or typ) == list 27 | 28 | 29 | def is_iterable_type(typ: type) -> bool: 30 | """If the given type is `typing.Iterable[T]`""" 31 | origin = get_origin(typ) or typ 32 | return origin == Iterable or origin == _c_abc.Iterable 33 | 34 | 35 | def is_union_type(typ: type) -> bool: 36 | return _is_union(get_origin(typ)) 37 | 38 | 39 | def is_required_type(typ: type) -> bool: 40 | return get_origin(typ) == Required 41 | 42 | 43 | def is_typevar(typ: type) -> bool: 44 | # type ignore is required because type checkers 45 | # think this expression will always return False 46 | return type(typ) == TypeVar # type: ignore 47 | 48 | 49 | _TYPE_ALIAS_TYPES: tuple[type[typing_extensions.TypeAliasType], ...] = (typing_extensions.TypeAliasType,) 50 | if sys.version_info >= (3, 12): 51 | _TYPE_ALIAS_TYPES = (*_TYPE_ALIAS_TYPES, typing.TypeAliasType) 52 | 53 | 54 | def is_type_alias_type(tp: Any, /) -> TypeIs[typing_extensions.TypeAliasType]: 55 | """Return whether the provided argument is an instance of `TypeAliasType`. 56 | 57 | ```python 58 | type Int = int 59 | is_type_alias_type(Int) 60 | # > True 61 | Str = TypeAliasType("Str", str) 62 | is_type_alias_type(Str) 63 | # > True 64 | ``` 65 | """ 66 | return isinstance(tp, _TYPE_ALIAS_TYPES) 67 | 68 | 69 | # Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]] 70 | @lru_cache(maxsize=8096) 71 | def strip_annotated_type(typ: type) -> type: 72 | if is_required_type(typ) or is_annotated_type(typ): 73 | return strip_annotated_type(cast(type, get_args(typ)[0])) 74 | 75 | return typ 76 | 77 | 78 | def extract_type_arg(typ: type, index: int) -> type: 79 | args = get_args(typ) 80 | try: 81 | return cast(type, args[index]) 82 | except IndexError as err: 83 | raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err 84 | 85 | 86 | def extract_type_var_from_base( 87 | typ: type, 88 | *, 89 | generic_bases: tuple[type, ...], 90 | index: int, 91 | failure_message: str | None = None, 92 | ) -> type: 93 | """Given a type like `Foo[T]`, returns the generic type variable `T`. 94 | 95 | This also handles the case where a concrete subclass is given, e.g. 96 | ```py 97 | class MyResponse(Foo[bytes]): 98 | ... 99 | 100 | extract_type_var(MyResponse, bases=(Foo,), index=0) -> bytes 101 | ``` 102 | 103 | And where a generic subclass is given: 104 | ```py 105 | _T = TypeVar('_T') 106 | class MyResponse(Foo[_T]): 107 | ... 108 | 109 | extract_type_var(MyResponse[bytes], bases=(Foo,), index=0) -> bytes 110 | ``` 111 | """ 112 | cls = cast(object, get_origin(typ) or typ) 113 | if cls in generic_bases: # pyright: ignore[reportUnnecessaryContains] 114 | # we're given the class directly 115 | return extract_type_arg(typ, index) 116 | 117 | # if a subclass is given 118 | # --- 119 | # this is needed as __orig_bases__ is not present in the typeshed stubs 120 | # because it is intended to be for internal use only, however there does 121 | # not seem to be a way to resolve generic TypeVars for inherited subclasses 122 | # without using it. 123 | if isinstance(cls, InheritsGeneric): 124 | target_base_class: Any | None = None 125 | for base in cls.__orig_bases__: 126 | if base.__origin__ in generic_bases: 127 | target_base_class = base 128 | break 129 | 130 | if target_base_class is None: 131 | raise RuntimeError( 132 | "Could not find the generic base class;\n" 133 | "This should never happen;\n" 134 | f"Does {cls} inherit from one of {generic_bases} ?" 135 | ) 136 | 137 | extracted = extract_type_arg(target_base_class, index) 138 | if is_typevar(extracted): 139 | # If the extracted type argument is itself a type variable 140 | # then that means the subclass itself is generic, so we have 141 | # to resolve the type argument from the class itself, not 142 | # the base class. 143 | # 144 | # Note: if there is more than 1 type argument, the subclass could 145 | # change the ordering of the type arguments, this is not currently 146 | # supported. 147 | return extract_type_arg(typ, index) 148 | 149 | return extracted 150 | 151 | raise RuntimeError(failure_message or f"Could not resolve inner type variable at index {index} for {typ}") 152 | -------------------------------------------------------------------------------- /src/atla/_version.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | __title__ = "atla" 4 | __version__ = "0.6.2" # x-release-please-version 5 | -------------------------------------------------------------------------------- /src/atla/lib/.keep: -------------------------------------------------------------------------------- 1 | File generated from our OpenAPI spec by Stainless. 2 | 3 | This directory can be used to store custom files to expand the SDK. 4 | It is ignored by Stainless code generation and its content (other than this keep file) won't be touched. -------------------------------------------------------------------------------- /src/atla/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atla-ai/atla-sdk-python/6c8abda183cf41c9fa631358d9458d48b6bf05ce/src/atla/py.typed -------------------------------------------------------------------------------- /src/atla/resources/__init__.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | from .chat import ( 4 | ChatResource, 5 | AsyncChatResource, 6 | ChatResourceWithRawResponse, 7 | AsyncChatResourceWithRawResponse, 8 | ChatResourceWithStreamingResponse, 9 | AsyncChatResourceWithStreamingResponse, 10 | ) 11 | from .metrics import ( 12 | MetricsResource, 13 | AsyncMetricsResource, 14 | MetricsResourceWithRawResponse, 15 | AsyncMetricsResourceWithRawResponse, 16 | MetricsResourceWithStreamingResponse, 17 | AsyncMetricsResourceWithStreamingResponse, 18 | ) 19 | from .evaluation import ( 20 | EvaluationResource, 21 | AsyncEvaluationResource, 22 | EvaluationResourceWithRawResponse, 23 | AsyncEvaluationResourceWithRawResponse, 24 | EvaluationResourceWithStreamingResponse, 25 | AsyncEvaluationResourceWithStreamingResponse, 26 | ) 27 | 28 | __all__ = [ 29 | "ChatResource", 30 | "AsyncChatResource", 31 | "ChatResourceWithRawResponse", 32 | "AsyncChatResourceWithRawResponse", 33 | "ChatResourceWithStreamingResponse", 34 | "AsyncChatResourceWithStreamingResponse", 35 | "EvaluationResource", 36 | "AsyncEvaluationResource", 37 | "EvaluationResourceWithRawResponse", 38 | "AsyncEvaluationResourceWithRawResponse", 39 | "EvaluationResourceWithStreamingResponse", 40 | "AsyncEvaluationResourceWithStreamingResponse", 41 | "MetricsResource", 42 | "AsyncMetricsResource", 43 | "MetricsResourceWithRawResponse", 44 | "AsyncMetricsResourceWithRawResponse", 45 | "MetricsResourceWithStreamingResponse", 46 | "AsyncMetricsResourceWithStreamingResponse", 47 | ] 48 | -------------------------------------------------------------------------------- /src/atla/resources/chat/__init__.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | from .chat import ( 4 | ChatResource, 5 | AsyncChatResource, 6 | ChatResourceWithRawResponse, 7 | AsyncChatResourceWithRawResponse, 8 | ChatResourceWithStreamingResponse, 9 | AsyncChatResourceWithStreamingResponse, 10 | ) 11 | from .completions import ( 12 | CompletionsResource, 13 | AsyncCompletionsResource, 14 | CompletionsResourceWithRawResponse, 15 | AsyncCompletionsResourceWithRawResponse, 16 | CompletionsResourceWithStreamingResponse, 17 | AsyncCompletionsResourceWithStreamingResponse, 18 | ) 19 | 20 | __all__ = [ 21 | "CompletionsResource", 22 | "AsyncCompletionsResource", 23 | "CompletionsResourceWithRawResponse", 24 | "AsyncCompletionsResourceWithRawResponse", 25 | "CompletionsResourceWithStreamingResponse", 26 | "AsyncCompletionsResourceWithStreamingResponse", 27 | "ChatResource", 28 | "AsyncChatResource", 29 | "ChatResourceWithRawResponse", 30 | "AsyncChatResourceWithRawResponse", 31 | "ChatResourceWithStreamingResponse", 32 | "AsyncChatResourceWithStreamingResponse", 33 | ] 34 | -------------------------------------------------------------------------------- /src/atla/resources/chat/chat.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | from __future__ import annotations 4 | 5 | from ..._compat import cached_property 6 | from ..._resource import SyncAPIResource, AsyncAPIResource 7 | from .completions import ( 8 | CompletionsResource, 9 | AsyncCompletionsResource, 10 | CompletionsResourceWithRawResponse, 11 | AsyncCompletionsResourceWithRawResponse, 12 | CompletionsResourceWithStreamingResponse, 13 | AsyncCompletionsResourceWithStreamingResponse, 14 | ) 15 | 16 | __all__ = ["ChatResource", "AsyncChatResource"] 17 | 18 | 19 | class ChatResource(SyncAPIResource): 20 | @cached_property 21 | def completions(self) -> CompletionsResource: 22 | return CompletionsResource(self._client) 23 | 24 | @cached_property 25 | def with_raw_response(self) -> ChatResourceWithRawResponse: 26 | """ 27 | This property can be used as a prefix for any HTTP method call to return 28 | the raw response object instead of the parsed content. 29 | 30 | For more information, see https://www.github.com/atla-ai/atla-sdk-python#accessing-raw-response-data-eg-headers 31 | """ 32 | return ChatResourceWithRawResponse(self) 33 | 34 | @cached_property 35 | def with_streaming_response(self) -> ChatResourceWithStreamingResponse: 36 | """ 37 | An alternative to `.with_raw_response` that doesn't eagerly read the response body. 38 | 39 | For more information, see https://www.github.com/atla-ai/atla-sdk-python#with_streaming_response 40 | """ 41 | return ChatResourceWithStreamingResponse(self) 42 | 43 | 44 | class AsyncChatResource(AsyncAPIResource): 45 | @cached_property 46 | def completions(self) -> AsyncCompletionsResource: 47 | return AsyncCompletionsResource(self._client) 48 | 49 | @cached_property 50 | def with_raw_response(self) -> AsyncChatResourceWithRawResponse: 51 | """ 52 | This property can be used as a prefix for any HTTP method call to return 53 | the raw response object instead of the parsed content. 54 | 55 | For more information, see https://www.github.com/atla-ai/atla-sdk-python#accessing-raw-response-data-eg-headers 56 | """ 57 | return AsyncChatResourceWithRawResponse(self) 58 | 59 | @cached_property 60 | def with_streaming_response(self) -> AsyncChatResourceWithStreamingResponse: 61 | """ 62 | An alternative to `.with_raw_response` that doesn't eagerly read the response body. 63 | 64 | For more information, see https://www.github.com/atla-ai/atla-sdk-python#with_streaming_response 65 | """ 66 | return AsyncChatResourceWithStreamingResponse(self) 67 | 68 | 69 | class ChatResourceWithRawResponse: 70 | def __init__(self, chat: ChatResource) -> None: 71 | self._chat = chat 72 | 73 | @cached_property 74 | def completions(self) -> CompletionsResourceWithRawResponse: 75 | return CompletionsResourceWithRawResponse(self._chat.completions) 76 | 77 | 78 | class AsyncChatResourceWithRawResponse: 79 | def __init__(self, chat: AsyncChatResource) -> None: 80 | self._chat = chat 81 | 82 | @cached_property 83 | def completions(self) -> AsyncCompletionsResourceWithRawResponse: 84 | return AsyncCompletionsResourceWithRawResponse(self._chat.completions) 85 | 86 | 87 | class ChatResourceWithStreamingResponse: 88 | def __init__(self, chat: ChatResource) -> None: 89 | self._chat = chat 90 | 91 | @cached_property 92 | def completions(self) -> CompletionsResourceWithStreamingResponse: 93 | return CompletionsResourceWithStreamingResponse(self._chat.completions) 94 | 95 | 96 | class AsyncChatResourceWithStreamingResponse: 97 | def __init__(self, chat: AsyncChatResource) -> None: 98 | self._chat = chat 99 | 100 | @cached_property 101 | def completions(self) -> AsyncCompletionsResourceWithStreamingResponse: 102 | return AsyncCompletionsResourceWithStreamingResponse(self._chat.completions) 103 | -------------------------------------------------------------------------------- /src/atla/resources/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | from .metrics import ( 4 | MetricsResource, 5 | AsyncMetricsResource, 6 | MetricsResourceWithRawResponse, 7 | AsyncMetricsResourceWithRawResponse, 8 | MetricsResourceWithStreamingResponse, 9 | AsyncMetricsResourceWithStreamingResponse, 10 | ) 11 | from .prompts import ( 12 | PromptsResource, 13 | AsyncPromptsResource, 14 | PromptsResourceWithRawResponse, 15 | AsyncPromptsResourceWithRawResponse, 16 | PromptsResourceWithStreamingResponse, 17 | AsyncPromptsResourceWithStreamingResponse, 18 | ) 19 | from .few_shot_examples import ( 20 | FewShotExamplesResource, 21 | AsyncFewShotExamplesResource, 22 | FewShotExamplesResourceWithRawResponse, 23 | AsyncFewShotExamplesResourceWithRawResponse, 24 | FewShotExamplesResourceWithStreamingResponse, 25 | AsyncFewShotExamplesResourceWithStreamingResponse, 26 | ) 27 | 28 | __all__ = [ 29 | "PromptsResource", 30 | "AsyncPromptsResource", 31 | "PromptsResourceWithRawResponse", 32 | "AsyncPromptsResourceWithRawResponse", 33 | "PromptsResourceWithStreamingResponse", 34 | "AsyncPromptsResourceWithStreamingResponse", 35 | "FewShotExamplesResource", 36 | "AsyncFewShotExamplesResource", 37 | "FewShotExamplesResourceWithRawResponse", 38 | "AsyncFewShotExamplesResourceWithRawResponse", 39 | "FewShotExamplesResourceWithStreamingResponse", 40 | "AsyncFewShotExamplesResourceWithStreamingResponse", 41 | "MetricsResource", 42 | "AsyncMetricsResource", 43 | "MetricsResourceWithRawResponse", 44 | "AsyncMetricsResourceWithRawResponse", 45 | "MetricsResourceWithStreamingResponse", 46 | "AsyncMetricsResourceWithStreamingResponse", 47 | ] 48 | -------------------------------------------------------------------------------- /src/atla/resources/metrics/few_shot_examples.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Iterable 6 | 7 | import httpx 8 | 9 | from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven 10 | from ..._utils import maybe_transform, async_maybe_transform 11 | from ..._compat import cached_property 12 | from ..._resource import SyncAPIResource, AsyncAPIResource 13 | from ..._response import ( 14 | to_raw_response_wrapper, 15 | to_streamed_response_wrapper, 16 | async_to_raw_response_wrapper, 17 | async_to_streamed_response_wrapper, 18 | ) 19 | from ..._base_client import make_request_options 20 | from ...types.metrics import few_shot_example_set_params 21 | from ...types.metrics.few_shot_example_param import FewShotExampleParam 22 | from ...types.metrics.few_shot_example_set_response import FewShotExampleSetResponse 23 | 24 | __all__ = ["FewShotExamplesResource", "AsyncFewShotExamplesResource"] 25 | 26 | 27 | class FewShotExamplesResource(SyncAPIResource): 28 | @cached_property 29 | def with_raw_response(self) -> FewShotExamplesResourceWithRawResponse: 30 | """ 31 | This property can be used as a prefix for any HTTP method call to return 32 | the raw response object instead of the parsed content. 33 | 34 | For more information, see https://www.github.com/atla-ai/atla-sdk-python#accessing-raw-response-data-eg-headers 35 | """ 36 | return FewShotExamplesResourceWithRawResponse(self) 37 | 38 | @cached_property 39 | def with_streaming_response(self) -> FewShotExamplesResourceWithStreamingResponse: 40 | """ 41 | An alternative to `.with_raw_response` that doesn't eagerly read the response body. 42 | 43 | For more information, see https://www.github.com/atla-ai/atla-sdk-python#with_streaming_response 44 | """ 45 | return FewShotExamplesResourceWithStreamingResponse(self) 46 | 47 | def set( 48 | self, 49 | metric_id: str, 50 | *, 51 | few_shot_examples: Iterable[FewShotExampleParam], 52 | # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. 53 | # The extra values given here take precedence over values defined on the client or passed to this method. 54 | extra_headers: Headers | None = None, 55 | extra_query: Query | None = None, 56 | extra_body: Body | None = None, 57 | timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, 58 | ) -> FewShotExampleSetResponse: 59 | """ 60 | Set few-shot examples for a metric. 61 | 62 | Args: 63 | metric_id: The ID of the metric to set few-shot examples for. 64 | 65 | few_shot_examples: The few-shot examples to upsert. 66 | 67 | extra_headers: Send extra headers 68 | 69 | extra_query: Add additional query parameters to the request 70 | 71 | extra_body: Add additional JSON properties to the request 72 | 73 | timeout: Override the client-level default timeout for this request, in seconds 74 | """ 75 | if not metric_id: 76 | raise ValueError(f"Expected a non-empty value for `metric_id` but received {metric_id!r}") 77 | return self._put( 78 | f"/v1/metrics/{metric_id}/few_shot_examples", 79 | body=maybe_transform( 80 | {"few_shot_examples": few_shot_examples}, few_shot_example_set_params.FewShotExampleSetParams 81 | ), 82 | options=make_request_options( 83 | extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout 84 | ), 85 | cast_to=FewShotExampleSetResponse, 86 | ) 87 | 88 | 89 | class AsyncFewShotExamplesResource(AsyncAPIResource): 90 | @cached_property 91 | def with_raw_response(self) -> AsyncFewShotExamplesResourceWithRawResponse: 92 | """ 93 | This property can be used as a prefix for any HTTP method call to return 94 | the raw response object instead of the parsed content. 95 | 96 | For more information, see https://www.github.com/atla-ai/atla-sdk-python#accessing-raw-response-data-eg-headers 97 | """ 98 | return AsyncFewShotExamplesResourceWithRawResponse(self) 99 | 100 | @cached_property 101 | def with_streaming_response(self) -> AsyncFewShotExamplesResourceWithStreamingResponse: 102 | """ 103 | An alternative to `.with_raw_response` that doesn't eagerly read the response body. 104 | 105 | For more information, see https://www.github.com/atla-ai/atla-sdk-python#with_streaming_response 106 | """ 107 | return AsyncFewShotExamplesResourceWithStreamingResponse(self) 108 | 109 | async def set( 110 | self, 111 | metric_id: str, 112 | *, 113 | few_shot_examples: Iterable[FewShotExampleParam], 114 | # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. 115 | # The extra values given here take precedence over values defined on the client or passed to this method. 116 | extra_headers: Headers | None = None, 117 | extra_query: Query | None = None, 118 | extra_body: Body | None = None, 119 | timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, 120 | ) -> FewShotExampleSetResponse: 121 | """ 122 | Set few-shot examples for a metric. 123 | 124 | Args: 125 | metric_id: The ID of the metric to set few-shot examples for. 126 | 127 | few_shot_examples: The few-shot examples to upsert. 128 | 129 | extra_headers: Send extra headers 130 | 131 | extra_query: Add additional query parameters to the request 132 | 133 | extra_body: Add additional JSON properties to the request 134 | 135 | timeout: Override the client-level default timeout for this request, in seconds 136 | """ 137 | if not metric_id: 138 | raise ValueError(f"Expected a non-empty value for `metric_id` but received {metric_id!r}") 139 | return await self._put( 140 | f"/v1/metrics/{metric_id}/few_shot_examples", 141 | body=await async_maybe_transform( 142 | {"few_shot_examples": few_shot_examples}, few_shot_example_set_params.FewShotExampleSetParams 143 | ), 144 | options=make_request_options( 145 | extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout 146 | ), 147 | cast_to=FewShotExampleSetResponse, 148 | ) 149 | 150 | 151 | class FewShotExamplesResourceWithRawResponse: 152 | def __init__(self, few_shot_examples: FewShotExamplesResource) -> None: 153 | self._few_shot_examples = few_shot_examples 154 | 155 | self.set = to_raw_response_wrapper( 156 | few_shot_examples.set, 157 | ) 158 | 159 | 160 | class AsyncFewShotExamplesResourceWithRawResponse: 161 | def __init__(self, few_shot_examples: AsyncFewShotExamplesResource) -> None: 162 | self._few_shot_examples = few_shot_examples 163 | 164 | self.set = async_to_raw_response_wrapper( 165 | few_shot_examples.set, 166 | ) 167 | 168 | 169 | class FewShotExamplesResourceWithStreamingResponse: 170 | def __init__(self, few_shot_examples: FewShotExamplesResource) -> None: 171 | self._few_shot_examples = few_shot_examples 172 | 173 | self.set = to_streamed_response_wrapper( 174 | few_shot_examples.set, 175 | ) 176 | 177 | 178 | class AsyncFewShotExamplesResourceWithStreamingResponse: 179 | def __init__(self, few_shot_examples: AsyncFewShotExamplesResource) -> None: 180 | self._few_shot_examples = few_shot_examples 181 | 182 | self.set = async_to_streamed_response_wrapper( 183 | few_shot_examples.set, 184 | ) 185 | -------------------------------------------------------------------------------- /src/atla/types/__init__.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | from __future__ import annotations 4 | 5 | from .metric import Metric as Metric 6 | from .evaluation import Evaluation as Evaluation 7 | from .chat_completion import ChatCompletion as ChatCompletion 8 | from .metric_list_params import MetricListParams as MetricListParams 9 | from .metric_get_response import MetricGetResponse as MetricGetResponse 10 | from .metric_create_params import MetricCreateParams as MetricCreateParams 11 | from .metric_list_response import MetricListResponse as MetricListResponse 12 | from .metric_create_response import MetricCreateResponse as MetricCreateResponse 13 | from .metric_delete_response import MetricDeleteResponse as MetricDeleteResponse 14 | from .evaluation_create_params import EvaluationCreateParams as EvaluationCreateParams 15 | -------------------------------------------------------------------------------- /src/atla/types/chat/__init__.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | from __future__ import annotations 4 | 5 | from .completion_create_params import CompletionCreateParams as CompletionCreateParams 6 | -------------------------------------------------------------------------------- /src/atla/types/chat/completion_create_params.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Union, Iterable, Optional 6 | from typing_extensions import Literal, Required, TypeAlias, TypedDict 7 | 8 | __all__ = [ 9 | "CompletionCreateParams", 10 | "Message", 11 | "MessageChatCompletionDeveloperMessageParam", 12 | "MessageChatCompletionSystemMessageParam", 13 | "MessageChatCompletionUserMessageParam", 14 | "MessageChatCompletionAssistantMessageParam", 15 | "MessageChatCompletionAssistantMessageParamAudio", 16 | "MessageChatCompletionAssistantMessageParamFunctionCall", 17 | "MessageChatCompletionAssistantMessageParamToolCall", 18 | "MessageChatCompletionAssistantMessageParamToolCallFunction", 19 | "MessageChatCompletionToolMessageParam", 20 | "MessageChatCompletionFunctionMessageParam", 21 | ] 22 | 23 | 24 | class CompletionCreateParams(TypedDict, total=False): 25 | messages: Required[Iterable[Message]] 26 | """A list of messages comprising the conversation so far. 27 | 28 | See the 29 | [OpenAI API reference](https://platform.openai.com/docs/api-reference/chat/create) 30 | for more information. 31 | """ 32 | 33 | model: Required[str] 34 | """The ID or name of the Atla evaluator model to use. 35 | 36 | This may point to a specific model version or a model family. If a model family 37 | is provided, the default model version for that family will be used. 38 | """ 39 | 40 | max_completion_tokens: Optional[int] 41 | """An upper bound for the number of tokens that can be generated for an evaluation. 42 | 43 | See the 44 | [OpenAI API reference](https://platform.openai.com/docs/api-reference/chat/create) 45 | for more information. 46 | """ 47 | 48 | max_tokens: Optional[int] 49 | """ 50 | The maximum number of [tokens](/tokenizer) that can be generated in the 51 | evaluation. This value is now deprecated in favor of `max_completion_tokens`. 52 | See the 53 | [OpenAI API reference](https://platform.openai.com/docs/api-reference/chat/create) 54 | for more information. 55 | """ 56 | 57 | temperature: Optional[float] 58 | """What sampling temperature to use, between 0 and 2. 59 | 60 | See the 61 | [OpenAI API reference](https://platform.openai.com/docs/api-reference/chat/create) 62 | for more information 63 | """ 64 | 65 | top_p: Optional[float] 66 | """ 67 | An alternative to sampling with temperature, called nucleus sampling, wherethe 68 | model considers the results of the tokens with top_p probability mass. See the 69 | [OpenAI API reference](https://platform.openai.com/docs/api-reference/chat/create) 70 | for more information. 71 | """ 72 | 73 | 74 | class MessageChatCompletionDeveloperMessageParam(TypedDict, total=False): 75 | content: Required[str] 76 | 77 | role: Required[Literal["developer"]] 78 | 79 | name: str 80 | 81 | 82 | class MessageChatCompletionSystemMessageParam(TypedDict, total=False): 83 | content: Required[str] 84 | 85 | role: Required[Literal["system"]] 86 | 87 | name: str 88 | 89 | 90 | class MessageChatCompletionUserMessageParam(TypedDict, total=False): 91 | content: Required[str] 92 | 93 | role: Required[Literal["user"]] 94 | 95 | name: str 96 | 97 | 98 | class MessageChatCompletionAssistantMessageParamAudio(TypedDict, total=False): 99 | id: Required[str] 100 | 101 | 102 | class MessageChatCompletionAssistantMessageParamFunctionCall(TypedDict, total=False): 103 | arguments: Required[str] 104 | 105 | name: Required[str] 106 | 107 | 108 | class MessageChatCompletionAssistantMessageParamToolCallFunction(TypedDict, total=False): 109 | arguments: Required[str] 110 | 111 | name: Required[str] 112 | 113 | 114 | class MessageChatCompletionAssistantMessageParamToolCall(TypedDict, total=False): 115 | id: Required[str] 116 | 117 | function: Required[MessageChatCompletionAssistantMessageParamToolCallFunction] 118 | 119 | type: Required[Literal["function"]] 120 | 121 | 122 | class MessageChatCompletionAssistantMessageParam(TypedDict, total=False): 123 | role: Required[Literal["assistant"]] 124 | 125 | audio: Optional[MessageChatCompletionAssistantMessageParamAudio] 126 | 127 | content: Optional[str] 128 | 129 | function_call: Optional[MessageChatCompletionAssistantMessageParamFunctionCall] 130 | 131 | name: str 132 | 133 | refusal: Optional[str] 134 | 135 | tool_calls: Iterable[MessageChatCompletionAssistantMessageParamToolCall] 136 | 137 | 138 | class MessageChatCompletionToolMessageParam(TypedDict, total=False): 139 | content: Required[str] 140 | 141 | role: Required[Literal["tool"]] 142 | 143 | tool_call_id: Required[str] 144 | 145 | 146 | class MessageChatCompletionFunctionMessageParam(TypedDict, total=False): 147 | content: Required[Optional[str]] 148 | 149 | name: Required[str] 150 | 151 | role: Required[Literal["function"]] 152 | 153 | 154 | Message: TypeAlias = Union[ 155 | MessageChatCompletionDeveloperMessageParam, 156 | MessageChatCompletionSystemMessageParam, 157 | MessageChatCompletionUserMessageParam, 158 | MessageChatCompletionAssistantMessageParam, 159 | MessageChatCompletionToolMessageParam, 160 | MessageChatCompletionFunctionMessageParam, 161 | ] 162 | -------------------------------------------------------------------------------- /src/atla/types/chat_completion.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | import builtins 4 | from typing import TYPE_CHECKING, List, Optional 5 | from typing_extensions import Literal 6 | 7 | from .._models import BaseModel 8 | 9 | __all__ = [ 10 | "ChatCompletion", 11 | "Choice", 12 | "ChoiceMessage", 13 | "ChoiceMessageAudio", 14 | "ChoiceMessageFunctionCall", 15 | "ChoiceMessageToolCall", 16 | "ChoiceMessageToolCallFunction", 17 | "ChoiceLogprobs", 18 | "ChoiceLogprobsContent", 19 | "ChoiceLogprobsContentTopLogprob", 20 | "ChoiceLogprobsRefusal", 21 | "ChoiceLogprobsRefusalTopLogprob", 22 | "Usage", 23 | "UsageCompletionTokensDetails", 24 | "UsagePromptTokensDetails", 25 | ] 26 | 27 | 28 | class ChoiceMessageAudio(BaseModel): 29 | id: str 30 | 31 | data: str 32 | 33 | expires_at: int 34 | 35 | transcript: str 36 | 37 | if TYPE_CHECKING: 38 | # Stub to indicate that arbitrary properties are accepted. 39 | # To access properties that are not valid identifiers you can use `getattr`, e.g. 40 | # `getattr(obj, '$type')` 41 | def __getattr__(self, attr: str) -> object: ... 42 | 43 | 44 | class ChoiceMessageFunctionCall(BaseModel): 45 | arguments: str 46 | 47 | name: str 48 | 49 | if TYPE_CHECKING: 50 | # Stub to indicate that arbitrary properties are accepted. 51 | # To access properties that are not valid identifiers you can use `getattr`, e.g. 52 | # `getattr(obj, '$type')` 53 | def __getattr__(self, attr: str) -> object: ... 54 | 55 | 56 | class ChoiceMessageToolCallFunction(BaseModel): 57 | arguments: str 58 | 59 | name: str 60 | 61 | if TYPE_CHECKING: 62 | # Stub to indicate that arbitrary properties are accepted. 63 | # To access properties that are not valid identifiers you can use `getattr`, e.g. 64 | # `getattr(obj, '$type')` 65 | def __getattr__(self, attr: str) -> object: ... 66 | 67 | 68 | class ChoiceMessageToolCall(BaseModel): 69 | id: str 70 | 71 | function: ChoiceMessageToolCallFunction 72 | 73 | type: Literal["function"] 74 | 75 | if TYPE_CHECKING: 76 | # Stub to indicate that arbitrary properties are accepted. 77 | # To access properties that are not valid identifiers you can use `getattr`, e.g. 78 | # `getattr(obj, '$type')` 79 | def __getattr__(self, attr: str) -> object: ... 80 | 81 | 82 | class ChoiceMessage(BaseModel): 83 | role: Literal["assistant"] 84 | 85 | audio: Optional[ChoiceMessageAudio] = None 86 | 87 | content: Optional[str] = None 88 | 89 | function_call: Optional[ChoiceMessageFunctionCall] = None 90 | 91 | refusal: Optional[str] = None 92 | 93 | tool_calls: Optional[List[ChoiceMessageToolCall]] = None 94 | 95 | if TYPE_CHECKING: 96 | # Stub to indicate that arbitrary properties are accepted. 97 | # To access properties that are not valid identifiers you can use `getattr`, e.g. 98 | # `getattr(obj, '$type')` 99 | def __getattr__(self, attr: str) -> object: ... 100 | 101 | 102 | class ChoiceLogprobsContentTopLogprob(BaseModel): 103 | token: str 104 | 105 | logprob: float 106 | 107 | bytes: Optional[List[int]] = None 108 | 109 | if TYPE_CHECKING: 110 | # Stub to indicate that arbitrary properties are accepted. 111 | # To access properties that are not valid identifiers you can use `getattr`, e.g. 112 | # `getattr(obj, '$type')` 113 | def __getattr__(self, attr: str) -> object: ... 114 | 115 | 116 | class ChoiceLogprobsContent(BaseModel): 117 | token: str 118 | 119 | logprob: float 120 | 121 | top_logprobs: List[ChoiceLogprobsContentTopLogprob] 122 | 123 | bytes: Optional[List[int]] = None 124 | 125 | if TYPE_CHECKING: 126 | # Stub to indicate that arbitrary properties are accepted. 127 | # To access properties that are not valid identifiers you can use `getattr`, e.g. 128 | # `getattr(obj, '$type')` 129 | def __getattr__(self, attr: str) -> object: ... 130 | 131 | 132 | class ChoiceLogprobsRefusalTopLogprob(BaseModel): 133 | token: str 134 | 135 | logprob: float 136 | 137 | bytes: Optional[List[int]] = None 138 | 139 | if TYPE_CHECKING: 140 | # Stub to indicate that arbitrary properties are accepted. 141 | # To access properties that are not valid identifiers you can use `getattr`, e.g. 142 | # `getattr(obj, '$type')` 143 | def __getattr__(self, attr: str) -> object: ... 144 | 145 | 146 | class ChoiceLogprobsRefusal(BaseModel): 147 | token: str 148 | 149 | logprob: float 150 | 151 | top_logprobs: List[ChoiceLogprobsRefusalTopLogprob] 152 | 153 | bytes: Optional[List[int]] = None 154 | 155 | if TYPE_CHECKING: 156 | # Stub to indicate that arbitrary properties are accepted. 157 | # To access properties that are not valid identifiers you can use `getattr`, e.g. 158 | # `getattr(obj, '$type')` 159 | def __getattr__(self, attr: str) -> object: ... 160 | 161 | 162 | class ChoiceLogprobs(BaseModel): 163 | content: Optional[List[ChoiceLogprobsContent]] = None 164 | 165 | refusal: Optional[List[ChoiceLogprobsRefusal]] = None 166 | 167 | if TYPE_CHECKING: 168 | # Stub to indicate that arbitrary properties are accepted. 169 | # To access properties that are not valid identifiers you can use `getattr`, e.g. 170 | # `getattr(obj, '$type')` 171 | def __getattr__(self, attr: str) -> object: ... 172 | 173 | 174 | class Choice(BaseModel): 175 | finish_reason: Literal["stop", "length", "tool_calls", "content_filter", "function_call"] 176 | 177 | index: int 178 | 179 | message: ChoiceMessage 180 | 181 | logprobs: Optional[ChoiceLogprobs] = None 182 | 183 | if TYPE_CHECKING: 184 | # Stub to indicate that arbitrary properties are accepted. 185 | # To access properties that are not valid identifiers you can use `getattr`, e.g. 186 | # `getattr(obj, '$type')` 187 | def __getattr__(self, attr: str) -> object: ... 188 | 189 | 190 | class UsageCompletionTokensDetails(BaseModel): 191 | accepted_prediction_tokens: Optional[int] = None 192 | 193 | audio_tokens: Optional[int] = None 194 | 195 | reasoning_tokens: Optional[int] = None 196 | 197 | rejected_prediction_tokens: Optional[int] = None 198 | 199 | if TYPE_CHECKING: 200 | # Stub to indicate that arbitrary properties are accepted. 201 | # To access properties that are not valid identifiers you can use `getattr`, e.g. 202 | # `getattr(obj, '$type')` 203 | def __getattr__(self, attr: str) -> object: ... 204 | 205 | 206 | class UsagePromptTokensDetails(BaseModel): 207 | audio_tokens: Optional[int] = None 208 | 209 | cached_tokens: Optional[int] = None 210 | 211 | if TYPE_CHECKING: 212 | # Stub to indicate that arbitrary properties are accepted. 213 | # To access properties that are not valid identifiers you can use `getattr`, e.g. 214 | # `getattr(obj, '$type')` 215 | def __getattr__(self, attr: str) -> object: ... 216 | 217 | 218 | class Usage(BaseModel): 219 | completion_tokens: int 220 | 221 | prompt_tokens: int 222 | 223 | total_tokens: int 224 | 225 | completion_tokens_details: Optional[UsageCompletionTokensDetails] = None 226 | 227 | prompt_tokens_details: Optional[UsagePromptTokensDetails] = None 228 | 229 | if TYPE_CHECKING: 230 | # Stub to indicate that arbitrary properties are accepted. 231 | # To access properties that are not valid identifiers you can use `getattr`, e.g. 232 | # `getattr(obj, '$type')` 233 | def __getattr__(self, attr: str) -> object: ... 234 | 235 | 236 | class ChatCompletion(BaseModel): 237 | id: str 238 | 239 | choices: List[Choice] 240 | 241 | created: int 242 | 243 | model: str 244 | 245 | object: Literal["chat.completion"] 246 | 247 | service_tier: Optional[Literal["scale", "default"]] = None 248 | 249 | system_fingerprint: Optional[str] = None 250 | 251 | usage: Optional[Usage] = None 252 | 253 | if TYPE_CHECKING: 254 | # Stub to indicate that arbitrary properties are accepted. 255 | # To access properties that are not valid identifiers you can use `getattr`, e.g. 256 | # `getattr(obj, '$type')` 257 | def __getattr__(self, attr: str) -> builtins.object: ... 258 | -------------------------------------------------------------------------------- /src/atla/types/evaluation.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | from typing import Optional 4 | from typing_extensions import Literal 5 | 6 | from .._compat import PYDANTIC_V2, ConfigDict 7 | from .._models import BaseModel 8 | 9 | __all__ = ["Evaluation", "Result", "ResultEvaluation"] 10 | 11 | 12 | class ResultEvaluation(BaseModel): 13 | critique: str 14 | """The critique of the evaluation.""" 15 | 16 | score: str 17 | """A value representing the evaluation result.""" 18 | 19 | 20 | class Result(BaseModel): 21 | evaluation: ResultEvaluation 22 | """The evaluation results.""" 23 | 24 | model_id: str 25 | """The ID of the Atla evaluator model used.""" 26 | 27 | if PYDANTIC_V2: 28 | # allow fields with a `model_` prefix 29 | model_config = ConfigDict(protected_namespaces=tuple()) 30 | 31 | 32 | class Evaluation(BaseModel): 33 | request_id: str 34 | """The ID of the request the response is for.""" 35 | 36 | result: Result 37 | """The result of the evaluation.""" 38 | 39 | status: Optional[Literal["success"]] = None 40 | -------------------------------------------------------------------------------- /src/atla/types/evaluation_create_params.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Iterable, Optional 6 | from typing_extensions import Required, TypedDict 7 | 8 | from .metrics.few_shot_example_param import FewShotExampleParam 9 | 10 | __all__ = ["EvaluationCreateParams"] 11 | 12 | 13 | class EvaluationCreateParams(TypedDict, total=False): 14 | model_id: Required[str] 15 | """The ID or name of the Atla evaluator model to use. 16 | 17 | This may point to a specific model version or a model family. If a model family 18 | is provided, the default model version for that family will be used. 19 | """ 20 | 21 | model_input: Required[str] 22 | """The input given to a model which produced the `model_output` to be evaluated.""" 23 | 24 | model_output: Required[str] 25 | """The output of the model which is being evaluated. 26 | 27 | This is the `model_output` from the `model_input`. 28 | """ 29 | 30 | evaluation_criteria: Optional[str] 31 | """The criteria used to evaluate the `model_output`. 32 | 33 | Only one of `evaluation_criteria` or `metric_name` can be provided. 34 | """ 35 | 36 | expected_model_output: Optional[str] 37 | """ 38 | An optional reference ("ground-truth" / "gold standard") answer against which to 39 | evaluate the `model_output`. 40 | """ 41 | 42 | few_shot_examples: Iterable[FewShotExampleParam] 43 | """A list of few-shot examples for the evaluation.""" 44 | 45 | metric_name: Optional[str] 46 | """The name of the metric to use for the evaluation. 47 | 48 | Only one of `evaluation_criteria` or `metric_name` can be provided. 49 | """ 50 | 51 | model_context: Optional[str] 52 | """ 53 | Any additional context provided to the model which received the `model_input` 54 | and produced the `model_output`. 55 | """ 56 | 57 | prompt_version: Optional[int] 58 | """The version of the prompt to use for the evaluation. 59 | 60 | If not provided, the active prompt version will be used. 61 | """ 62 | -------------------------------------------------------------------------------- /src/atla/types/metric.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | from typing import Dict, List, Optional 4 | from datetime import datetime 5 | from typing_extensions import Literal 6 | 7 | from pydantic import Field as FieldInfo 8 | 9 | from .._models import BaseModel 10 | from .metrics.prompt import Prompt 11 | from .metrics.few_shot_example import FewShotExample 12 | 13 | __all__ = ["Metric"] 14 | 15 | 16 | class Metric(BaseModel): 17 | metric_type: Literal["binary", "likert_1_to_5"] 18 | """The type of metric.""" 19 | 20 | name: str 21 | """The name of the metric. 22 | 23 | Metric names must contain only lowercase letters, numbers, hyphens, or 24 | underscores, and must start with a lowercase letter and end with either a 25 | lowercase letter or number. Metric names must be unique within a project. 26 | """ 27 | 28 | project_id: Optional[str] = None 29 | """The ID of the project that the metric belongs to. 30 | 31 | If the metric is shared, this field will be `null`. 32 | """ 33 | 34 | api_id: Optional[str] = FieldInfo(alias="_id", default=None) 35 | """The ID of the metric in the database.""" 36 | 37 | active_prompt_version: Optional[int] = None 38 | """The version of the prompt that is currently active for the metric.""" 39 | 40 | created_at: Optional[datetime] = None 41 | """The creation time of the metric.""" 42 | 43 | description: Optional[str] = None 44 | """An optional description of the metric.""" 45 | 46 | few_shot_examples: Optional[List[FewShotExample]] = None 47 | """The few-shot examples for the metric. At most 3 examples are allowed.""" 48 | 49 | prompts: Optional[Dict[str, Prompt]] = None 50 | """The prompts for the metric, keyed by version.""" 51 | 52 | required_fields: Optional[ 53 | List[Literal["model_input", "model_output", "model_context", "expected_model_output"]] 54 | ] = None 55 | """The fields that are required for the metric. 56 | 57 | All metrics must require at least `model_input` and `model_output`, which are 58 | the default values. 59 | """ 60 | 61 | updated_at: Optional[datetime] = None 62 | """The last update time of the metric.""" 63 | -------------------------------------------------------------------------------- /src/atla/types/metric_create_params.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | from __future__ import annotations 4 | 5 | from typing import List, Optional 6 | from typing_extensions import Literal, Required, TypedDict 7 | 8 | __all__ = ["MetricCreateParams"] 9 | 10 | 11 | class MetricCreateParams(TypedDict, total=False): 12 | metric_type: Required[Literal["binary", "likert_1_to_5"]] 13 | """The type of metric.""" 14 | 15 | name: Required[str] 16 | """The name of the metric. 17 | 18 | Metric names must contain only lowercase letters, numbers, hyphens, or 19 | underscores, and must start with a lowercase letter and end with either a 20 | lowercase letter or number. Metric names must be unique within a project. 21 | """ 22 | 23 | description: Optional[str] 24 | """An optional description of the metric.""" 25 | 26 | required_fields: List[Literal["model_input", "model_output", "model_context", "expected_model_output"]] 27 | """The fields that are required for the metric. 28 | 29 | All metrics must require at least `model_input` and `model_output`, which are 30 | the default values. 31 | """ 32 | -------------------------------------------------------------------------------- /src/atla/types/metric_create_response.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | from typing import Optional 4 | from typing_extensions import Literal 5 | 6 | from .._models import BaseModel 7 | 8 | __all__ = ["MetricCreateResponse"] 9 | 10 | 11 | class MetricCreateResponse(BaseModel): 12 | metric_id: str 13 | """The ID of the created metric.""" 14 | 15 | request_id: str 16 | """The ID of the request the response is for.""" 17 | 18 | status: Optional[Literal["success"]] = None 19 | -------------------------------------------------------------------------------- /src/atla/types/metric_delete_response.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | from typing import Optional 4 | from typing_extensions import Literal 5 | 6 | from .._models import BaseModel 7 | 8 | __all__ = ["MetricDeleteResponse"] 9 | 10 | 11 | class MetricDeleteResponse(BaseModel): 12 | request_id: str 13 | """The ID of the request the response is for.""" 14 | 15 | status: Optional[Literal["success"]] = None 16 | -------------------------------------------------------------------------------- /src/atla/types/metric_get_response.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | from typing import Optional 4 | from typing_extensions import Literal 5 | 6 | from .metric import Metric 7 | from .._models import BaseModel 8 | 9 | __all__ = ["MetricGetResponse"] 10 | 11 | 12 | class MetricGetResponse(BaseModel): 13 | metric: Metric 14 | """The metric retrieved.""" 15 | 16 | request_id: str 17 | """The ID of the request the response is for.""" 18 | 19 | status: Optional[Literal["success"]] = None 20 | -------------------------------------------------------------------------------- /src/atla/types/metric_list_params.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | from __future__ import annotations 4 | 5 | from typing_extensions import TypedDict 6 | 7 | __all__ = ["MetricListParams"] 8 | 9 | 10 | class MetricListParams(TypedDict, total=False): 11 | include_default: bool 12 | """Whether to include default metrics.""" 13 | -------------------------------------------------------------------------------- /src/atla/types/metric_list_response.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | from typing import List, Optional 4 | from typing_extensions import Literal 5 | 6 | from .metric import Metric 7 | from .._models import BaseModel 8 | 9 | __all__ = ["MetricListResponse"] 10 | 11 | 12 | class MetricListResponse(BaseModel): 13 | metrics: List[Metric] 14 | """The metrics retrieved.""" 15 | 16 | request_id: str 17 | """The ID of the request the response is for.""" 18 | 19 | status: Optional[Literal["success"]] = None 20 | -------------------------------------------------------------------------------- /src/atla/types/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | from __future__ import annotations 4 | 5 | from .prompt import Prompt as Prompt 6 | from .few_shot_example import FewShotExample as FewShotExample 7 | from .prompt_get_response import PromptGetResponse as PromptGetResponse 8 | from .prompt_create_params import PromptCreateParams as PromptCreateParams 9 | from .prompt_list_response import PromptListResponse as PromptListResponse 10 | from .few_shot_example_param import FewShotExampleParam as FewShotExampleParam 11 | from .prompt_create_response import PromptCreateResponse as PromptCreateResponse 12 | from .few_shot_example_set_params import FewShotExampleSetParams as FewShotExampleSetParams 13 | from .few_shot_example_set_response import FewShotExampleSetResponse as FewShotExampleSetResponse 14 | from .prompt_set_active_prompt_version_params import ( 15 | PromptSetActivePromptVersionParams as PromptSetActivePromptVersionParams, 16 | ) 17 | from .prompt_set_active_prompt_version_response import ( 18 | PromptSetActivePromptVersionResponse as PromptSetActivePromptVersionResponse, 19 | ) 20 | -------------------------------------------------------------------------------- /src/atla/types/metrics/few_shot_example.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | from typing import Optional 4 | 5 | from ..._compat import PYDANTIC_V2, ConfigDict 6 | from ..._models import BaseModel 7 | 8 | __all__ = ["FewShotExample"] 9 | 10 | 11 | class FewShotExample(BaseModel): 12 | model_input: str 13 | """The input to the model for the few-shot example.""" 14 | 15 | model_output: str 16 | """The output from the model for the few-shot example.""" 17 | 18 | score: str 19 | """The score for the few-shot example.""" 20 | 21 | critique: Optional[str] = None 22 | """The critique for the few-shot example.""" 23 | 24 | expected_model_output: Optional[str] = None 25 | """The expected output from the model for the few-shot example.""" 26 | 27 | model_context: Optional[str] = None 28 | """The context for the few-shot example.""" 29 | 30 | if PYDANTIC_V2: 31 | # allow fields with a `model_` prefix 32 | model_config = ConfigDict(protected_namespaces=tuple()) 33 | -------------------------------------------------------------------------------- /src/atla/types/metrics/few_shot_example_param.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Optional 6 | from typing_extensions import Required, TypedDict 7 | 8 | __all__ = ["FewShotExampleParam"] 9 | 10 | 11 | class FewShotExampleParam(TypedDict, total=False): 12 | model_input: Required[str] 13 | """The input to the model for the few-shot example.""" 14 | 15 | model_output: Required[str] 16 | """The output from the model for the few-shot example.""" 17 | 18 | score: Required[str] 19 | """The score for the few-shot example.""" 20 | 21 | critique: Optional[str] 22 | """The critique for the few-shot example.""" 23 | 24 | expected_model_output: Optional[str] 25 | """The expected output from the model for the few-shot example.""" 26 | 27 | model_context: Optional[str] 28 | """The context for the few-shot example.""" 29 | -------------------------------------------------------------------------------- /src/atla/types/metrics/few_shot_example_set_params.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Iterable 6 | from typing_extensions import Required, TypedDict 7 | 8 | from .few_shot_example_param import FewShotExampleParam 9 | 10 | __all__ = ["FewShotExampleSetParams"] 11 | 12 | 13 | class FewShotExampleSetParams(TypedDict, total=False): 14 | few_shot_examples: Required[Iterable[FewShotExampleParam]] 15 | """The few-shot examples to upsert.""" 16 | -------------------------------------------------------------------------------- /src/atla/types/metrics/few_shot_example_set_response.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | from typing import Optional 4 | from typing_extensions import Literal 5 | 6 | from ..._models import BaseModel 7 | 8 | __all__ = ["FewShotExampleSetResponse"] 9 | 10 | 11 | class FewShotExampleSetResponse(BaseModel): 12 | request_id: str 13 | """The ID of the request the response is for.""" 14 | 15 | status: Optional[Literal["success"]] = None 16 | -------------------------------------------------------------------------------- /src/atla/types/metrics/prompt.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | from typing import Optional 4 | from datetime import datetime 5 | 6 | from ..._models import BaseModel 7 | 8 | __all__ = ["Prompt"] 9 | 10 | 11 | class Prompt(BaseModel): 12 | content: str 13 | """The content of the prompt.""" 14 | 15 | version: int 16 | """The version of the prompt.""" 17 | 18 | created_at: Optional[datetime] = None 19 | """The creation time of the prompt.""" 20 | 21 | updated_at: Optional[datetime] = None 22 | """The last update time of the prompt.""" 23 | -------------------------------------------------------------------------------- /src/atla/types/metrics/prompt_create_params.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | from __future__ import annotations 4 | 5 | from typing_extensions import Required, TypedDict 6 | 7 | __all__ = ["PromptCreateParams"] 8 | 9 | 10 | class PromptCreateParams(TypedDict, total=False): 11 | content: Required[str] 12 | """The content of the prompt to create.""" 13 | -------------------------------------------------------------------------------- /src/atla/types/metrics/prompt_create_response.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | from typing import Optional 4 | from typing_extensions import Literal 5 | 6 | from ..._models import BaseModel 7 | 8 | __all__ = ["PromptCreateResponse"] 9 | 10 | 11 | class PromptCreateResponse(BaseModel): 12 | request_id: str 13 | """The ID of the request the response is for.""" 14 | 15 | version: int 16 | """The version of the created prompt.""" 17 | 18 | status: Optional[Literal["success"]] = None 19 | -------------------------------------------------------------------------------- /src/atla/types/metrics/prompt_get_response.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | from typing import Optional 4 | from typing_extensions import Literal 5 | 6 | from .prompt import Prompt 7 | from ..._models import BaseModel 8 | 9 | __all__ = ["PromptGetResponse"] 10 | 11 | 12 | class PromptGetResponse(BaseModel): 13 | prompt: Prompt 14 | """The prompt retrieved.""" 15 | 16 | request_id: str 17 | """The ID of the request the response is for.""" 18 | 19 | status: Optional[Literal["success"]] = None 20 | -------------------------------------------------------------------------------- /src/atla/types/metrics/prompt_list_response.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | from typing import List, Optional 4 | from typing_extensions import Literal 5 | 6 | from .prompt import Prompt 7 | from ..._models import BaseModel 8 | 9 | __all__ = ["PromptListResponse"] 10 | 11 | 12 | class PromptListResponse(BaseModel): 13 | prompts: List[Prompt] 14 | """The prompts retrieved.""" 15 | 16 | request_id: str 17 | """The ID of the request the response is for.""" 18 | 19 | status: Optional[Literal["success"]] = None 20 | -------------------------------------------------------------------------------- /src/atla/types/metrics/prompt_set_active_prompt_version_params.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | from __future__ import annotations 4 | 5 | from typing_extensions import Required, TypedDict 6 | 7 | __all__ = ["PromptSetActivePromptVersionParams"] 8 | 9 | 10 | class PromptSetActivePromptVersionParams(TypedDict, total=False): 11 | version: Required[int] 12 | """The version of the prompt to set as active.""" 13 | -------------------------------------------------------------------------------- /src/atla/types/metrics/prompt_set_active_prompt_version_response.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | from typing import Optional 4 | from typing_extensions import Literal 5 | 6 | from ..._models import BaseModel 7 | 8 | __all__ = ["PromptSetActivePromptVersionResponse"] 9 | 10 | 11 | class PromptSetActivePromptVersionResponse(BaseModel): 12 | request_id: str 13 | """The ID of the request the response is for.""" 14 | 15 | status: Optional[Literal["success"]] = None 16 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | -------------------------------------------------------------------------------- /tests/api_resources/__init__.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | -------------------------------------------------------------------------------- /tests/api_resources/chat/__init__.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | -------------------------------------------------------------------------------- /tests/api_resources/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | -------------------------------------------------------------------------------- /tests/api_resources/metrics/test_few_shot_examples.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | from __future__ import annotations 4 | 5 | import os 6 | from typing import Any, cast 7 | 8 | import pytest 9 | 10 | from atla import Atla, AsyncAtla 11 | from tests.utils import assert_matches_type 12 | from atla.types.metrics import FewShotExampleSetResponse 13 | 14 | base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") 15 | 16 | 17 | class TestFewShotExamples: 18 | parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) 19 | 20 | @parametrize 21 | def test_method_set(self, client: Atla) -> None: 22 | few_shot_example = client.metrics.few_shot_examples.set( 23 | metric_id="metric_id", 24 | few_shot_examples=[ 25 | { 26 | "model_input": "Few-shot `model_input`.", 27 | "model_output": "Few-shot `model_output`.", 28 | "score": "1", 29 | } 30 | ], 31 | ) 32 | assert_matches_type(FewShotExampleSetResponse, few_shot_example, path=["response"]) 33 | 34 | @parametrize 35 | def test_raw_response_set(self, client: Atla) -> None: 36 | response = client.metrics.few_shot_examples.with_raw_response.set( 37 | metric_id="metric_id", 38 | few_shot_examples=[ 39 | { 40 | "model_input": "Few-shot `model_input`.", 41 | "model_output": "Few-shot `model_output`.", 42 | "score": "1", 43 | } 44 | ], 45 | ) 46 | 47 | assert response.is_closed is True 48 | assert response.http_request.headers.get("X-Stainless-Lang") == "python" 49 | few_shot_example = response.parse() 50 | assert_matches_type(FewShotExampleSetResponse, few_shot_example, path=["response"]) 51 | 52 | @parametrize 53 | def test_streaming_response_set(self, client: Atla) -> None: 54 | with client.metrics.few_shot_examples.with_streaming_response.set( 55 | metric_id="metric_id", 56 | few_shot_examples=[ 57 | { 58 | "model_input": "Few-shot `model_input`.", 59 | "model_output": "Few-shot `model_output`.", 60 | "score": "1", 61 | } 62 | ], 63 | ) as response: 64 | assert not response.is_closed 65 | assert response.http_request.headers.get("X-Stainless-Lang") == "python" 66 | 67 | few_shot_example = response.parse() 68 | assert_matches_type(FewShotExampleSetResponse, few_shot_example, path=["response"]) 69 | 70 | assert cast(Any, response.is_closed) is True 71 | 72 | @parametrize 73 | def test_path_params_set(self, client: Atla) -> None: 74 | with pytest.raises(ValueError, match=r"Expected a non-empty value for `metric_id` but received ''"): 75 | client.metrics.few_shot_examples.with_raw_response.set( 76 | metric_id="", 77 | few_shot_examples=[ 78 | { 79 | "model_input": "Few-shot `model_input`.", 80 | "model_output": "Few-shot `model_output`.", 81 | "score": "1", 82 | } 83 | ], 84 | ) 85 | 86 | 87 | class TestAsyncFewShotExamples: 88 | parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) 89 | 90 | @parametrize 91 | async def test_method_set(self, async_client: AsyncAtla) -> None: 92 | few_shot_example = await async_client.metrics.few_shot_examples.set( 93 | metric_id="metric_id", 94 | few_shot_examples=[ 95 | { 96 | "model_input": "Few-shot `model_input`.", 97 | "model_output": "Few-shot `model_output`.", 98 | "score": "1", 99 | } 100 | ], 101 | ) 102 | assert_matches_type(FewShotExampleSetResponse, few_shot_example, path=["response"]) 103 | 104 | @parametrize 105 | async def test_raw_response_set(self, async_client: AsyncAtla) -> None: 106 | response = await async_client.metrics.few_shot_examples.with_raw_response.set( 107 | metric_id="metric_id", 108 | few_shot_examples=[ 109 | { 110 | "model_input": "Few-shot `model_input`.", 111 | "model_output": "Few-shot `model_output`.", 112 | "score": "1", 113 | } 114 | ], 115 | ) 116 | 117 | assert response.is_closed is True 118 | assert response.http_request.headers.get("X-Stainless-Lang") == "python" 119 | few_shot_example = await response.parse() 120 | assert_matches_type(FewShotExampleSetResponse, few_shot_example, path=["response"]) 121 | 122 | @parametrize 123 | async def test_streaming_response_set(self, async_client: AsyncAtla) -> None: 124 | async with async_client.metrics.few_shot_examples.with_streaming_response.set( 125 | metric_id="metric_id", 126 | few_shot_examples=[ 127 | { 128 | "model_input": "Few-shot `model_input`.", 129 | "model_output": "Few-shot `model_output`.", 130 | "score": "1", 131 | } 132 | ], 133 | ) as response: 134 | assert not response.is_closed 135 | assert response.http_request.headers.get("X-Stainless-Lang") == "python" 136 | 137 | few_shot_example = await response.parse() 138 | assert_matches_type(FewShotExampleSetResponse, few_shot_example, path=["response"]) 139 | 140 | assert cast(Any, response.is_closed) is True 141 | 142 | @parametrize 143 | async def test_path_params_set(self, async_client: AsyncAtla) -> None: 144 | with pytest.raises(ValueError, match=r"Expected a non-empty value for `metric_id` but received ''"): 145 | await async_client.metrics.few_shot_examples.with_raw_response.set( 146 | metric_id="", 147 | few_shot_examples=[ 148 | { 149 | "model_input": "Few-shot `model_input`.", 150 | "model_output": "Few-shot `model_output`.", 151 | "score": "1", 152 | } 153 | ], 154 | ) 155 | -------------------------------------------------------------------------------- /tests/api_resources/test_evaluation.py: -------------------------------------------------------------------------------- 1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. 2 | 3 | from __future__ import annotations 4 | 5 | import os 6 | from typing import Any, cast 7 | 8 | import pytest 9 | 10 | from atla import Atla, AsyncAtla 11 | from atla.types import Evaluation 12 | from tests.utils import assert_matches_type 13 | 14 | base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") 15 | 16 | 17 | class TestEvaluation: 18 | parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) 19 | 20 | @parametrize 21 | def test_method_create(self, client: Atla) -> None: 22 | evaluation = client.evaluation.create( 23 | model_id="atla-selene-mini-20250127", 24 | model_input="Is it legal to monitor employee emails under European privacy laws?", 25 | model_output="Monitoring employee emails is permissible under European privacy laws like GDPR, provided there is a legitimate purpose.", 26 | ) 27 | assert_matches_type(Evaluation, evaluation, path=["response"]) 28 | 29 | @parametrize 30 | def test_method_create_with_all_params(self, client: Atla) -> None: 31 | evaluation = client.evaluation.create( 32 | model_id="atla-selene-mini-20250127", 33 | model_input="Is it legal to monitor employee emails under European privacy laws?", 34 | model_output="Monitoring employee emails is permissible under European privacy laws like GDPR, provided there is a legitimate purpose.", 35 | evaluation_criteria="Evaluate the answer based on its factual correctness. Assign a score of 1 if the answer is factually correct, otherwise assign a score of 0.", 36 | expected_model_output="Yes, but only under strict conditions. European privacy laws, including GDPR, require that monitoring be necessary for a legitimate purpose, employees be informed in advance, and privacy impact be minimized.", 37 | few_shot_examples=[ 38 | { 39 | "model_input": "Can employers require employees to use personal devices for work?", 40 | "model_output": "Employers can require employees to use personal devices for work, but legal and privacy considerations must be addressed.", 41 | "score": "1", 42 | "critique": "The model output is factually correct and accurately describes the Bring Your Own Device (BYOD) policy that an employer may choose to implement while highlighting the relevant legal and privacy considerations.", 43 | "expected_model_output": "Yes, but privacy and security concerns must be addressed. Employers must ensure compliance with data protection laws, inform employees about data handling, and offer alternatives where necessary.", 44 | "model_context": "Employers implementing Bring Your Own Device (BYOD) policies must consider data protection laws and employee privacy rights. Under regulations like GDPR, companies must ensure adequate data security, inform employees of monitoring or data collection practices, and provide alternatives if necessary. Failure to implement safeguards could lead to legal challenges or data breaches.", 45 | } 46 | ], 47 | metric_name="my_metric", 48 | model_context="European privacy laws, including GDPR, allow for the monitoring of employee emails under strict conditions. The employer must demonstrate that the monitoring is necessary for a legitimate purpose, such as protecting company assets or compliance with legal obligations. Employees must be informed about the monitoring in advance, and the privacy impact should be assessed to minimize intrusion.", 49 | prompt_version=1, 50 | ) 51 | assert_matches_type(Evaluation, evaluation, path=["response"]) 52 | 53 | @parametrize 54 | def test_raw_response_create(self, client: Atla) -> None: 55 | response = client.evaluation.with_raw_response.create( 56 | model_id="atla-selene-mini-20250127", 57 | model_input="Is it legal to monitor employee emails under European privacy laws?", 58 | model_output="Monitoring employee emails is permissible under European privacy laws like GDPR, provided there is a legitimate purpose.", 59 | ) 60 | 61 | assert response.is_closed is True 62 | assert response.http_request.headers.get("X-Stainless-Lang") == "python" 63 | evaluation = response.parse() 64 | assert_matches_type(Evaluation, evaluation, path=["response"]) 65 | 66 | @parametrize 67 | def test_streaming_response_create(self, client: Atla) -> None: 68 | with client.evaluation.with_streaming_response.create( 69 | model_id="atla-selene-mini-20250127", 70 | model_input="Is it legal to monitor employee emails under European privacy laws?", 71 | model_output="Monitoring employee emails is permissible under European privacy laws like GDPR, provided there is a legitimate purpose.", 72 | ) as response: 73 | assert not response.is_closed 74 | assert response.http_request.headers.get("X-Stainless-Lang") == "python" 75 | 76 | evaluation = response.parse() 77 | assert_matches_type(Evaluation, evaluation, path=["response"]) 78 | 79 | assert cast(Any, response.is_closed) is True 80 | 81 | 82 | class TestAsyncEvaluation: 83 | parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) 84 | 85 | @parametrize 86 | async def test_method_create(self, async_client: AsyncAtla) -> None: 87 | evaluation = await async_client.evaluation.create( 88 | model_id="atla-selene-mini-20250127", 89 | model_input="Is it legal to monitor employee emails under European privacy laws?", 90 | model_output="Monitoring employee emails is permissible under European privacy laws like GDPR, provided there is a legitimate purpose.", 91 | ) 92 | assert_matches_type(Evaluation, evaluation, path=["response"]) 93 | 94 | @parametrize 95 | async def test_method_create_with_all_params(self, async_client: AsyncAtla) -> None: 96 | evaluation = await async_client.evaluation.create( 97 | model_id="atla-selene-mini-20250127", 98 | model_input="Is it legal to monitor employee emails under European privacy laws?", 99 | model_output="Monitoring employee emails is permissible under European privacy laws like GDPR, provided there is a legitimate purpose.", 100 | evaluation_criteria="Evaluate the answer based on its factual correctness. Assign a score of 1 if the answer is factually correct, otherwise assign a score of 0.", 101 | expected_model_output="Yes, but only under strict conditions. European privacy laws, including GDPR, require that monitoring be necessary for a legitimate purpose, employees be informed in advance, and privacy impact be minimized.", 102 | few_shot_examples=[ 103 | { 104 | "model_input": "Can employers require employees to use personal devices for work?", 105 | "model_output": "Employers can require employees to use personal devices for work, but legal and privacy considerations must be addressed.", 106 | "score": "1", 107 | "critique": "The model output is factually correct and accurately describes the Bring Your Own Device (BYOD) policy that an employer may choose to implement while highlighting the relevant legal and privacy considerations.", 108 | "expected_model_output": "Yes, but privacy and security concerns must be addressed. Employers must ensure compliance with data protection laws, inform employees about data handling, and offer alternatives where necessary.", 109 | "model_context": "Employers implementing Bring Your Own Device (BYOD) policies must consider data protection laws and employee privacy rights. Under regulations like GDPR, companies must ensure adequate data security, inform employees of monitoring or data collection practices, and provide alternatives if necessary. Failure to implement safeguards could lead to legal challenges or data breaches.", 110 | } 111 | ], 112 | metric_name="my_metric", 113 | model_context="European privacy laws, including GDPR, allow for the monitoring of employee emails under strict conditions. The employer must demonstrate that the monitoring is necessary for a legitimate purpose, such as protecting company assets or compliance with legal obligations. Employees must be informed about the monitoring in advance, and the privacy impact should be assessed to minimize intrusion.", 114 | prompt_version=1, 115 | ) 116 | assert_matches_type(Evaluation, evaluation, path=["response"]) 117 | 118 | @parametrize 119 | async def test_raw_response_create(self, async_client: AsyncAtla) -> None: 120 | response = await async_client.evaluation.with_raw_response.create( 121 | model_id="atla-selene-mini-20250127", 122 | model_input="Is it legal to monitor employee emails under European privacy laws?", 123 | model_output="Monitoring employee emails is permissible under European privacy laws like GDPR, provided there is a legitimate purpose.", 124 | ) 125 | 126 | assert response.is_closed is True 127 | assert response.http_request.headers.get("X-Stainless-Lang") == "python" 128 | evaluation = await response.parse() 129 | assert_matches_type(Evaluation, evaluation, path=["response"]) 130 | 131 | @parametrize 132 | async def test_streaming_response_create(self, async_client: AsyncAtla) -> None: 133 | async with async_client.evaluation.with_streaming_response.create( 134 | model_id="atla-selene-mini-20250127", 135 | model_input="Is it legal to monitor employee emails under European privacy laws?", 136 | model_output="Monitoring employee emails is permissible under European privacy laws like GDPR, provided there is a legitimate purpose.", 137 | ) as response: 138 | assert not response.is_closed 139 | assert response.http_request.headers.get("X-Stainless-Lang") == "python" 140 | 141 | evaluation = await response.parse() 142 | assert_matches_type(Evaluation, evaluation, path=["response"]) 143 | 144 | assert cast(Any, response.is_closed) is True 145 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | import logging 5 | from typing import TYPE_CHECKING, Iterator, AsyncIterator 6 | 7 | import pytest 8 | from pytest_asyncio import is_async_test 9 | 10 | from atla import Atla, AsyncAtla 11 | 12 | if TYPE_CHECKING: 13 | from _pytest.fixtures import FixtureRequest # pyright: ignore[reportPrivateImportUsage] 14 | 15 | pytest.register_assert_rewrite("tests.utils") 16 | 17 | logging.getLogger("atla").setLevel(logging.DEBUG) 18 | 19 | 20 | # automatically add `pytest.mark.asyncio()` to all of our async tests 21 | # so we don't have to add that boilerplate everywhere 22 | def pytest_collection_modifyitems(items: list[pytest.Function]) -> None: 23 | pytest_asyncio_tests = (item for item in items if is_async_test(item)) 24 | session_scope_marker = pytest.mark.asyncio(loop_scope="session") 25 | for async_test in pytest_asyncio_tests: 26 | async_test.add_marker(session_scope_marker, append=False) 27 | 28 | 29 | base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") 30 | 31 | api_key = "My API Key" 32 | 33 | 34 | @pytest.fixture(scope="session") 35 | def client(request: FixtureRequest) -> Iterator[Atla]: 36 | strict = getattr(request, "param", True) 37 | if not isinstance(strict, bool): 38 | raise TypeError(f"Unexpected fixture parameter type {type(strict)}, expected {bool}") 39 | 40 | with Atla(base_url=base_url, api_key=api_key, _strict_response_validation=strict) as client: 41 | yield client 42 | 43 | 44 | @pytest.fixture(scope="session") 45 | async def async_client(request: FixtureRequest) -> AsyncIterator[AsyncAtla]: 46 | strict = getattr(request, "param", True) 47 | if not isinstance(strict, bool): 48 | raise TypeError(f"Unexpected fixture parameter type {type(strict)}, expected {bool}") 49 | 50 | async with AsyncAtla(base_url=base_url, api_key=api_key, _strict_response_validation=strict) as client: 51 | yield client 52 | -------------------------------------------------------------------------------- /tests/sample_file.txt: -------------------------------------------------------------------------------- 1 | Hello, world! 2 | -------------------------------------------------------------------------------- /tests/test_deepcopy.py: -------------------------------------------------------------------------------- 1 | from atla._utils import deepcopy_minimal 2 | 3 | 4 | def assert_different_identities(obj1: object, obj2: object) -> None: 5 | assert obj1 == obj2 6 | assert id(obj1) != id(obj2) 7 | 8 | 9 | def test_simple_dict() -> None: 10 | obj1 = {"foo": "bar"} 11 | obj2 = deepcopy_minimal(obj1) 12 | assert_different_identities(obj1, obj2) 13 | 14 | 15 | def test_nested_dict() -> None: 16 | obj1 = {"foo": {"bar": True}} 17 | obj2 = deepcopy_minimal(obj1) 18 | assert_different_identities(obj1, obj2) 19 | assert_different_identities(obj1["foo"], obj2["foo"]) 20 | 21 | 22 | def test_complex_nested_dict() -> None: 23 | obj1 = {"foo": {"bar": [{"hello": "world"}]}} 24 | obj2 = deepcopy_minimal(obj1) 25 | assert_different_identities(obj1, obj2) 26 | assert_different_identities(obj1["foo"], obj2["foo"]) 27 | assert_different_identities(obj1["foo"]["bar"], obj2["foo"]["bar"]) 28 | assert_different_identities(obj1["foo"]["bar"][0], obj2["foo"]["bar"][0]) 29 | 30 | 31 | def test_simple_list() -> None: 32 | obj1 = ["a", "b", "c"] 33 | obj2 = deepcopy_minimal(obj1) 34 | assert_different_identities(obj1, obj2) 35 | 36 | 37 | def test_nested_list() -> None: 38 | obj1 = ["a", [1, 2, 3]] 39 | obj2 = deepcopy_minimal(obj1) 40 | assert_different_identities(obj1, obj2) 41 | assert_different_identities(obj1[1], obj2[1]) 42 | 43 | 44 | class MyObject: ... 45 | 46 | 47 | def test_ignores_other_types() -> None: 48 | # custom classes 49 | my_obj = MyObject() 50 | obj1 = {"foo": my_obj} 51 | obj2 = deepcopy_minimal(obj1) 52 | assert_different_identities(obj1, obj2) 53 | assert obj1["foo"] is my_obj 54 | 55 | # tuples 56 | obj3 = ("a", "b") 57 | obj4 = deepcopy_minimal(obj3) 58 | assert obj3 is obj4 59 | -------------------------------------------------------------------------------- /tests/test_extract_files.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Sequence 4 | 5 | import pytest 6 | 7 | from atla._types import FileTypes 8 | from atla._utils import extract_files 9 | 10 | 11 | def test_removes_files_from_input() -> None: 12 | query = {"foo": "bar"} 13 | assert extract_files(query, paths=[]) == [] 14 | assert query == {"foo": "bar"} 15 | 16 | query2 = {"foo": b"Bar", "hello": "world"} 17 | assert extract_files(query2, paths=[["foo"]]) == [("foo", b"Bar")] 18 | assert query2 == {"hello": "world"} 19 | 20 | query3 = {"foo": {"foo": {"bar": b"Bar"}}, "hello": "world"} 21 | assert extract_files(query3, paths=[["foo", "foo", "bar"]]) == [("foo[foo][bar]", b"Bar")] 22 | assert query3 == {"foo": {"foo": {}}, "hello": "world"} 23 | 24 | query4 = {"foo": {"bar": b"Bar", "baz": "foo"}, "hello": "world"} 25 | assert extract_files(query4, paths=[["foo", "bar"]]) == [("foo[bar]", b"Bar")] 26 | assert query4 == {"hello": "world", "foo": {"baz": "foo"}} 27 | 28 | 29 | def test_multiple_files() -> None: 30 | query = {"documents": [{"file": b"My first file"}, {"file": b"My second file"}]} 31 | assert extract_files(query, paths=[["documents", "", "file"]]) == [ 32 | ("documents[][file]", b"My first file"), 33 | ("documents[][file]", b"My second file"), 34 | ] 35 | assert query == {"documents": [{}, {}]} 36 | 37 | 38 | @pytest.mark.parametrize( 39 | "query,paths,expected", 40 | [ 41 | [ 42 | {"foo": {"bar": "baz"}}, 43 | [["foo", "", "bar"]], 44 | [], 45 | ], 46 | [ 47 | {"foo": ["bar", "baz"]}, 48 | [["foo", "bar"]], 49 | [], 50 | ], 51 | [ 52 | {"foo": {"bar": "baz"}}, 53 | [["foo", "foo"]], 54 | [], 55 | ], 56 | ], 57 | ids=["dict expecting array", "array expecting dict", "unknown keys"], 58 | ) 59 | def test_ignores_incorrect_paths( 60 | query: dict[str, object], 61 | paths: Sequence[Sequence[str]], 62 | expected: list[tuple[str, FileTypes]], 63 | ) -> None: 64 | assert extract_files(query, paths=paths) == expected 65 | -------------------------------------------------------------------------------- /tests/test_files.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import anyio 4 | import pytest 5 | from dirty_equals import IsDict, IsList, IsBytes, IsTuple 6 | 7 | from atla._files import to_httpx_files, async_to_httpx_files 8 | 9 | readme_path = Path(__file__).parent.parent.joinpath("README.md") 10 | 11 | 12 | def test_pathlib_includes_file_name() -> None: 13 | result = to_httpx_files({"file": readme_path}) 14 | print(result) 15 | assert result == IsDict({"file": IsTuple("README.md", IsBytes())}) 16 | 17 | 18 | def test_tuple_input() -> None: 19 | result = to_httpx_files([("file", readme_path)]) 20 | print(result) 21 | assert result == IsList(IsTuple("file", IsTuple("README.md", IsBytes()))) 22 | 23 | 24 | @pytest.mark.asyncio 25 | async def test_async_pathlib_includes_file_name() -> None: 26 | result = await async_to_httpx_files({"file": readme_path}) 27 | print(result) 28 | assert result == IsDict({"file": IsTuple("README.md", IsBytes())}) 29 | 30 | 31 | @pytest.mark.asyncio 32 | async def test_async_supports_anyio_path() -> None: 33 | result = await async_to_httpx_files({"file": anyio.Path(readme_path)}) 34 | print(result) 35 | assert result == IsDict({"file": IsTuple("README.md", IsBytes())}) 36 | 37 | 38 | @pytest.mark.asyncio 39 | async def test_async_tuple_input() -> None: 40 | result = await async_to_httpx_files([("file", readme_path)]) 41 | print(result) 42 | assert result == IsList(IsTuple("file", IsTuple("README.md", IsBytes()))) 43 | 44 | 45 | def test_string_not_allowed() -> None: 46 | with pytest.raises(TypeError, match="Expected file types input to be a FileContent type or to be a tuple"): 47 | to_httpx_files( 48 | { 49 | "file": "foo", # type: ignore 50 | } 51 | ) 52 | -------------------------------------------------------------------------------- /tests/test_qs.py: -------------------------------------------------------------------------------- 1 | from typing import Any, cast 2 | from functools import partial 3 | from urllib.parse import unquote 4 | 5 | import pytest 6 | 7 | from atla._qs import Querystring, stringify 8 | 9 | 10 | def test_empty() -> None: 11 | assert stringify({}) == "" 12 | assert stringify({"a": {}}) == "" 13 | assert stringify({"a": {"b": {"c": {}}}}) == "" 14 | 15 | 16 | def test_basic() -> None: 17 | assert stringify({"a": 1}) == "a=1" 18 | assert stringify({"a": "b"}) == "a=b" 19 | assert stringify({"a": True}) == "a=true" 20 | assert stringify({"a": False}) == "a=false" 21 | assert stringify({"a": 1.23456}) == "a=1.23456" 22 | assert stringify({"a": None}) == "" 23 | 24 | 25 | @pytest.mark.parametrize("method", ["class", "function"]) 26 | def test_nested_dotted(method: str) -> None: 27 | if method == "class": 28 | serialise = Querystring(nested_format="dots").stringify 29 | else: 30 | serialise = partial(stringify, nested_format="dots") 31 | 32 | assert unquote(serialise({"a": {"b": "c"}})) == "a.b=c" 33 | assert unquote(serialise({"a": {"b": "c", "d": "e", "f": "g"}})) == "a.b=c&a.d=e&a.f=g" 34 | assert unquote(serialise({"a": {"b": {"c": {"d": "e"}}}})) == "a.b.c.d=e" 35 | assert unquote(serialise({"a": {"b": True}})) == "a.b=true" 36 | 37 | 38 | def test_nested_brackets() -> None: 39 | assert unquote(stringify({"a": {"b": "c"}})) == "a[b]=c" 40 | assert unquote(stringify({"a": {"b": "c", "d": "e", "f": "g"}})) == "a[b]=c&a[d]=e&a[f]=g" 41 | assert unquote(stringify({"a": {"b": {"c": {"d": "e"}}}})) == "a[b][c][d]=e" 42 | assert unquote(stringify({"a": {"b": True}})) == "a[b]=true" 43 | 44 | 45 | @pytest.mark.parametrize("method", ["class", "function"]) 46 | def test_array_comma(method: str) -> None: 47 | if method == "class": 48 | serialise = Querystring(array_format="comma").stringify 49 | else: 50 | serialise = partial(stringify, array_format="comma") 51 | 52 | assert unquote(serialise({"in": ["foo", "bar"]})) == "in=foo,bar" 53 | assert unquote(serialise({"a": {"b": [True, False]}})) == "a[b]=true,false" 54 | assert unquote(serialise({"a": {"b": [True, False, None, True]}})) == "a[b]=true,false,true" 55 | 56 | 57 | def test_array_repeat() -> None: 58 | assert unquote(stringify({"in": ["foo", "bar"]})) == "in=foo&in=bar" 59 | assert unquote(stringify({"a": {"b": [True, False]}})) == "a[b]=true&a[b]=false" 60 | assert unquote(stringify({"a": {"b": [True, False, None, True]}})) == "a[b]=true&a[b]=false&a[b]=true" 61 | assert unquote(stringify({"in": ["foo", {"b": {"c": ["d", "e"]}}]})) == "in=foo&in[b][c]=d&in[b][c]=e" 62 | 63 | 64 | @pytest.mark.parametrize("method", ["class", "function"]) 65 | def test_array_brackets(method: str) -> None: 66 | if method == "class": 67 | serialise = Querystring(array_format="brackets").stringify 68 | else: 69 | serialise = partial(stringify, array_format="brackets") 70 | 71 | assert unquote(serialise({"in": ["foo", "bar"]})) == "in[]=foo&in[]=bar" 72 | assert unquote(serialise({"a": {"b": [True, False]}})) == "a[b][]=true&a[b][]=false" 73 | assert unquote(serialise({"a": {"b": [True, False, None, True]}})) == "a[b][]=true&a[b][]=false&a[b][]=true" 74 | 75 | 76 | def test_unknown_array_format() -> None: 77 | with pytest.raises(NotImplementedError, match="Unknown array_format value: foo, choose from comma, repeat"): 78 | stringify({"a": ["foo", "bar"]}, array_format=cast(Any, "foo")) 79 | -------------------------------------------------------------------------------- /tests/test_required_args.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pytest 4 | 5 | from atla._utils import required_args 6 | 7 | 8 | def test_too_many_positional_params() -> None: 9 | @required_args(["a"]) 10 | def foo(a: str | None = None) -> str | None: 11 | return a 12 | 13 | with pytest.raises(TypeError, match=r"foo\(\) takes 1 argument\(s\) but 2 were given"): 14 | foo("a", "b") # type: ignore 15 | 16 | 17 | def test_positional_param() -> None: 18 | @required_args(["a"]) 19 | def foo(a: str | None = None) -> str | None: 20 | return a 21 | 22 | assert foo("a") == "a" 23 | assert foo(None) is None 24 | assert foo(a="b") == "b" 25 | 26 | with pytest.raises(TypeError, match="Missing required argument: 'a'"): 27 | foo() 28 | 29 | 30 | def test_keyword_only_param() -> None: 31 | @required_args(["a"]) 32 | def foo(*, a: str | None = None) -> str | None: 33 | return a 34 | 35 | assert foo(a="a") == "a" 36 | assert foo(a=None) is None 37 | assert foo(a="b") == "b" 38 | 39 | with pytest.raises(TypeError, match="Missing required argument: 'a'"): 40 | foo() 41 | 42 | 43 | def test_multiple_params() -> None: 44 | @required_args(["a", "b", "c"]) 45 | def foo(a: str = "", *, b: str = "", c: str = "") -> str | None: 46 | return f"{a} {b} {c}" 47 | 48 | assert foo(a="a", b="b", c="c") == "a b c" 49 | 50 | error_message = r"Missing required arguments.*" 51 | 52 | with pytest.raises(TypeError, match=error_message): 53 | foo() 54 | 55 | with pytest.raises(TypeError, match=error_message): 56 | foo(a="a") 57 | 58 | with pytest.raises(TypeError, match=error_message): 59 | foo(b="b") 60 | 61 | with pytest.raises(TypeError, match=error_message): 62 | foo(c="c") 63 | 64 | with pytest.raises(TypeError, match=r"Missing required argument: 'a'"): 65 | foo(b="a", c="c") 66 | 67 | with pytest.raises(TypeError, match=r"Missing required argument: 'b'"): 68 | foo("a", c="c") 69 | 70 | 71 | def test_multiple_variants() -> None: 72 | @required_args(["a"], ["b"]) 73 | def foo(*, a: str | None = None, b: str | None = None) -> str | None: 74 | return a if a is not None else b 75 | 76 | assert foo(a="foo") == "foo" 77 | assert foo(b="bar") == "bar" 78 | assert foo(a=None) is None 79 | assert foo(b=None) is None 80 | 81 | # TODO: this error message could probably be improved 82 | with pytest.raises( 83 | TypeError, 84 | match=r"Missing required arguments; Expected either \('a'\) or \('b'\) arguments to be given", 85 | ): 86 | foo() 87 | 88 | 89 | def test_multiple_params_multiple_variants() -> None: 90 | @required_args(["a", "b"], ["c"]) 91 | def foo(*, a: str | None = None, b: str | None = None, c: str | None = None) -> str | None: 92 | if a is not None: 93 | return a 94 | if b is not None: 95 | return b 96 | return c 97 | 98 | error_message = r"Missing required arguments; Expected either \('a' and 'b'\) or \('c'\) arguments to be given" 99 | 100 | with pytest.raises(TypeError, match=error_message): 101 | foo(a="foo") 102 | 103 | with pytest.raises(TypeError, match=error_message): 104 | foo(b="bar") 105 | 106 | with pytest.raises(TypeError, match=error_message): 107 | foo() 108 | 109 | assert foo(a=None, b="bar") == "bar" 110 | assert foo(c=None) is None 111 | assert foo(c="foo") == "foo" 112 | -------------------------------------------------------------------------------- /tests/test_response.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Any, List, Union, cast 3 | from typing_extensions import Annotated 4 | 5 | import httpx 6 | import pytest 7 | import pydantic 8 | 9 | from atla import Atla, AsyncAtla, BaseModel 10 | from atla._response import ( 11 | APIResponse, 12 | BaseAPIResponse, 13 | AsyncAPIResponse, 14 | BinaryAPIResponse, 15 | AsyncBinaryAPIResponse, 16 | extract_response_type, 17 | ) 18 | from atla._streaming import Stream 19 | from atla._base_client import FinalRequestOptions 20 | 21 | 22 | class ConcreteBaseAPIResponse(APIResponse[bytes]): ... 23 | 24 | 25 | class ConcreteAPIResponse(APIResponse[List[str]]): ... 26 | 27 | 28 | class ConcreteAsyncAPIResponse(APIResponse[httpx.Response]): ... 29 | 30 | 31 | def test_extract_response_type_direct_classes() -> None: 32 | assert extract_response_type(BaseAPIResponse[str]) == str 33 | assert extract_response_type(APIResponse[str]) == str 34 | assert extract_response_type(AsyncAPIResponse[str]) == str 35 | 36 | 37 | def test_extract_response_type_direct_class_missing_type_arg() -> None: 38 | with pytest.raises( 39 | RuntimeError, 40 | match="Expected type to have a type argument at index 0 but it did not", 41 | ): 42 | extract_response_type(AsyncAPIResponse) 43 | 44 | 45 | def test_extract_response_type_concrete_subclasses() -> None: 46 | assert extract_response_type(ConcreteBaseAPIResponse) == bytes 47 | assert extract_response_type(ConcreteAPIResponse) == List[str] 48 | assert extract_response_type(ConcreteAsyncAPIResponse) == httpx.Response 49 | 50 | 51 | def test_extract_response_type_binary_response() -> None: 52 | assert extract_response_type(BinaryAPIResponse) == bytes 53 | assert extract_response_type(AsyncBinaryAPIResponse) == bytes 54 | 55 | 56 | class PydanticModel(pydantic.BaseModel): ... 57 | 58 | 59 | def test_response_parse_mismatched_basemodel(client: Atla) -> None: 60 | response = APIResponse( 61 | raw=httpx.Response(200, content=b"foo"), 62 | client=client, 63 | stream=False, 64 | stream_cls=None, 65 | cast_to=str, 66 | options=FinalRequestOptions.construct(method="get", url="/foo"), 67 | ) 68 | 69 | with pytest.raises( 70 | TypeError, 71 | match="Pydantic models must subclass our base model type, e.g. `from atla import BaseModel`", 72 | ): 73 | response.parse(to=PydanticModel) 74 | 75 | 76 | @pytest.mark.asyncio 77 | async def test_async_response_parse_mismatched_basemodel(async_client: AsyncAtla) -> None: 78 | response = AsyncAPIResponse( 79 | raw=httpx.Response(200, content=b"foo"), 80 | client=async_client, 81 | stream=False, 82 | stream_cls=None, 83 | cast_to=str, 84 | options=FinalRequestOptions.construct(method="get", url="/foo"), 85 | ) 86 | 87 | with pytest.raises( 88 | TypeError, 89 | match="Pydantic models must subclass our base model type, e.g. `from atla import BaseModel`", 90 | ): 91 | await response.parse(to=PydanticModel) 92 | 93 | 94 | def test_response_parse_custom_stream(client: Atla) -> None: 95 | response = APIResponse( 96 | raw=httpx.Response(200, content=b"foo"), 97 | client=client, 98 | stream=True, 99 | stream_cls=None, 100 | cast_to=str, 101 | options=FinalRequestOptions.construct(method="get", url="/foo"), 102 | ) 103 | 104 | stream = response.parse(to=Stream[int]) 105 | assert stream._cast_to == int 106 | 107 | 108 | @pytest.mark.asyncio 109 | async def test_async_response_parse_custom_stream(async_client: AsyncAtla) -> None: 110 | response = AsyncAPIResponse( 111 | raw=httpx.Response(200, content=b"foo"), 112 | client=async_client, 113 | stream=True, 114 | stream_cls=None, 115 | cast_to=str, 116 | options=FinalRequestOptions.construct(method="get", url="/foo"), 117 | ) 118 | 119 | stream = await response.parse(to=Stream[int]) 120 | assert stream._cast_to == int 121 | 122 | 123 | class CustomModel(BaseModel): 124 | foo: str 125 | bar: int 126 | 127 | 128 | def test_response_parse_custom_model(client: Atla) -> None: 129 | response = APIResponse( 130 | raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})), 131 | client=client, 132 | stream=False, 133 | stream_cls=None, 134 | cast_to=str, 135 | options=FinalRequestOptions.construct(method="get", url="/foo"), 136 | ) 137 | 138 | obj = response.parse(to=CustomModel) 139 | assert obj.foo == "hello!" 140 | assert obj.bar == 2 141 | 142 | 143 | @pytest.mark.asyncio 144 | async def test_async_response_parse_custom_model(async_client: AsyncAtla) -> None: 145 | response = AsyncAPIResponse( 146 | raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})), 147 | client=async_client, 148 | stream=False, 149 | stream_cls=None, 150 | cast_to=str, 151 | options=FinalRequestOptions.construct(method="get", url="/foo"), 152 | ) 153 | 154 | obj = await response.parse(to=CustomModel) 155 | assert obj.foo == "hello!" 156 | assert obj.bar == 2 157 | 158 | 159 | def test_response_parse_annotated_type(client: Atla) -> None: 160 | response = APIResponse( 161 | raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})), 162 | client=client, 163 | stream=False, 164 | stream_cls=None, 165 | cast_to=str, 166 | options=FinalRequestOptions.construct(method="get", url="/foo"), 167 | ) 168 | 169 | obj = response.parse( 170 | to=cast("type[CustomModel]", Annotated[CustomModel, "random metadata"]), 171 | ) 172 | assert obj.foo == "hello!" 173 | assert obj.bar == 2 174 | 175 | 176 | async def test_async_response_parse_annotated_type(async_client: AsyncAtla) -> None: 177 | response = AsyncAPIResponse( 178 | raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})), 179 | client=async_client, 180 | stream=False, 181 | stream_cls=None, 182 | cast_to=str, 183 | options=FinalRequestOptions.construct(method="get", url="/foo"), 184 | ) 185 | 186 | obj = await response.parse( 187 | to=cast("type[CustomModel]", Annotated[CustomModel, "random metadata"]), 188 | ) 189 | assert obj.foo == "hello!" 190 | assert obj.bar == 2 191 | 192 | 193 | @pytest.mark.parametrize( 194 | "content, expected", 195 | [ 196 | ("false", False), 197 | ("true", True), 198 | ("False", False), 199 | ("True", True), 200 | ("TrUe", True), 201 | ("FalSe", False), 202 | ], 203 | ) 204 | def test_response_parse_bool(client: Atla, content: str, expected: bool) -> None: 205 | response = APIResponse( 206 | raw=httpx.Response(200, content=content), 207 | client=client, 208 | stream=False, 209 | stream_cls=None, 210 | cast_to=str, 211 | options=FinalRequestOptions.construct(method="get", url="/foo"), 212 | ) 213 | 214 | result = response.parse(to=bool) 215 | assert result is expected 216 | 217 | 218 | @pytest.mark.parametrize( 219 | "content, expected", 220 | [ 221 | ("false", False), 222 | ("true", True), 223 | ("False", False), 224 | ("True", True), 225 | ("TrUe", True), 226 | ("FalSe", False), 227 | ], 228 | ) 229 | async def test_async_response_parse_bool(client: AsyncAtla, content: str, expected: bool) -> None: 230 | response = AsyncAPIResponse( 231 | raw=httpx.Response(200, content=content), 232 | client=client, 233 | stream=False, 234 | stream_cls=None, 235 | cast_to=str, 236 | options=FinalRequestOptions.construct(method="get", url="/foo"), 237 | ) 238 | 239 | result = await response.parse(to=bool) 240 | assert result is expected 241 | 242 | 243 | class OtherModel(BaseModel): 244 | a: str 245 | 246 | 247 | @pytest.mark.parametrize("client", [False], indirect=True) # loose validation 248 | def test_response_parse_expect_model_union_non_json_content(client: Atla) -> None: 249 | response = APIResponse( 250 | raw=httpx.Response(200, content=b"foo", headers={"Content-Type": "application/text"}), 251 | client=client, 252 | stream=False, 253 | stream_cls=None, 254 | cast_to=str, 255 | options=FinalRequestOptions.construct(method="get", url="/foo"), 256 | ) 257 | 258 | obj = response.parse(to=cast(Any, Union[CustomModel, OtherModel])) 259 | assert isinstance(obj, str) 260 | assert obj == "foo" 261 | 262 | 263 | @pytest.mark.asyncio 264 | @pytest.mark.parametrize("async_client", [False], indirect=True) # loose validation 265 | async def test_async_response_parse_expect_model_union_non_json_content(async_client: AsyncAtla) -> None: 266 | response = AsyncAPIResponse( 267 | raw=httpx.Response(200, content=b"foo", headers={"Content-Type": "application/text"}), 268 | client=async_client, 269 | stream=False, 270 | stream_cls=None, 271 | cast_to=str, 272 | options=FinalRequestOptions.construct(method="get", url="/foo"), 273 | ) 274 | 275 | obj = await response.parse(to=cast(Any, Union[CustomModel, OtherModel])) 276 | assert isinstance(obj, str) 277 | assert obj == "foo" 278 | -------------------------------------------------------------------------------- /tests/test_streaming.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Iterator, AsyncIterator 4 | 5 | import httpx 6 | import pytest 7 | 8 | from atla import Atla, AsyncAtla 9 | from atla._streaming import Stream, AsyncStream, ServerSentEvent 10 | 11 | 12 | @pytest.mark.asyncio 13 | @pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) 14 | async def test_basic(sync: bool, client: Atla, async_client: AsyncAtla) -> None: 15 | def body() -> Iterator[bytes]: 16 | yield b"event: completion\n" 17 | yield b'data: {"foo":true}\n' 18 | yield b"\n" 19 | 20 | iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) 21 | 22 | sse = await iter_next(iterator) 23 | assert sse.event == "completion" 24 | assert sse.json() == {"foo": True} 25 | 26 | await assert_empty_iter(iterator) 27 | 28 | 29 | @pytest.mark.asyncio 30 | @pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) 31 | async def test_data_missing_event(sync: bool, client: Atla, async_client: AsyncAtla) -> None: 32 | def body() -> Iterator[bytes]: 33 | yield b'data: {"foo":true}\n' 34 | yield b"\n" 35 | 36 | iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) 37 | 38 | sse = await iter_next(iterator) 39 | assert sse.event is None 40 | assert sse.json() == {"foo": True} 41 | 42 | await assert_empty_iter(iterator) 43 | 44 | 45 | @pytest.mark.asyncio 46 | @pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) 47 | async def test_event_missing_data(sync: bool, client: Atla, async_client: AsyncAtla) -> None: 48 | def body() -> Iterator[bytes]: 49 | yield b"event: ping\n" 50 | yield b"\n" 51 | 52 | iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) 53 | 54 | sse = await iter_next(iterator) 55 | assert sse.event == "ping" 56 | assert sse.data == "" 57 | 58 | await assert_empty_iter(iterator) 59 | 60 | 61 | @pytest.mark.asyncio 62 | @pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) 63 | async def test_multiple_events(sync: bool, client: Atla, async_client: AsyncAtla) -> None: 64 | def body() -> Iterator[bytes]: 65 | yield b"event: ping\n" 66 | yield b"\n" 67 | yield b"event: completion\n" 68 | yield b"\n" 69 | 70 | iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) 71 | 72 | sse = await iter_next(iterator) 73 | assert sse.event == "ping" 74 | assert sse.data == "" 75 | 76 | sse = await iter_next(iterator) 77 | assert sse.event == "completion" 78 | assert sse.data == "" 79 | 80 | await assert_empty_iter(iterator) 81 | 82 | 83 | @pytest.mark.asyncio 84 | @pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) 85 | async def test_multiple_events_with_data(sync: bool, client: Atla, async_client: AsyncAtla) -> None: 86 | def body() -> Iterator[bytes]: 87 | yield b"event: ping\n" 88 | yield b'data: {"foo":true}\n' 89 | yield b"\n" 90 | yield b"event: completion\n" 91 | yield b'data: {"bar":false}\n' 92 | yield b"\n" 93 | 94 | iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) 95 | 96 | sse = await iter_next(iterator) 97 | assert sse.event == "ping" 98 | assert sse.json() == {"foo": True} 99 | 100 | sse = await iter_next(iterator) 101 | assert sse.event == "completion" 102 | assert sse.json() == {"bar": False} 103 | 104 | await assert_empty_iter(iterator) 105 | 106 | 107 | @pytest.mark.asyncio 108 | @pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) 109 | async def test_multiple_data_lines_with_empty_line(sync: bool, client: Atla, async_client: AsyncAtla) -> None: 110 | def body() -> Iterator[bytes]: 111 | yield b"event: ping\n" 112 | yield b"data: {\n" 113 | yield b'data: "foo":\n' 114 | yield b"data: \n" 115 | yield b"data:\n" 116 | yield b"data: true}\n" 117 | yield b"\n\n" 118 | 119 | iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) 120 | 121 | sse = await iter_next(iterator) 122 | assert sse.event == "ping" 123 | assert sse.json() == {"foo": True} 124 | assert sse.data == '{\n"foo":\n\n\ntrue}' 125 | 126 | await assert_empty_iter(iterator) 127 | 128 | 129 | @pytest.mark.asyncio 130 | @pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) 131 | async def test_data_json_escaped_double_new_line(sync: bool, client: Atla, async_client: AsyncAtla) -> None: 132 | def body() -> Iterator[bytes]: 133 | yield b"event: ping\n" 134 | yield b'data: {"foo": "my long\\n\\ncontent"}' 135 | yield b"\n\n" 136 | 137 | iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) 138 | 139 | sse = await iter_next(iterator) 140 | assert sse.event == "ping" 141 | assert sse.json() == {"foo": "my long\n\ncontent"} 142 | 143 | await assert_empty_iter(iterator) 144 | 145 | 146 | @pytest.mark.asyncio 147 | @pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) 148 | async def test_multiple_data_lines(sync: bool, client: Atla, async_client: AsyncAtla) -> None: 149 | def body() -> Iterator[bytes]: 150 | yield b"event: ping\n" 151 | yield b"data: {\n" 152 | yield b'data: "foo":\n' 153 | yield b"data: true}\n" 154 | yield b"\n\n" 155 | 156 | iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) 157 | 158 | sse = await iter_next(iterator) 159 | assert sse.event == "ping" 160 | assert sse.json() == {"foo": True} 161 | 162 | await assert_empty_iter(iterator) 163 | 164 | 165 | @pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) 166 | async def test_special_new_line_character( 167 | sync: bool, 168 | client: Atla, 169 | async_client: AsyncAtla, 170 | ) -> None: 171 | def body() -> Iterator[bytes]: 172 | yield b'data: {"content":" culpa"}\n' 173 | yield b"\n" 174 | yield b'data: {"content":" \xe2\x80\xa8"}\n' 175 | yield b"\n" 176 | yield b'data: {"content":"foo"}\n' 177 | yield b"\n" 178 | 179 | iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) 180 | 181 | sse = await iter_next(iterator) 182 | assert sse.event is None 183 | assert sse.json() == {"content": " culpa"} 184 | 185 | sse = await iter_next(iterator) 186 | assert sse.event is None 187 | assert sse.json() == {"content": " 
"} 188 | 189 | sse = await iter_next(iterator) 190 | assert sse.event is None 191 | assert sse.json() == {"content": "foo"} 192 | 193 | await assert_empty_iter(iterator) 194 | 195 | 196 | @pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) 197 | async def test_multi_byte_character_multiple_chunks( 198 | sync: bool, 199 | client: Atla, 200 | async_client: AsyncAtla, 201 | ) -> None: 202 | def body() -> Iterator[bytes]: 203 | yield b'data: {"content":"' 204 | # bytes taken from the string 'известни' and arbitrarily split 205 | # so that some multi-byte characters span multiple chunks 206 | yield b"\xd0" 207 | yield b"\xb8\xd0\xb7\xd0" 208 | yield b"\xb2\xd0\xb5\xd1\x81\xd1\x82\xd0\xbd\xd0\xb8" 209 | yield b'"}\n' 210 | yield b"\n" 211 | 212 | iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) 213 | 214 | sse = await iter_next(iterator) 215 | assert sse.event is None 216 | assert sse.json() == {"content": "известни"} 217 | 218 | 219 | async def to_aiter(iter: Iterator[bytes]) -> AsyncIterator[bytes]: 220 | for chunk in iter: 221 | yield chunk 222 | 223 | 224 | async def iter_next(iter: Iterator[ServerSentEvent] | AsyncIterator[ServerSentEvent]) -> ServerSentEvent: 225 | if isinstance(iter, AsyncIterator): 226 | return await iter.__anext__() 227 | 228 | return next(iter) 229 | 230 | 231 | async def assert_empty_iter(iter: Iterator[ServerSentEvent] | AsyncIterator[ServerSentEvent]) -> None: 232 | with pytest.raises((StopAsyncIteration, RuntimeError)): 233 | await iter_next(iter) 234 | 235 | 236 | def make_event_iterator( 237 | content: Iterator[bytes], 238 | *, 239 | sync: bool, 240 | client: Atla, 241 | async_client: AsyncAtla, 242 | ) -> Iterator[ServerSentEvent] | AsyncIterator[ServerSentEvent]: 243 | if sync: 244 | return Stream(cast_to=object, client=client, response=httpx.Response(200, content=content))._iter_events() 245 | 246 | return AsyncStream( 247 | cast_to=object, client=async_client, response=httpx.Response(200, content=to_aiter(content)) 248 | )._iter_events() 249 | -------------------------------------------------------------------------------- /tests/test_utils/test_proxy.py: -------------------------------------------------------------------------------- 1 | import operator 2 | from typing import Any 3 | from typing_extensions import override 4 | 5 | from atla._utils import LazyProxy 6 | 7 | 8 | class RecursiveLazyProxy(LazyProxy[Any]): 9 | @override 10 | def __load__(self) -> Any: 11 | return self 12 | 13 | def __call__(self, *_args: Any, **_kwds: Any) -> Any: 14 | raise RuntimeError("This should never be called!") 15 | 16 | 17 | def test_recursive_proxy() -> None: 18 | proxy = RecursiveLazyProxy() 19 | assert repr(proxy) == "RecursiveLazyProxy" 20 | assert str(proxy) == "RecursiveLazyProxy" 21 | assert dir(proxy) == [] 22 | assert type(proxy).__name__ == "RecursiveLazyProxy" 23 | assert type(operator.attrgetter("name.foo.bar.baz")(proxy)).__name__ == "RecursiveLazyProxy" 24 | -------------------------------------------------------------------------------- /tests/test_utils/test_typing.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Generic, TypeVar, cast 4 | 5 | from atla._utils import extract_type_var_from_base 6 | 7 | _T = TypeVar("_T") 8 | _T2 = TypeVar("_T2") 9 | _T3 = TypeVar("_T3") 10 | 11 | 12 | class BaseGeneric(Generic[_T]): ... 13 | 14 | 15 | class SubclassGeneric(BaseGeneric[_T]): ... 16 | 17 | 18 | class BaseGenericMultipleTypeArgs(Generic[_T, _T2, _T3]): ... 19 | 20 | 21 | class SubclassGenericMultipleTypeArgs(BaseGenericMultipleTypeArgs[_T, _T2, _T3]): ... 22 | 23 | 24 | class SubclassDifferentOrderGenericMultipleTypeArgs(BaseGenericMultipleTypeArgs[_T2, _T, _T3]): ... 25 | 26 | 27 | def test_extract_type_var() -> None: 28 | assert ( 29 | extract_type_var_from_base( 30 | BaseGeneric[int], 31 | index=0, 32 | generic_bases=cast("tuple[type, ...]", (BaseGeneric,)), 33 | ) 34 | == int 35 | ) 36 | 37 | 38 | def test_extract_type_var_generic_subclass() -> None: 39 | assert ( 40 | extract_type_var_from_base( 41 | SubclassGeneric[int], 42 | index=0, 43 | generic_bases=cast("tuple[type, ...]", (BaseGeneric,)), 44 | ) 45 | == int 46 | ) 47 | 48 | 49 | def test_extract_type_var_multiple() -> None: 50 | typ = BaseGenericMultipleTypeArgs[int, str, None] 51 | 52 | generic_bases = cast("tuple[type, ...]", (BaseGenericMultipleTypeArgs,)) 53 | assert extract_type_var_from_base(typ, index=0, generic_bases=generic_bases) == int 54 | assert extract_type_var_from_base(typ, index=1, generic_bases=generic_bases) == str 55 | assert extract_type_var_from_base(typ, index=2, generic_bases=generic_bases) == type(None) 56 | 57 | 58 | def test_extract_type_var_generic_subclass_multiple() -> None: 59 | typ = SubclassGenericMultipleTypeArgs[int, str, None] 60 | 61 | generic_bases = cast("tuple[type, ...]", (BaseGenericMultipleTypeArgs,)) 62 | assert extract_type_var_from_base(typ, index=0, generic_bases=generic_bases) == int 63 | assert extract_type_var_from_base(typ, index=1, generic_bases=generic_bases) == str 64 | assert extract_type_var_from_base(typ, index=2, generic_bases=generic_bases) == type(None) 65 | 66 | 67 | def test_extract_type_var_generic_subclass_different_ordering_multiple() -> None: 68 | typ = SubclassDifferentOrderGenericMultipleTypeArgs[int, str, None] 69 | 70 | generic_bases = cast("tuple[type, ...]", (BaseGenericMultipleTypeArgs,)) 71 | assert extract_type_var_from_base(typ, index=0, generic_bases=generic_bases) == int 72 | assert extract_type_var_from_base(typ, index=1, generic_bases=generic_bases) == str 73 | assert extract_type_var_from_base(typ, index=2, generic_bases=generic_bases) == type(None) 74 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | import inspect 5 | import traceback 6 | import contextlib 7 | from typing import Any, TypeVar, Iterator, cast 8 | from datetime import date, datetime 9 | from typing_extensions import Literal, get_args, get_origin, assert_type 10 | 11 | from atla._types import Omit, NoneType 12 | from atla._utils import ( 13 | is_dict, 14 | is_list, 15 | is_list_type, 16 | is_union_type, 17 | extract_type_arg, 18 | is_annotated_type, 19 | is_type_alias_type, 20 | ) 21 | from atla._compat import PYDANTIC_V2, field_outer_type, get_model_fields 22 | from atla._models import BaseModel 23 | 24 | BaseModelT = TypeVar("BaseModelT", bound=BaseModel) 25 | 26 | 27 | def assert_matches_model(model: type[BaseModelT], value: BaseModelT, *, path: list[str]) -> bool: 28 | for name, field in get_model_fields(model).items(): 29 | field_value = getattr(value, name) 30 | if PYDANTIC_V2: 31 | allow_none = False 32 | else: 33 | # in v1 nullability was structured differently 34 | # https://docs.pydantic.dev/2.0/migration/#required-optional-and-nullable-fields 35 | allow_none = getattr(field, "allow_none", False) 36 | 37 | assert_matches_type( 38 | field_outer_type(field), 39 | field_value, 40 | path=[*path, name], 41 | allow_none=allow_none, 42 | ) 43 | 44 | return True 45 | 46 | 47 | # Note: the `path` argument is only used to improve error messages when `--showlocals` is used 48 | def assert_matches_type( 49 | type_: Any, 50 | value: object, 51 | *, 52 | path: list[str], 53 | allow_none: bool = False, 54 | ) -> None: 55 | if is_type_alias_type(type_): 56 | type_ = type_.__value__ 57 | 58 | # unwrap `Annotated[T, ...]` -> `T` 59 | if is_annotated_type(type_): 60 | type_ = extract_type_arg(type_, 0) 61 | 62 | if allow_none and value is None: 63 | return 64 | 65 | if type_ is None or type_ is NoneType: 66 | assert value is None 67 | return 68 | 69 | origin = get_origin(type_) or type_ 70 | 71 | if is_list_type(type_): 72 | return _assert_list_type(type_, value) 73 | 74 | if origin == str: 75 | assert isinstance(value, str) 76 | elif origin == int: 77 | assert isinstance(value, int) 78 | elif origin == bool: 79 | assert isinstance(value, bool) 80 | elif origin == float: 81 | assert isinstance(value, float) 82 | elif origin == bytes: 83 | assert isinstance(value, bytes) 84 | elif origin == datetime: 85 | assert isinstance(value, datetime) 86 | elif origin == date: 87 | assert isinstance(value, date) 88 | elif origin == object: 89 | # nothing to do here, the expected type is unknown 90 | pass 91 | elif origin == Literal: 92 | assert value in get_args(type_) 93 | elif origin == dict: 94 | assert is_dict(value) 95 | 96 | args = get_args(type_) 97 | key_type = args[0] 98 | items_type = args[1] 99 | 100 | for key, item in value.items(): 101 | assert_matches_type(key_type, key, path=[*path, ""]) 102 | assert_matches_type(items_type, item, path=[*path, ""]) 103 | elif is_union_type(type_): 104 | variants = get_args(type_) 105 | 106 | try: 107 | none_index = variants.index(type(None)) 108 | except ValueError: 109 | pass 110 | else: 111 | # special case Optional[T] for better error messages 112 | if len(variants) == 2: 113 | if value is None: 114 | # valid 115 | return 116 | 117 | return assert_matches_type(type_=variants[not none_index], value=value, path=path) 118 | 119 | for i, variant in enumerate(variants): 120 | try: 121 | assert_matches_type(variant, value, path=[*path, f"variant {i}"]) 122 | return 123 | except AssertionError: 124 | traceback.print_exc() 125 | continue 126 | 127 | raise AssertionError("Did not match any variants") 128 | elif issubclass(origin, BaseModel): 129 | assert isinstance(value, type_) 130 | assert assert_matches_model(type_, cast(Any, value), path=path) 131 | elif inspect.isclass(origin) and origin.__name__ == "HttpxBinaryResponseContent": 132 | assert value.__class__.__name__ == "HttpxBinaryResponseContent" 133 | else: 134 | assert None, f"Unhandled field type: {type_}" 135 | 136 | 137 | def _assert_list_type(type_: type[object], value: object) -> None: 138 | assert is_list(value) 139 | 140 | inner_type = get_args(type_)[0] 141 | for entry in value: 142 | assert_type(inner_type, entry) # type: ignore 143 | 144 | 145 | @contextlib.contextmanager 146 | def update_env(**new_env: str | Omit) -> Iterator[None]: 147 | old = os.environ.copy() 148 | 149 | try: 150 | for name, value in new_env.items(): 151 | if isinstance(value, Omit): 152 | os.environ.pop(name, None) 153 | else: 154 | os.environ[name] = value 155 | 156 | yield None 157 | finally: 158 | os.environ.clear() 159 | os.environ.update(old) 160 | --------------------------------------------------------------------------------