├── .git-hooks └── check_api_key.sh ├── .github ├── CODEOWNERS ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── dependabot.yml ├── labeler.yml ├── renovate.json ├── settings.yml ├── stale.yml └── workflows │ ├── integration-tests.yaml │ ├── labeler.yml │ ├── lint.yaml │ ├── publish.yaml │ ├── release_version.yaml │ ├── semantic-pr.yml │ └── test.yaml ├── .gitignore ├── .gitleaks.toml ├── .markdownlint.yaml ├── .pre-commit-config.yaml ├── .python-version ├── .yamllint.yaml ├── CHANGELOG.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── ai21 ├── __init__.py ├── ai21_env_config.py ├── clients │ ├── __init__.py │ ├── aws │ │ ├── __init__.py │ │ └── aws_authorization.py │ ├── azure │ │ ├── __init__.py │ │ └── ai21_azure_client.py │ ├── bedrock │ │ ├── __init__.py │ │ ├── _stream_decoder.py │ │ ├── ai21_bedrock_client.py │ │ └── bedrock_model_id.py │ ├── common │ │ ├── __init__.py │ │ ├── auth │ │ │ ├── __init__.py │ │ │ └── gcp_authorization.py │ │ ├── chat_base.py │ │ ├── conversational_rag.py │ │ └── maestro │ │ │ ├── __init__.py │ │ │ ├── maestro.py │ │ │ └── run.py │ ├── launchpad │ │ ├── __init__.py │ │ └── ai21_launchpad_client.py │ ├── studio │ │ ├── __init__.py │ │ ├── ai21_client.py │ │ ├── async_ai21_client.py │ │ └── resources │ │ │ ├── __init__.py │ │ │ ├── batch │ │ │ ├── __init__.py │ │ │ ├── async_batches.py │ │ │ ├── base_batches.py │ │ │ └── batches.py │ │ │ ├── beta │ │ │ ├── __init__.py │ │ │ ├── async_beta.py │ │ │ └── beta.py │ │ │ ├── chat │ │ │ ├── __init__.py │ │ │ ├── async_chat_completions.py │ │ │ ├── base_chat_completions.py │ │ │ └── chat_completions.py │ │ │ ├── constants.py │ │ │ ├── maestro │ │ │ ├── __init__.py │ │ │ ├── maestro.py │ │ │ └── run.py │ │ │ ├── studio_chat.py │ │ │ ├── studio_conversational_rag.py │ │ │ ├── studio_library.py │ │ │ └── studio_resource.py │ └── vertex │ │ ├── __init__.py │ │ └── ai21_vertex_client.py ├── constants.py ├── errors.py ├── files │ ├── __init__.py │ └── downloaded_file.py ├── http_client │ ├── __init__.py │ ├── async_http_client.py │ ├── base_http_client.py │ └── http_client.py ├── logger.py ├── models │ ├── __init__.py │ ├── _pydantic_compatibility.py │ ├── ai21_base_model.py │ ├── chat │ │ ├── __init__.py │ │ ├── chat_completion_chunk.py │ │ ├── chat_completion_response.py │ │ ├── chat_message.py │ │ ├── document_schema.py │ │ ├── function_tool_definition.py │ │ ├── response_format.py │ │ ├── role_type.py │ │ ├── tool_call.py │ │ ├── tool_defintions.py │ │ ├── tool_function.py │ │ └── tool_parameters.py │ ├── chat_message.py │ ├── document_type.py │ ├── logprobs.py │ ├── maestro │ │ ├── __init__.py │ │ └── run.py │ ├── penalty.py │ ├── request_options.py │ ├── responses │ │ ├── __init__.py │ │ ├── batch_response.py │ │ ├── chat_response.py │ │ ├── conversational_rag_response.py │ │ └── file_response.py │ ├── retrieval_strategy.py │ └── usage_info.py ├── pagination │ ├── __init__.py │ ├── async_pagination.py │ ├── base_pagination.py │ └── sync_pagination.py ├── py.typed ├── stream │ ├── __init__.py │ ├── async_stream.py │ ├── stream.py │ └── stream_commons.py ├── tokenizers │ ├── __init__.py │ ├── ai21_tokenizer.py │ └── factory.py ├── types.py ├── utils │ ├── __init__.py │ └── typing.py ├── version.py └── version_utils.py ├── bootstrap.sh ├── examples ├── __init__.py ├── azure │ ├── __init__.py │ ├── async_azure_chat_completions.py │ └── azure_chat_completions.py ├── bedrock │ └── chat │ │ ├── async_chat_completions.py │ │ ├── async_stream_chat_completions.py │ │ ├── chat_completions.py │ │ └── stream_chat_completions.py ├── launchpad │ ├── __init__.py │ ├── async_chat_completions.py │ └── chat_completions.py ├── studio │ ├── __init__.py │ ├── async_library.py │ ├── async_tokenization.py │ ├── batches │ │ ├── __init__.py │ │ ├── async_batches.py │ │ └── batches.py │ ├── chat │ │ ├── __init__.py │ │ ├── async_chat_completions.py │ │ ├── async_stream_chat_completions.py │ │ ├── chat_completions.py │ │ ├── chat_documents.py │ │ ├── chat_function_calling.py │ │ ├── chat_function_calling_multiple_tools.py │ │ ├── chat_response_format.py │ │ └── stream_chat_completions.py │ ├── conversational_rag │ │ ├── __init__.py │ │ ├── async_conversational_rag.py │ │ └── conversational_rag.py │ ├── file_utils.py │ ├── library.py │ ├── maestro │ │ ├── __init__.py │ │ ├── async_run.py │ │ └── run.py │ ├── sample_notebooks │ │ ├── .ipynb_checkpoints │ │ │ └── AI21_Client_Code_Snippets-checkpoint.ipynb │ │ └── AI21_Client_Code_Snippets.ipynb │ └── tokenization.py └── vertex │ ├── __init__.py │ ├── async_chat_completions.py │ ├── async_stream_chat_completions.py │ ├── chat_completions.py │ └── stream_chat_completions.py ├── init.sh ├── poetry.lock ├── poetry.toml ├── pyproject.toml ├── setup.py ├── tasks.py └── tests ├── __init__.py ├── integration_tests ├── __init__.py ├── clients │ ├── __init__.py │ ├── bedrock │ │ ├── __init__.py │ │ └── test_chat_completions.py │ ├── resources │ │ └── library_file.txt │ ├── studio │ │ ├── __init__.py │ │ ├── conftest.py │ │ ├── test_chat_completions.py │ │ ├── test_library.py │ │ └── test_maestro.py │ ├── test_bedrock.py │ ├── test_studio.py │ └── test_vertex.py └── skip_helpers.py └── unittests ├── __init__.py ├── clients ├── __init__.py ├── azure │ ├── __init__.py │ ├── test_ai21_azure_client.py │ └── test_chat_completions.py ├── bedrock │ ├── __init__.py │ └── test_chat_completions.py ├── studio │ ├── __init__.py │ ├── resources │ │ ├── __init__.py │ │ ├── chat │ │ │ ├── __init__.py │ │ │ └── test_chat_completions.py │ │ ├── conftest.py │ │ ├── test_async_studio_resource.py │ │ ├── test_chat.py │ │ └── test_studio_resources.py │ ├── test_ai21_client.py │ └── test_async_ai21_client.py └── vertex │ ├── __init__.py │ └── test_chat_completions.py ├── commons.py ├── conftest.py ├── models ├── __init__.py ├── response_mocks.py └── test_serialization.py ├── test_ai21_env_config.py ├── test_ai21_http_client.py ├── test_aws_authorization.py ├── test_aws_stream_decoder.py ├── test_gcp_authorization.py ├── test_http_client.py ├── test_imports.py ├── test_stream.py └── tokenizers ├── test_ai21_tokenizer.py └── test_async_ai21_tokenizer.py /.git-hooks/check_api_key.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Check for `api_key=` in staged changes 4 | if git diff --cached | grep -q -E '\bapi_key=[^"]'; then 5 | echo "❌ Commit blocked: Found 'api_key=' in staged changes." 6 | exit 1 # Prevent commit 7 | fi 8 | 9 | exit 0 # Allow commit 10 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @reach 2 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: asafgardin, etang-ai21 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Version [e.g. 22] 29 | 30 | **Additional context** 31 | Add any other context about the problem here. 32 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # yaml-language-server: $schema=https://json.schemastore.org/dependabot-2.0.json 2 | 3 | version: 2 4 | updates: 5 | - package-ecosystem: github-actions 6 | directory: / 7 | schedule: 8 | interval: daily 9 | commit-message: 10 | prefix: chore(deps) 11 | -------------------------------------------------------------------------------- /.github/labeler.yml: -------------------------------------------------------------------------------- 1 | # Add 'Documentation' label to any change to .md files within the entire repository 2 | documentation: 3 | - changed-files: 4 | - any-glob-to-any-file: "**/*.md" 5 | 6 | # Add 'feature' label to any PR where the head branch name starts with `feature` or has a `feature` section in the name 7 | feature: 8 | - head-branch: ["^feat", "feat:"] 9 | 10 | # Add 'release' label to any PR that is opened against the `main` branch 11 | fix: 12 | - head-branch: ["^bugfix", "fix:"] 13 | 14 | ci: 15 | - head-branch: ["^ci", "ci:"] 16 | - changed-files: 17 | - any-glob-to-any-file: 18 | - .github/* 19 | 20 | aws: 21 | - changed-files: 22 | - any-glob-to-any-file: 23 | - ai21/clients/bedrock/* 24 | - ai21/clients/sagemaker/* 25 | 26 | azure: 27 | - changed-files: 28 | - any-glob-to-any-file: 29 | - ai21/clients/azure/* 30 | 31 | vertex: 32 | - changed-files: 33 | - any-glob-to-any-file: 34 | - ai21/clients/vertex/* 35 | -------------------------------------------------------------------------------- /.github/renovate.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": [ 3 | "config:base" 4 | ], 5 | "dependencyDashboard": true, 6 | "semanticCommits": "enabled", 7 | "labels": [ 8 | "dependencies" 9 | ], 10 | "regexManagers": [ 11 | { 12 | "fileMatch": [ 13 | "(^|/)\\.pre-commit-config\\.yaml$" 14 | ], 15 | "matchStrings": [ 16 | "\\nminimum_pre_commit_version: (?.*?)\\n" 17 | ], 18 | "depNameTemplate": "pre-commit", 19 | "datasourceTemplate": "pypi" 20 | }, 21 | { 22 | "fileMatch": [ 23 | "(^|/)\\.pre-commit-config\\.yaml$" 24 | ], 25 | "matchStrings": [ 26 | "\\n\\s*entry: (?[^:]+):(?\\S+)" 27 | ], 28 | "datasourceTemplate": "docker" 29 | } 30 | ] 31 | } 32 | -------------------------------------------------------------------------------- /.github/stale.yml: -------------------------------------------------------------------------------- 1 | # Configuration for probot-stale - https://github.com/probot/stale 2 | 3 | # Number of days of inactivity before an Issue or Pull Request becomes stale 4 | daysUntilStale: 3 5 | 6 | # Number of days of inactivity before an Issue or Pull Request with the stale 7 | # label is closed. 8 | # Set to false to disable. If disabled, issues still need to be closed manually, 9 | # but will remain marked as stale. 10 | daysUntilClose: 3 11 | 12 | # Issues or Pull Requests with these labels will never be considered stale. 13 | # Set to `[]` to disable 14 | exemptLabels: 15 | - dependencies 16 | - draft 17 | - WIP 18 | 19 | # Label to use when marking as stale 20 | staleLabel: stale 21 | 22 | # Comment to post when marking as stale. Set to `false` to disable 23 | markComment: > 24 | This PR has been automatically marked as stale because it has not had 25 | recent activity. It will be closed if no further activity occurs. Thank you 26 | for your contributions. 27 | 28 | # Comment to post when removing the stale label. 29 | # unmarkComment: > 30 | # Your comment here. 31 | 32 | # Comment to post when closing a stale Issue or Pull Request. 33 | closeComment: > 34 | This PR has been automatically closed because it has not had recent activity. 35 | You can reopen it by clicking on `Reopen pull request`. 36 | Thank you for your contributions. 37 | 38 | # Limit the number of actions per hour, from 1-30. Default is 30 39 | limitPerRun: 30 40 | 41 | # Limit to only `issues` or `pulls` 42 | only: pulls 43 | -------------------------------------------------------------------------------- /.github/workflows/lint.yaml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: [push] 4 | 5 | env: 6 | POETRY_VERSION: "1.7.1" 7 | POETRY_URL: https://install.python-poetry.org 8 | 9 | jobs: 10 | lint: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: ["3.10"] 15 | 16 | steps: 17 | - name: Checkout 18 | uses: actions/checkout@v3 19 | - name: Install Poetry 20 | run: | 21 | pipx install poetry==1.8 22 | - name: Set up Python 23 | uses: actions/setup-python@v5 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | cache: poetry 27 | cache-dependency-path: poetry.lock 28 | - name: Set Poetry environment 29 | run: | 30 | poetry env use ${{ matrix.python-version }} 31 | - name: Install dependencies 32 | run: | 33 | poetry install --only dev --all-extras 34 | - name: Lint Python (Black) 35 | run: | 36 | poetry run inv formatter 37 | - name: Lint Python (Ruff) 38 | run: | 39 | poetry run inv lint 40 | - name: Lint Python (isort) 41 | run: | 42 | poetry run inv isort 43 | -------------------------------------------------------------------------------- /.github/workflows/publish.yaml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | name: Publish to PYPI 5 | 6 | on: 7 | release: 8 | types: [published] 9 | 10 | permissions: 11 | contents: read 12 | 13 | jobs: 14 | deploy: 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: ["3.10"] 19 | 20 | steps: 21 | - uses: actions/checkout@v3 22 | - name: Install Poetry 23 | run: | 24 | pipx install poetry==1.8 25 | - name: Set up Python 26 | uses: actions/setup-python@v5 27 | with: 28 | python-version: ${{ matrix.python-version }} 29 | cache: poetry 30 | cache-dependency-path: poetry.lock 31 | - name: Set Poetry environment 32 | run: | 33 | poetry env use ${{ matrix.python-version }} 34 | - name: Build package 35 | run: poetry build 36 | - name: Publish package to PYPI 37 | uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc 38 | with: 39 | user: __token__ 40 | password: ${{ secrets.PYPI_API_TOKEN }} 41 | -------------------------------------------------------------------------------- /.github/workflows/release_version.yaml: -------------------------------------------------------------------------------- 1 | name: Semantic Release 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | jobs: 7 | release: 8 | runs-on: ubuntu-latest 9 | concurrency: release 10 | permissions: 11 | id-token: write 12 | contents: write 13 | 14 | steps: 15 | - uses: actions/checkout@v3 16 | with: 17 | fetch-depth: 0 18 | persist-credentials: false 19 | 20 | - name: Python Semantic Release 21 | uses: python-semantic-release/python-semantic-release@v9.21.0 22 | with: 23 | github_token: ${{ secrets.GH_PAT_SEM_REL_NISSIM_AI21_PYTHON_SDK }} 24 | -------------------------------------------------------------------------------- /.github/workflows/semantic-pr.yml: -------------------------------------------------------------------------------- 1 | # yaml-language-server: $schema=https://json.schemastore.org/github-workflow.json 2 | 3 | name: Semantic PR 4 | concurrency: 5 | group: Semantic-PR-${{ github.head_ref }} 6 | cancel-in-progress: true 7 | on: 8 | pull_request_target: 9 | types: 10 | - opened 11 | - edited 12 | - reopened 13 | 14 | jobs: 15 | semantic-pr: 16 | runs-on: ubuntu-20.04 17 | timeout-minutes: 1 18 | steps: 19 | - name: Semantic pull-request 20 | uses: amannn/action-semantic-pull-request@v5.5.3 21 | with: 22 | requireScope: false 23 | wip: true 24 | validateSingleCommit: true 25 | env: 26 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 27 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: Unittest 2 | 3 | on: [push, pull_request] 4 | 5 | env: 6 | POETRY_VERSION: "1.7.1" 7 | POETRY_URL: https://install.python-poetry.org 8 | 9 | jobs: 10 | lint: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: ["3.10"] 15 | 16 | steps: 17 | - name: Checkout 18 | uses: actions/checkout@v3 19 | - name: Install Poetry 20 | run: | 21 | pipx install poetry==1.8 22 | - name: Set up Python 23 | uses: actions/setup-python@v5 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | cache: poetry 27 | cache-dependency-path: poetry.lock 28 | - name: Set Poetry environment and clear cache 29 | run: | 30 | poetry env use ${{ matrix.python-version }} 31 | poetry cache clear --all pypi 32 | - name: Install dependencies 33 | run: | 34 | poetry install --no-root --only dev --all-extras 35 | - name: Lint Python (Black) 36 | run: | 37 | poetry run inv formatter 38 | - name: Lint Python (Ruff) 39 | run: | 40 | poetry run inv lint 41 | - name: Lint Python (isort) 42 | run: | 43 | poetry run inv isort 44 | unittests: 45 | runs-on: ubuntu-latest 46 | strategy: 47 | matrix: 48 | python-version: ["3.8", "3.9", "3.10", "3.11"] 49 | pydantic-version: ["^1.10", "^2.0"] 50 | steps: 51 | - name: Checkout 52 | uses: actions/checkout@v3 53 | - name: Install Poetry 54 | run: | 55 | pipx install poetry==1.8 56 | - name: Set up Python 57 | uses: actions/setup-python@v5 58 | with: 59 | python-version: ${{ matrix.python-version }} 60 | cache: poetry 61 | cache-dependency-path: poetry.lock 62 | - name: Set Poetry environment and clear cache 63 | run: | 64 | poetry env use ${{ matrix.python-version }} 65 | poetry cache clear --all pypi 66 | - name: Override Pydantic version 67 | run: | 68 | if [[ "${{ matrix.pydantic-version }}" == ^1.* ]]; then 69 | # Since there's a mismatch between pydantic 1.x and the python-semantic-release version 70 | # we have in the project, we need to downgrade its version when running tests on pydantic v1 71 | poetry remove python-semantic-release 72 | poetry add python-semantic-release@8.0.0 73 | fi 74 | poetry add pydantic@${{ matrix.pydantic-version }} 75 | poetry lock --no-update 76 | - name: Install dependencies 77 | run: | 78 | poetry install --all-extras 79 | - name: Run Tests 80 | env: 81 | AI21_API_KEY: ${{ secrets.AI21_API_KEY }} 82 | AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} 83 | AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} 84 | run: | 85 | poetry run pytest tests/unittests/ 86 | - name: Upload pytest test results 87 | uses: actions/upload-artifact@v4 88 | with: 89 | name: pytest-results-${{ matrix.python-version }} 90 | path: junit/test-results-${{ matrix.python-version }}.xml 91 | # Use always() to always run this step to publish test results when there are test failures 92 | if: ${{ always() }} 93 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # https://github.com/github/gitignore/blob/master/Global/macOS.gitignore 2 | # General 3 | .DS_Store 4 | .AppleDouble 5 | .LSOverride 6 | 7 | # Icon must end with two \r 8 | Icon 9 | 10 | # Thumbnails 11 | ._* 12 | 13 | # Files that might appear in the root of a volume 14 | .DocumentRevisions-V100 15 | .fseventsd 16 | .Spotlight-V100 17 | .TemporaryItems 18 | .Trashes 19 | .VolumeIcon.icns 20 | .com.apple.timemachine.donotpresent 21 | 22 | # Directories potentially created on remote AFP share 23 | .AppleDB 24 | .AppleDesktop 25 | Network Trash Folder 26 | Temporary Items 27 | .apdisk 28 | 29 | # https://github.com/github/gitignore/blob/master/Global/Archives.gitignore 30 | # It's better to unpack these files and commit the raw source because 31 | # git has its own built in compression methods. 32 | *.7z 33 | *.jar 34 | *.rar 35 | *.zip 36 | *.gz 37 | *.gzip 38 | *.tgz 39 | *.bzip 40 | *.bzip2 41 | *.bz2 42 | *.xz 43 | *.lzma 44 | *.cab 45 | *.xar 46 | 47 | # Packing-only formats 48 | *.iso 49 | *.tar 50 | 51 | # Package management formats 52 | *.dmg 53 | *.xpi 54 | *.gem 55 | *.egg 56 | *.deb 57 | *.rpm 58 | *.msi 59 | *.msm 60 | *.msp 61 | *.txz 62 | 63 | # https://github.com/github/gitignore/blob/master/Global/JetBrains.gitignore 64 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 65 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 66 | 67 | # IntelliJ 68 | .idea/ 69 | out/ 70 | 71 | # Gradle and Maven with auto-import 72 | # When using Gradle or Maven with auto-import, you should exclude module files, 73 | # since they will be recreated, and may cause churn. Uncomment if using 74 | # auto-import. 75 | # *.iml 76 | # *.ipr 77 | 78 | # CMake 79 | cmake-build-*/ 80 | 81 | # File-based project format 82 | *.iws 83 | 84 | # mpeltonen/sbt-idea plugin 85 | .idea_modules/ 86 | 87 | # JIRA plugin 88 | atlassian-ide-plugin.xml 89 | 90 | # Crashlytics plugin (for Android Studio and IntelliJ) 91 | com_crashlytics_export_strings.xml 92 | crashlytics.properties 93 | crashlytics-build.properties 94 | fabric.properties 95 | 96 | # https://github.com/github/gitignore/blob/master/Global/VisualStudioCode.gitignore 97 | .vscode/* 98 | !.vscode/settings.json 99 | !.vscode/tasks.json 100 | !.vscode/launch.json 101 | !.vscode/extensions.json 102 | *.code-workspace 103 | 104 | # Local History for Visual Studio Code 105 | .history/ 106 | 107 | # https://github.com/github/gitignore/blob/master/Global/Backup.gitignore 108 | *.bak 109 | *.gho 110 | *.ori 111 | *.orig 112 | *.tmp 113 | 114 | # https://github.com/github/gitignore/blob/master/Global/Diff.gitignore 115 | *.patch 116 | *.diff 117 | 118 | # https://github.com/github/gitignore/blob/master/Global/Patch.gitignore 119 | *.orig 120 | *.rej 121 | __pycache__ 122 | 123 | tests/integration_tests/test-file/* 124 | .vscode/* 125 | .env 126 | -------------------------------------------------------------------------------- /.markdownlint.yaml: -------------------------------------------------------------------------------- 1 | # yaml-language-server: $schema=https://json.schemastore.org/markdownlint.json 2 | 3 | # https://github.com/DavidAnson/markdownlint/blob/main/schema/.markdownlint.yaml 4 | # https://github.com/DavidAnson/markdownlint/blob/main/doc/Rules.md 5 | 6 | # Default state for all rules 7 | default: true 8 | 9 | MD013: 10 | line_length: 300 11 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | minimum_pre_commit_version: 2.20.0 2 | fail_fast: false 3 | default_stages: 4 | - commit 5 | exclude: (.idea|vscode) 6 | repos: 7 | - repo: https://github.com/pre-commit/pre-commit-hooks 8 | rev: v4.4.0 9 | hooks: 10 | - id: check-added-large-files 11 | exclude: (ai21_tokenizer/resources|tests/resources) 12 | - id: check-case-conflict 13 | - id: check-executables-have-shebangs 14 | - id: check-shebang-scripts-are-executable 15 | - id: check-merge-conflict 16 | - id: check-symlinks 17 | - id: detect-private-key 18 | - id: no-commit-to-branch 19 | - repo: https://github.com/pre-commit/pre-commit-hooks 20 | rev: v4.4.0 21 | hooks: 22 | - id: trailing-whitespace 23 | - id: end-of-file-fixer 24 | - id: mixed-line-ending 25 | exclude: (CHANGELOG.md) 26 | - repo: https://github.com/jumanjihouse/pre-commit-hooks 27 | rev: 3.0.0 28 | hooks: 29 | - id: forbid-binary 30 | exclude: (ai21_tokenizer/resources|tests/resources) 31 | - id: git-check 32 | files: "CHANGELOG.md" 33 | - repo: https://github.com/commitizen-tools/commitizen 34 | rev: v3.5.3 35 | hooks: 36 | - id: commitizen 37 | name: Lint commit message 38 | stages: 39 | - commit-msg 40 | - repo: https://github.com/python-jsonschema/check-jsonschema 41 | rev: 0.23.3 42 | hooks: 43 | - id: check-jsonschema 44 | name: Validate Pre-commit 45 | files: .*\.pre-commit-config\.yaml 46 | types: 47 | - yaml 48 | args: 49 | - --schemafile 50 | - https://json.schemastore.org/pre-commit-config.json 51 | - id: check-jsonschema 52 | name: Validate YamlLint configuration 53 | files: .*\.yamllint\.yaml 54 | types: 55 | - yaml 56 | args: 57 | - --schemafile 58 | - https://json.schemastore.org/yamllint.json 59 | - id: check-jsonschema 60 | name: Validate Prettier configuration 61 | files: .*\.prettierrc\.yaml 62 | types: 63 | - yaml 64 | args: 65 | - --schemafile 66 | - http://json.schemastore.org/prettierrc 67 | - repo: https://github.com/python-poetry/poetry 68 | rev: 1.5.0 69 | hooks: 70 | - id: poetry-check 71 | - repo: https://github.com/adrienverge/yamllint 72 | rev: v1.32.0 73 | hooks: 74 | - id: yamllint 75 | name: Lint YAML files 76 | args: 77 | - --format 78 | - parsable 79 | - --strict 80 | - repo: https://github.com/shellcheck-py/shellcheck-py 81 | rev: v0.9.0.5 82 | hooks: 83 | - id: shellcheck 84 | name: Check sh files (and patch) 85 | entry: bash -eo pipefail -c 'shellcheck $@ -f diff | patch -p 1' -- 86 | - id: shellcheck 87 | name: Check sh files (and print violations) 88 | - repo: https://github.com/pre-commit/mirrors-prettier 89 | rev: v3.0.0 90 | hooks: 91 | - id: prettier 92 | name: Formatter 93 | exclude: (CHANGELOG.md) 94 | additional_dependencies: 95 | - prettier@2.8.8 96 | - "prettier-plugin-sh@0.12.8" 97 | types_or: 98 | - yaml 99 | - markdown 100 | - shell 101 | - repo: https://github.com/psf/black 102 | rev: 23.7.0 103 | hooks: 104 | - id: black 105 | types: 106 | - python 107 | - repo: https://github.com/astral-sh/ruff-pre-commit 108 | rev: v0.0.280 109 | hooks: 110 | - id: ruff 111 | args: 112 | - --fix 113 | - repo: local 114 | hooks: 115 | - id: hadolint 116 | name: Lint Dockerfiles 117 | language: docker_image 118 | entry: hadolint/hadolint:v2.10.0 hadolint 119 | types: 120 | - dockerfile 121 | - repo: local 122 | hooks: 123 | - id: check-api-key 124 | name: Check for API keys 125 | entry: .git-hooks/check_api_key.sh 126 | language: system 127 | stages: [commit] 128 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.10.6 2 | -------------------------------------------------------------------------------- /.yamllint.yaml: -------------------------------------------------------------------------------- 1 | # yaml-language-server: $schema=https://json.schemastore.org/yamllint.json 2 | 3 | extends: default 4 | 5 | rules: 6 | line-length: 7 | max: 300 8 | document-start: disable 9 | truthy: 10 | check-keys: false 11 | -------------------------------------------------------------------------------- /ai21/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from ai21.ai21_env_config import AI21EnvConfig 4 | from ai21.clients.azure.ai21_azure_client import AI21AzureClient, AsyncAI21AzureClient 5 | from ai21.clients.studio.ai21_client import AI21Client 6 | from ai21.clients.studio.async_ai21_client import AsyncAI21Client 7 | from ai21.errors import ( 8 | AI21APIError, 9 | AI21Error, 10 | APITimeoutError, 11 | MissingApiKeyError, 12 | ModelPackageDoesntExistError, 13 | TooManyRequestsError, 14 | ) 15 | from ai21.logger import setup_logger 16 | from ai21.version import VERSION 17 | 18 | 19 | __version__ = VERSION 20 | setup_logger() 21 | 22 | 23 | def _import_bedrock_client(): 24 | from ai21.clients.bedrock.ai21_bedrock_client import AI21BedrockClient 25 | 26 | return AI21BedrockClient 27 | 28 | 29 | def _import_bedrock_model_id(): 30 | from ai21.clients.bedrock.bedrock_model_id import BedrockModelID 31 | 32 | return BedrockModelID 33 | 34 | 35 | def _import_async_bedrock_client(): 36 | from ai21.clients.bedrock.ai21_bedrock_client import AsyncAI21BedrockClient 37 | 38 | return AsyncAI21BedrockClient 39 | 40 | 41 | def _import_vertex_client(): 42 | from ai21.clients.vertex.ai21_vertex_client import AI21VertexClient 43 | 44 | return AI21VertexClient 45 | 46 | 47 | def _import_launchpad_client(): 48 | from ai21.clients.launchpad.ai21_launchpad_client import AI21LaunchpadClient 49 | 50 | return AI21LaunchpadClient 51 | 52 | 53 | def _import_async_launchpad_client(): 54 | from ai21.clients.launchpad.ai21_launchpad_client import AsyncAI21LaunchpadClient 55 | 56 | return AsyncAI21LaunchpadClient 57 | 58 | 59 | def _import_async_vertex_client(): 60 | from ai21.clients.vertex.ai21_vertex_client import AsyncAI21VertexClient 61 | 62 | return AsyncAI21VertexClient 63 | 64 | 65 | def __getattr__(name: str) -> Any: 66 | try: 67 | if name == "AI21BedrockClient": 68 | return _import_bedrock_client() 69 | 70 | if name == "BedrockModelID": 71 | return _import_bedrock_model_id() 72 | 73 | if name == "AsyncAI21BedrockClient": 74 | return _import_async_bedrock_client() 75 | 76 | if name == "AI21VertexClient": 77 | return _import_vertex_client() 78 | 79 | if name == "AsyncAI21VertexClient": 80 | return _import_async_vertex_client() 81 | 82 | if name == "AI21LaunchpadClient": 83 | return _import_launchpad_client() 84 | 85 | if name == "AsyncAI21LaunchpadClient": 86 | return _import_async_launchpad_client() 87 | 88 | except ImportError as e: 89 | raise ImportError('Please install "ai21[AWS]" for Bedrock, or "ai21[Vertex]" for Vertex') from e 90 | 91 | 92 | __all__ = [ 93 | "AI21EnvConfig", 94 | "AI21Client", 95 | "AsyncAI21Client", 96 | "AI21APIError", 97 | "APITimeoutError", 98 | "AI21Error", 99 | "MissingApiKeyError", 100 | "ModelPackageDoesntExistError", 101 | "TooManyRequestsError", 102 | "AI21BedrockClient", 103 | "BedrockModelID", 104 | "AI21AzureClient", 105 | "AsyncAI21AzureClient", 106 | "AsyncAI21BedrockClient", 107 | "AI21VertexClient", 108 | "AsyncAI21VertexClient", 109 | "AI21LaunchpadClient", 110 | "AsyncAI21LaunchpadClient", 111 | ] 112 | -------------------------------------------------------------------------------- /ai21/ai21_env_config.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | import os 5 | from dataclasses import dataclass 6 | from typing import Optional 7 | 8 | from ai21.constants import DEFAULT_API_VERSION, STUDIO_HOST 9 | 10 | # Constants for environment variable keys 11 | _ENV_API_KEY = "AI21_API_KEY" 12 | _ENV_API_VERSION = "AI21_API_VERSION" 13 | _ENV_API_HOST = "AI21_API_HOST" 14 | _ENV_TIMEOUT_SEC = "AI21_TIMEOUT_SEC" 15 | _ENV_NUM_RETRIES = "AI21_NUM_RETRIES" 16 | _ENV_AWS_REGION = "AI21_AWS_REGION" 17 | _ENV_LOG_LEVEL = "AI21_LOG_LEVEL" 18 | 19 | _logger = logging.getLogger(__name__) 20 | 21 | 22 | @dataclass 23 | class _AI21EnvConfig: 24 | _api_key: Optional[str] = None 25 | _api_version: str = DEFAULT_API_VERSION 26 | _api_host: str = STUDIO_HOST 27 | _timeout_sec: Optional[int] = None 28 | _num_retries: Optional[int] = None 29 | _aws_region: Optional[str] = None 30 | _log_level: Optional[str] = None 31 | 32 | @classmethod 33 | def from_env(cls) -> _AI21EnvConfig: 34 | return cls( 35 | _api_key=os.getenv(_ENV_API_KEY), 36 | _api_version=os.getenv(_ENV_API_VERSION, DEFAULT_API_VERSION), 37 | _api_host=os.getenv(_ENV_API_HOST, STUDIO_HOST), 38 | _timeout_sec=os.getenv(_ENV_TIMEOUT_SEC), 39 | _num_retries=os.getenv(_ENV_NUM_RETRIES), 40 | _aws_region=os.getenv(_ENV_AWS_REGION, "us-east-1"), 41 | _log_level=os.getenv(_ENV_LOG_LEVEL, "info"), 42 | ) 43 | 44 | @property 45 | def api_key(self) -> str: 46 | self._api_key = os.getenv(_ENV_API_KEY, self._api_key) 47 | return self._api_key 48 | 49 | @property 50 | def api_version(self) -> str: 51 | self._api_version = os.getenv(_ENV_API_VERSION, self._api_version) 52 | return self._api_version 53 | 54 | @property 55 | def api_host(self) -> str: 56 | self._api_host = os.getenv(_ENV_API_HOST, self._api_host) 57 | return self._api_host 58 | 59 | @property 60 | def timeout_sec(self) -> Optional[int]: 61 | timeout_str = os.getenv(_ENV_TIMEOUT_SEC) 62 | 63 | if timeout_str is not None: 64 | self._timeout_sec = int(timeout_str) 65 | 66 | return self._timeout_sec 67 | 68 | @property 69 | def num_retries(self) -> Optional[int]: 70 | retries_str = os.getenv(_ENV_NUM_RETRIES) 71 | 72 | if retries_str is not None: 73 | self._num_retries = int(retries_str) 74 | 75 | return self._num_retries 76 | 77 | @property 78 | def aws_region(self) -> Optional[str]: 79 | self._aws_region = os.getenv(_ENV_AWS_REGION, self._aws_region) 80 | return self._aws_region 81 | 82 | @property 83 | def log_level(self) -> Optional[str]: 84 | self._log_level = os.getenv(_ENV_LOG_LEVEL, self._log_level) 85 | return self._log_level 86 | 87 | def log(self, with_secrets: bool = False) -> None: 88 | env_vars = { 89 | _ENV_API_VERSION: self.api_version, 90 | _ENV_API_HOST: self.api_host, 91 | _ENV_TIMEOUT_SEC: self.timeout_sec, 92 | _ENV_NUM_RETRIES: self.num_retries, 93 | _ENV_AWS_REGION: self.aws_region, 94 | _ENV_LOG_LEVEL: self.log_level, 95 | } 96 | 97 | if with_secrets: 98 | env_vars[_ENV_API_KEY] = self.api_key 99 | 100 | _logger.debug(f"AI21 environment configuration: {env_vars}") 101 | 102 | 103 | AI21EnvConfig = _AI21EnvConfig.from_env() 104 | -------------------------------------------------------------------------------- /ai21/clients/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/ai21/clients/__init__.py -------------------------------------------------------------------------------- /ai21/clients/aws/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/ai21/clients/aws/__init__.py -------------------------------------------------------------------------------- /ai21/clients/aws/aws_authorization.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict 2 | 3 | from botocore.auth import SigV4Auth 4 | from botocore.awsrequest import AWSRequest 5 | import boto3 6 | 7 | 8 | class AWSAuthorization: 9 | def __init__(self, aws_session: boto3.Session): 10 | self._aws_session = aws_session 11 | 12 | def get_auth_headers( 13 | self, 14 | *, 15 | url: str, 16 | service_name: str, 17 | method: str, 18 | data: Optional[str], 19 | ) -> Dict[str, str]: 20 | request = AWSRequest(method=method, url=url, data=data) 21 | credentials = self._aws_session.get_credentials() 22 | 23 | signer = SigV4Auth( 24 | credentials=credentials, service_name=service_name, region_name=self._aws_session.region_name 25 | ) 26 | signer.add_auth(request) 27 | 28 | prepped = request.prepare() 29 | 30 | return {key: value for key, value in dict(prepped.headers).items() if value is not None} 31 | -------------------------------------------------------------------------------- /ai21/clients/azure/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/ai21/clients/azure/__init__.py -------------------------------------------------------------------------------- /ai21/clients/bedrock/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/ai21/clients/bedrock/__init__.py -------------------------------------------------------------------------------- /ai21/clients/bedrock/_stream_decoder.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import json 4 | from functools import lru_cache 5 | from typing import Iterator, AsyncIterator 6 | 7 | import httpx 8 | from botocore.eventstream import EventStreamMessage, EventStreamBuffer 9 | from botocore.model import Shape 10 | from botocore.parsers import EventStreamJSONParser 11 | 12 | from ai21.errors import StreamingDecodeError 13 | from ai21.stream.stream_commons import _SSEDecoderBase 14 | 15 | 16 | _FINISH_REASON_NULL_STR = '"finish_reason":null' 17 | 18 | 19 | @lru_cache(maxsize=None) 20 | def get_response_stream_shape() -> Shape: 21 | from botocore.model import ServiceModel 22 | from botocore.loaders import Loader 23 | 24 | loader = Loader() 25 | bedrock_service_dict = loader.load_service_model("bedrock-runtime", "service-2") 26 | bedrock_service_model = ServiceModel(bedrock_service_dict) 27 | return bedrock_service_model.shape_for("ResponseStream") 28 | 29 | 30 | class _AWSEventStreamDecoder(_SSEDecoderBase): 31 | def __init__(self) -> None: 32 | self._parser = EventStreamJSONParser() 33 | 34 | def iter(self, response: httpx.Response) -> Iterator[str]: 35 | event_stream_buffer = EventStreamBuffer() 36 | previous_item = None 37 | for chunk in response.iter_bytes(): 38 | try: 39 | item = next(self._process_chunks(event_stream_buffer, chunk)) 40 | except StopIteration as e: 41 | raise StreamingDecodeError(chunk=str(chunk), error_message=str(e)) 42 | # For Bedrock metering chunk: 43 | if previous_item is not None: 44 | item = self._build_last_chunk(last_model_chunk=previous_item, bedrock_metrics_chunk=item) 45 | if _FINISH_REASON_NULL_STR not in item and previous_item is None: 46 | previous_item = item 47 | continue 48 | yield item 49 | 50 | async def aiter(self, response: httpx.Response) -> AsyncIterator[str]: 51 | event_stream_buffer = EventStreamBuffer() 52 | previous_item = None 53 | async for chunk in response.aiter_bytes(): 54 | try: 55 | item = next(self._process_chunks(event_stream_buffer, chunk)) 56 | except StopIteration as e: 57 | raise StreamingDecodeError(chunk=str(chunk), error_message=str(e)) 58 | # For Bedrock metering chunk: 59 | if previous_item is not None: 60 | item = self._build_last_chunk(last_model_chunk=previous_item, bedrock_metrics_chunk=item) 61 | if _FINISH_REASON_NULL_STR not in item and previous_item is None: 62 | previous_item = item 63 | continue 64 | yield item 65 | 66 | def _parse_message_from_event(self, event: EventStreamMessage) -> str | None: 67 | response_dict = event.to_response_dict() 68 | parsed_response = self._parser.parse(response_dict, get_response_stream_shape()) 69 | if response_dict["status_code"] != 200: 70 | raise ValueError(f"Bad response code, expected 200: {response_dict}") 71 | 72 | chunk = parsed_response.get("chunk") 73 | if not chunk: 74 | return None 75 | 76 | return chunk.get("bytes").decode() # type: ignore[no-any-return] 77 | 78 | def _build_last_chunk(self, last_model_chunk: str, bedrock_metrics_chunk: str) -> str: 79 | chunk_dict = json.loads(last_model_chunk) 80 | bedrock_metrics_dict = json.loads(bedrock_metrics_chunk) 81 | chunk_dict = {**chunk_dict, **bedrock_metrics_dict} 82 | return json.dumps(chunk_dict) 83 | 84 | def _process_chunks(self, event_stream_buffer, chunk) -> Iterator[str]: 85 | try: 86 | event_stream_buffer.add_data(chunk) 87 | for event in event_stream_buffer: 88 | message = self._parse_message_from_event(event) 89 | if message: 90 | yield message 91 | except Exception as e: 92 | raise StreamingDecodeError(chunk=str(chunk), error_message=str(e)) 93 | -------------------------------------------------------------------------------- /ai21/clients/bedrock/bedrock_model_id.py: -------------------------------------------------------------------------------- 1 | class BedrockModelID: 2 | J2_MID_V1 = "ai21.j2-mid-v1" 3 | J2_ULTRA_V1 = "ai21.j2-ultra-v1" 4 | JAMBA_INSTRUCT_V1 = "ai21.jamba-instruct-v1:0" 5 | JAMBA_1_5_MINI = "ai21.jamba-1-5-mini-v1:0" 6 | JAMBA_1_5_LARGE = "ai21.jamba-1-5-large-v1:0" 7 | -------------------------------------------------------------------------------- /ai21/clients/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/ai21/clients/common/__init__.py -------------------------------------------------------------------------------- /ai21/clients/common/auth/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/ai21/clients/common/auth/__init__.py -------------------------------------------------------------------------------- /ai21/clients/common/auth/gcp_authorization.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Optional, Tuple 4 | 5 | import google.auth 6 | from google.auth.credentials import Credentials 7 | from google.auth.transport.requests import Request 8 | from google.auth.exceptions import DefaultCredentialsError 9 | 10 | from ai21.errors import CredentialsError 11 | 12 | 13 | class GCPAuthorization: 14 | def get_gcp_credentials( 15 | self, 16 | project_id: Optional[str] = None, 17 | ) -> Tuple[Credentials, str]: 18 | try: 19 | credentials, loaded_project_id = google.auth.default( 20 | scopes=["https://www.googleapis.com/auth/cloud-platform"], 21 | ) 22 | except DefaultCredentialsError as e: 23 | raise CredentialsError(provider_name="GCP", error_message=str(e)) 24 | 25 | if project_id is not None and project_id != loaded_project_id: 26 | raise ValueError("Mismatch between credentials project id and 'project_id'") 27 | 28 | project_id = project_id or loaded_project_id 29 | 30 | if project_id is None: 31 | raise ValueError("Could not get project_id for GCP project") 32 | 33 | if not isinstance(project_id, str): 34 | raise ValueError(f"Variable project_id must be a string, got {type(project_id)} instead") 35 | 36 | return credentials, project_id 37 | 38 | def _get_gcp_request(self) -> Request: 39 | return Request() 40 | 41 | def refresh_auth(self, credentials: Credentials) -> None: 42 | request = self._get_gcp_request() 43 | credentials.refresh(request) 44 | -------------------------------------------------------------------------------- /ai21/clients/common/conversational_rag.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from abc import ABC, abstractmethod 3 | from typing import Any, Dict, List 4 | 5 | from ai21.models.chat import ChatMessage 6 | from ai21.models._pydantic_compatibility import _to_dict 7 | from ai21.models.responses.conversational_rag_response import ConversationalRagResponse 8 | from ai21.models.retrieval_strategy import RetrievalStrategy 9 | from ai21.types import NotGiven, NOT_GIVEN 10 | from ai21.utils.typing import remove_not_given 11 | 12 | 13 | class ConversationalRag(ABC): 14 | _module_name = "conversational-rag" 15 | 16 | @abstractmethod 17 | def create( 18 | self, 19 | messages: List[ChatMessage], 20 | *, 21 | path: str | NotGiven = NOT_GIVEN, 22 | labels: List[str] | NotGiven = NOT_GIVEN, 23 | file_ids: List[str] | NotGiven = NOT_GIVEN, 24 | max_segments: int | NotGiven = NOT_GIVEN, 25 | retrieval_strategy: RetrievalStrategy | str | NotGiven = NOT_GIVEN, 26 | retrieval_similarity_threshold: float | NotGiven = NOT_GIVEN, 27 | max_neighbors: int | NotGiven = NOT_GIVEN, 28 | hybrid_search_alpha: float | NotGiven = NOT_GIVEN, 29 | **kwargs, 30 | ) -> ConversationalRagResponse: 31 | """ 32 | :param messages: List of ChatMessage objects. 33 | :param path: Search only files in the specified path or a child path. 34 | :param labels: Search only files with one of these labels. 35 | :param file_ids: List of file IDs to filter the sources by. 36 | :param max_segments: Maximum number of segments to retrieve. 37 | :param retrieval_strategy: The retrieval strategy to use. 38 | :param retrieval_similarity_threshold: The similarity threshold to use for retrieval. 39 | :param max_neighbors: Maximum number of neighbors to retrieve. 40 | :param hybrid_search_alpha: The alpha value to use for hybrid search. 41 | :param kwargs: Additional keyword arguments. 42 | :return: The response object. 43 | """ 44 | pass 45 | 46 | def _create_body( 47 | self, 48 | messages: List[ChatMessage], 49 | *, 50 | path: str | NotGiven, 51 | labels: List[str] | NotGiven, 52 | file_ids: List[str] | NotGiven, 53 | max_segments: int | NotGiven, 54 | retrieval_strategy: RetrievalStrategy | str | NotGiven, 55 | retrieval_similarity_threshold: float | NotGiven, 56 | max_neighbors: int | NotGiven, 57 | hybrid_search_alpha: float | NotGiven, 58 | **kwargs, 59 | ) -> Dict[str, Any]: 60 | return remove_not_given( 61 | { 62 | "messages": [_to_dict(message) for message in messages], 63 | "path": path, 64 | "labels": labels, 65 | "file_ids": file_ids, 66 | "max_segments": max_segments, 67 | "retrieval_strategy": retrieval_strategy, 68 | "retrieval_similarity_threshold": retrieval_similarity_threshold, 69 | "max_neighbors": max_neighbors, 70 | "hybrid_search_alpha": hybrid_search_alpha, 71 | **kwargs, 72 | } 73 | ) 74 | -------------------------------------------------------------------------------- /ai21/clients/common/maestro/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/ai21/clients/common/maestro/__init__.py -------------------------------------------------------------------------------- /ai21/clients/common/maestro/maestro.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | from ai21.clients.common.maestro.run import BaseMaestroRun 4 | 5 | 6 | class BaseMaestro(ABC): 7 | _module_name = "maestro" 8 | runs: BaseMaestroRun 9 | -------------------------------------------------------------------------------- /ai21/clients/common/maestro/run.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import List 5 | 6 | from ai21.models.chat import ChatMessage 7 | from ai21.models.maestro.run import ( 8 | Tool, 9 | ToolResources, 10 | RunResponse, 11 | DEFAULT_RUN_POLL_INTERVAL, 12 | DEFAULT_RUN_POLL_TIMEOUT, 13 | Requirement, 14 | Budget, 15 | OutputOptions, 16 | ) 17 | from ai21.types import NOT_GIVEN, NotGiven 18 | from ai21.utils.typing import remove_not_given 19 | 20 | 21 | class BaseMaestroRun(ABC): 22 | _module_name = "maestro/runs" 23 | 24 | def _create_body( 25 | self, 26 | *, 27 | input: str | List[ChatMessage], 28 | models: List[str] | NotGiven, 29 | tools: List[Tool] | NotGiven, 30 | tool_resources: ToolResources | NotGiven, 31 | requirements: List[Requirement] | NotGiven, 32 | budget: Budget | NotGiven, 33 | include: List[OutputOptions] | NotGiven, 34 | **kwargs, 35 | ) -> dict: 36 | return remove_not_given( 37 | { 38 | "input": input, 39 | "models": models, 40 | "tools": tools, 41 | "tool_resources": tool_resources, 42 | "requirements": requirements, 43 | "budget": budget, 44 | "include": include, 45 | **kwargs, 46 | } 47 | ) 48 | 49 | @abstractmethod 50 | def create( 51 | self, 52 | *, 53 | input: str | List[ChatMessage], 54 | models: List[str] | NotGiven = NOT_GIVEN, 55 | tools: List[Tool] | NotGiven = NOT_GIVEN, 56 | tool_resources: ToolResources | NotGiven = NOT_GIVEN, 57 | requirements: List[Requirement] | NotGiven = NOT_GIVEN, 58 | budget: Budget | NotGiven = NOT_GIVEN, 59 | include: List[OutputOptions] | NotGiven = NOT_GIVEN, 60 | **kwargs, 61 | ) -> RunResponse: 62 | pass 63 | 64 | @abstractmethod 65 | def retrieve(self, run_id: str) -> RunResponse: 66 | pass 67 | 68 | @abstractmethod 69 | def _poll_for_status(self, *, run_id: str, poll_interval: float, poll_timeout: float) -> RunResponse: 70 | pass 71 | 72 | @abstractmethod 73 | def create_and_poll( 74 | self, 75 | *, 76 | input: str | List[ChatMessage], 77 | models: List[str] | NotGiven = NOT_GIVEN, 78 | tools: List[Tool] | NotGiven = NOT_GIVEN, 79 | tool_resources: ToolResources | NotGiven = NOT_GIVEN, 80 | requirements: List[Requirement] | NotGiven = NOT_GIVEN, 81 | budget: Budget | NotGiven = NOT_GIVEN, 82 | include: List[OutputOptions] | NotGiven = NOT_GIVEN, 83 | poll_interval_sec: float = DEFAULT_RUN_POLL_INTERVAL, 84 | poll_timeout_sec: float = DEFAULT_RUN_POLL_TIMEOUT, 85 | **kwargs, 86 | ) -> RunResponse: 87 | pass 88 | -------------------------------------------------------------------------------- /ai21/clients/launchpad/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/ai21/clients/launchpad/__init__.py -------------------------------------------------------------------------------- /ai21/clients/studio/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/ai21/clients/studio/__init__.py -------------------------------------------------------------------------------- /ai21/clients/studio/ai21_client.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | 3 | import httpx 4 | 5 | from ai21.ai21_env_config import AI21EnvConfig, _AI21EnvConfig 6 | from ai21.clients.studio.resources.beta.beta import Beta 7 | from ai21.clients.studio.resources.studio_chat import StudioChat 8 | from ai21.clients.studio.resources.studio_library import StudioLibrary 9 | from ai21.http_client.http_client import AI21HTTPClient 10 | from ai21.tokenizers.ai21_tokenizer import AI21Tokenizer 11 | 12 | 13 | class AI21Client(AI21HTTPClient): 14 | """ 15 | This class would be sending requests to our REST API using http requests 16 | """ 17 | 18 | _tokenizer: Optional[AI21Tokenizer] 19 | 20 | def __init__( 21 | self, 22 | api_key: Optional[str] = None, 23 | api_host: Optional[str] = None, 24 | headers: Optional[Dict[str, Any]] = None, 25 | timeout_sec: Optional[float] = None, 26 | num_retries: Optional[int] = None, 27 | via: Optional[str] = None, 28 | http_client: Optional[httpx.Client] = None, 29 | env_config: _AI21EnvConfig = AI21EnvConfig, 30 | **kwargs, 31 | ): 32 | base_url = api_host or env_config.api_host 33 | super().__init__( 34 | api_key=api_key or env_config.api_key, 35 | base_url=base_url, 36 | headers=headers, 37 | timeout_sec=timeout_sec or env_config.timeout_sec, 38 | num_retries=num_retries or env_config.num_retries, 39 | via=via, 40 | client=http_client, 41 | ) 42 | self.chat: StudioChat = StudioChat(self) 43 | self.library = StudioLibrary(self) 44 | self.beta = Beta(self) 45 | -------------------------------------------------------------------------------- /ai21/clients/studio/async_ai21_client.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | 3 | import httpx 4 | 5 | from ai21.ai21_env_config import AI21EnvConfig, _AI21EnvConfig 6 | from ai21.clients.studio.resources.beta.async_beta import AsyncBeta 7 | from ai21.clients.studio.resources.studio_chat import AsyncStudioChat 8 | from ai21.clients.studio.resources.studio_library import AsyncStudioLibrary 9 | from ai21.http_client.async_http_client import AsyncAI21HTTPClient 10 | 11 | 12 | class AsyncAI21Client(AsyncAI21HTTPClient): 13 | """ 14 | This class would be sending requests to our REST API using http requests asynchronously 15 | """ 16 | 17 | def __init__( 18 | self, 19 | api_key: Optional[str] = None, 20 | api_host: Optional[str] = None, 21 | headers: Optional[Dict[str, Any]] = None, 22 | timeout_sec: Optional[float] = None, 23 | num_retries: Optional[int] = None, 24 | via: Optional[str] = None, 25 | http_client: Optional[httpx.AsyncClient] = None, 26 | env_config: _AI21EnvConfig = AI21EnvConfig, 27 | **kwargs, 28 | ): 29 | base_url = api_host or env_config.api_host 30 | 31 | super().__init__( 32 | api_key=api_key or env_config.api_key, 33 | base_url=base_url, 34 | headers=headers, 35 | timeout_sec=timeout_sec or env_config.timeout_sec, 36 | num_retries=num_retries or env_config.num_retries, 37 | via=via, 38 | client=http_client, 39 | ) 40 | 41 | self.chat: AsyncStudioChat = AsyncStudioChat(self) 42 | self.library = AsyncStudioLibrary(self) 43 | self.beta = AsyncBeta(self) 44 | -------------------------------------------------------------------------------- /ai21/clients/studio/resources/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/ai21/clients/studio/resources/__init__.py -------------------------------------------------------------------------------- /ai21/clients/studio/resources/batch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/ai21/clients/studio/resources/batch/__init__.py -------------------------------------------------------------------------------- /ai21/clients/studio/resources/batch/async_batches.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from os import PathLike 4 | from typing import Any, Dict, Literal 5 | 6 | from ai21.clients.studio.resources.batch.base_batches import BaseBatches 7 | from ai21.clients.studio.resources.studio_resource import AsyncStudioResource 8 | from ai21.files.downloaded_file import DownloadedFile 9 | from ai21.models.responses.batch_response import Batch 10 | from ai21.pagination.async_pagination import AsyncPagination 11 | from ai21.types import NOT_GIVEN, NotGiven 12 | 13 | 14 | class AsyncBatches(AsyncStudioResource, BaseBatches): 15 | async def create( 16 | self, 17 | file: PathLike, 18 | endpoint: Literal["/v1/chat/completions"], 19 | metadata: Dict[str, str] | NotGiven = NOT_GIVEN, 20 | **kwargs: Any, 21 | ) -> Batch: 22 | files = {"file": open(file, "rb")} 23 | body = self._create_body(endpoint=endpoint, metadata=metadata, **kwargs) 24 | 25 | return await self._post(path=f"/{self._module_name}", files=files, body=body, response_cls=dict) 26 | 27 | async def retrieve(self, batch_id: str) -> Batch: 28 | return await self._get(path=f"/{self._module_name}/{batch_id}", response_cls=dict) 29 | 30 | async def list( 31 | self, 32 | after: str | NotGiven = NOT_GIVEN, 33 | limit: int | NotGiven = NOT_GIVEN, 34 | **kwargs: Any, 35 | ) -> AsyncPagination[Batch]: 36 | """List your organization's batches. 37 | 38 | Args: 39 | after: A cursor for pagination. Provide a batch ID to fetch results 40 | starting after this batch. Useful for navigating through large 41 | result sets. 42 | 43 | limit: Maximum number of batches to return per page. Value must be 44 | between 1 and 100. Defaults to 20 if not specified. 45 | 46 | Returns: 47 | A paginator object that yields pages of batch results when iterated. 48 | """ 49 | params = self._create_list_params(after=after, limit=limit) 50 | return await self._list( 51 | path=f"/{self._module_name}", 52 | params=params, 53 | pagination_cls=AsyncPagination[Batch], 54 | response_cls=Batch, 55 | **kwargs, 56 | ) 57 | 58 | async def cancel(self, batch_id: str) -> Batch: 59 | return await self._post(path=f"/{self._module_name}/{batch_id}/cancel", response_cls=dict) 60 | 61 | async def get_results( 62 | self, 63 | batch_id: str, 64 | file_type: Literal["output", "error"] | NotGiven = NOT_GIVEN, 65 | force: bool | NotGiven = NOT_GIVEN, 66 | **kwargs: Any, 67 | ) -> DownloadedFile: 68 | return await self._get( 69 | path=f"/{self._module_name}/{batch_id}/results", 70 | params={"file_type": file_type, "force": force}, 71 | response_cls=DownloadedFile, 72 | **kwargs, 73 | ) 74 | -------------------------------------------------------------------------------- /ai21/clients/studio/resources/batch/base_batches.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import ABC 4 | from typing import Any, Dict 5 | 6 | from ai21.types import NOT_GIVEN, NotGiven 7 | from ai21.utils.typing import remove_not_given 8 | 9 | 10 | class BaseBatches(ABC): 11 | _module_name = "batches" 12 | 13 | def _create_body( 14 | self, 15 | endpoint: str, 16 | metadata: Dict[str, str] | NotGiven = NOT_GIVEN, 17 | **kwargs: Any, 18 | ) -> Dict[str, Any]: 19 | return remove_not_given( 20 | { 21 | "endpoint": endpoint, 22 | "metadata": metadata, 23 | **kwargs, 24 | } 25 | ) 26 | 27 | def _create_list_params( 28 | self, 29 | after: str | NotGiven, 30 | limit: int | NotGiven, 31 | ) -> Dict[str, Any]: 32 | return remove_not_given({"after": after, "limit": limit}) 33 | -------------------------------------------------------------------------------- /ai21/clients/studio/resources/batch/batches.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from os import PathLike 4 | from typing import Any, Dict, Literal 5 | 6 | from ai21.clients.studio.resources.batch.base_batches import BaseBatches 7 | from ai21.clients.studio.resources.studio_resource import StudioResource 8 | from ai21.files.downloaded_file import DownloadedFile 9 | from ai21.models.responses.batch_response import Batch 10 | from ai21.pagination.sync_pagination import SyncPagination 11 | from ai21.types import NOT_GIVEN, NotGiven 12 | 13 | 14 | class Batches(StudioResource, BaseBatches): 15 | def create( 16 | self, 17 | file: str | PathLike[str], 18 | endpoint: Literal["/v1/chat/completions"], 19 | metadata: Dict[str, str] | NotGiven = NOT_GIVEN, 20 | **kwargs: Any, 21 | ) -> Batch: 22 | files = {"file": open(file, "rb")} 23 | body = self._create_body(endpoint=endpoint, metadata=metadata, **kwargs) 24 | 25 | return self._post(path=f"/{self._module_name}", files=files, body=body, response_cls=Batch) 26 | 27 | def retrieve(self, batch_id: str) -> Batch: 28 | return self._get(path=f"/{self._module_name}/{batch_id}", response_cls=Batch) 29 | 30 | def list( 31 | self, 32 | after: str | NotGiven = NOT_GIVEN, 33 | limit: int | NotGiven = NOT_GIVEN, 34 | **kwargs: Any, 35 | ) -> SyncPagination[Batch]: 36 | """List your organization's batches. 37 | 38 | Args: 39 | after: A cursor for pagination. Provide a batch ID to fetch results 40 | starting after this batch. Useful for navigating through large 41 | result sets. 42 | 43 | limit: Maximum number of batches to return per page. Value must be 44 | between 1 and 100. Defaults to 20 if not specified. 45 | 46 | Returns: 47 | A paginator object that yields pages of batch results when iterated. 48 | """ 49 | params = self._create_list_params(after=after, limit=limit) 50 | return self._list( 51 | path=f"/{self._module_name}", 52 | params=params, 53 | pagination_cls=SyncPagination[Batch], 54 | response_cls=Batch, 55 | **kwargs, 56 | ) 57 | 58 | def cancel(self, batch_id: str) -> Batch: 59 | return self._post(path=f"/{self._module_name}/{batch_id}/cancel", response_cls=dict) 60 | 61 | def get_results( 62 | self, 63 | batch_id: str, 64 | file_type: Literal["output", "error"] | NotGiven = NOT_GIVEN, 65 | force: bool | NotGiven = NOT_GIVEN, 66 | **kwargs: Any, 67 | ) -> DownloadedFile: 68 | return self._get( 69 | path=f"/{self._module_name}/{batch_id}/results", 70 | params={"file_type": file_type, "force": force}, 71 | response_cls=DownloadedFile, 72 | **kwargs, 73 | ) 74 | -------------------------------------------------------------------------------- /ai21/clients/studio/resources/beta/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/ai21/clients/studio/resources/beta/__init__.py -------------------------------------------------------------------------------- /ai21/clients/studio/resources/beta/async_beta.py: -------------------------------------------------------------------------------- 1 | from ai21.clients.studio.resources.batch.async_batches import AsyncBatches 2 | from ai21.clients.studio.resources.maestro.maestro import AsyncMaestro 3 | from ai21.clients.studio.resources.studio_conversational_rag import ( 4 | AsyncStudioConversationalRag, 5 | ) 6 | from ai21.clients.studio.resources.studio_resource import AsyncStudioResource 7 | from ai21.http_client.async_http_client import AsyncAI21HTTPClient 8 | 9 | 10 | class AsyncBeta(AsyncStudioResource): 11 | def __init__(self, client: AsyncAI21HTTPClient): 12 | super().__init__(client) 13 | 14 | self.conversational_rag = AsyncStudioConversationalRag(client) 15 | self.maestro = AsyncMaestro(client) 16 | self.batches = AsyncBatches(client) 17 | -------------------------------------------------------------------------------- /ai21/clients/studio/resources/beta/beta.py: -------------------------------------------------------------------------------- 1 | from ai21.clients.studio.resources.batch.batches import Batches 2 | from ai21.clients.studio.resources.maestro.maestro import Maestro 3 | from ai21.clients.studio.resources.studio_conversational_rag import ( 4 | StudioConversationalRag, 5 | ) 6 | from ai21.clients.studio.resources.studio_resource import StudioResource 7 | from ai21.http_client.http_client import AI21HTTPClient 8 | 9 | 10 | class Beta(StudioResource): 11 | def __init__(self, client: AI21HTTPClient): 12 | super().__init__(client) 13 | 14 | self.conversational_rag = StudioConversationalRag(client) 15 | self.maestro = Maestro(client) 16 | self.batches = Batches(client) 17 | -------------------------------------------------------------------------------- /ai21/clients/studio/resources/chat/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from .chat_completions import ChatCompletions as ChatCompletions 4 | from .async_chat_completions import AsyncChatCompletions as AsyncChatCompletions 5 | -------------------------------------------------------------------------------- /ai21/clients/studio/resources/chat/base_chat_completions.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import warnings 4 | 5 | from abc import ABC 6 | from typing import Any, Dict, List, Literal, Optional, Union 7 | 8 | from ai21.models._pydantic_compatibility import _to_dict 9 | from ai21.models.chat.chat_message import ChatMessageParam 10 | from ai21.models.chat.document_schema import DocumentSchema 11 | from ai21.models.chat.response_format import ResponseFormat 12 | from ai21.models.chat.tool_defintions import ToolDefinition 13 | from ai21.types import NotGiven 14 | from ai21.utils.typing import remove_not_given 15 | 16 | 17 | _MODEL_DEPRECATION_WARNING = """ 18 | The 'jamba-1.5-mini' and 'jamba-1.5-large' models are deprecated and will 19 | be removed in a future version. 20 | Please use jamba-mini-1.6-2025-03 or jamba-large-1.6-2025-03 instead. 21 | """ 22 | 23 | 24 | class BaseChatCompletions(ABC): 25 | _module_name = "chat/completions" 26 | 27 | def _check_model(self, model: Optional[str]) -> str: 28 | if not model: 29 | raise ValueError("model should be provided 'create' method call") 30 | 31 | if model in ["jamba-1.5-mini", "jamba-1.5-large"]: 32 | warnings.warn( 33 | _MODEL_DEPRECATION_WARNING, 34 | DeprecationWarning, 35 | stacklevel=3, 36 | ) 37 | 38 | return model 39 | 40 | def _create_body( 41 | self, 42 | model: str, 43 | messages: List[ChatMessageParam], 44 | max_tokens: Optional[int] | NotGiven, 45 | temperature: Optional[float] | NotGiven, 46 | top_p: Optional[float] | NotGiven, 47 | stop: Optional[Union[str, List[str]]] | NotGiven, 48 | n: Optional[int] | NotGiven, 49 | stream: Literal[False] | Literal[True] | NotGiven, 50 | tools: List[ToolDefinition] | NotGiven, 51 | response_format: ResponseFormat | NotGiven, 52 | documents: List[DocumentSchema] | NotGiven, 53 | **kwargs: Any, 54 | ) -> Dict[str, Any]: 55 | return remove_not_given( 56 | { 57 | "model": model, 58 | "messages": [_to_dict(message) for message in messages], 59 | "temperature": temperature, 60 | "max_tokens": max_tokens, 61 | "top_p": top_p, 62 | "stop": stop, 63 | "n": n, 64 | "stream": stream, 65 | "tools": tools, 66 | "response_format": response_format, 67 | "documents": documents, 68 | **kwargs, 69 | } 70 | ) 71 | -------------------------------------------------------------------------------- /ai21/clients/studio/resources/constants.py: -------------------------------------------------------------------------------- 1 | CHAT_DEFAULT_NUM_RESULTS = 1 2 | CHAT_DEFAULT_TEMPERATURE = 0.7 3 | CHAT_DEFAULT_MAX_TOKENS = 300 4 | CHAT_DEFAULT_MIN_TOKENS = 0 5 | CHAT_DEFAULT_TOP_P = 1.0 6 | CHAT_DEFAULT_TOP_K_RETURN = 0 7 | -------------------------------------------------------------------------------- /ai21/clients/studio/resources/maestro/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/ai21/clients/studio/resources/maestro/__init__.py -------------------------------------------------------------------------------- /ai21/clients/studio/resources/maestro/maestro.py: -------------------------------------------------------------------------------- 1 | from ai21.clients.common.maestro.maestro import BaseMaestro 2 | from ai21.clients.studio.resources.maestro.run import MaestroRun, AsyncMaestroRun 3 | from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource 4 | from ai21.http_client.async_http_client import AsyncAI21HTTPClient 5 | from ai21.http_client.http_client import AI21HTTPClient 6 | 7 | 8 | class Maestro(StudioResource, BaseMaestro): 9 | def __init__(self, client: AI21HTTPClient): 10 | super().__init__(client) 11 | 12 | self.runs = MaestroRun(client) 13 | 14 | 15 | class AsyncMaestro(AsyncStudioResource, BaseMaestro): 16 | def __init__(self, client: AsyncAI21HTTPClient): 17 | super().__init__(client) 18 | 19 | self.runs = AsyncMaestroRun(client) 20 | -------------------------------------------------------------------------------- /ai21/clients/studio/resources/studio_conversational_rag.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import List 3 | 4 | from ai21.clients.common.conversational_rag import ConversationalRag 5 | from ai21.clients.studio.resources.studio_resource import ( 6 | AsyncStudioResource, 7 | StudioResource, 8 | ) 9 | from ai21.models.chat import ChatMessage 10 | from ai21.models.responses.conversational_rag_response import ConversationalRagResponse 11 | from ai21.models.retrieval_strategy import RetrievalStrategy 12 | from ai21.types import NotGiven, NOT_GIVEN 13 | 14 | 15 | class StudioConversationalRag(StudioResource, ConversationalRag): 16 | def create( 17 | self, 18 | messages: List[ChatMessage], 19 | *, 20 | path: str | NotGiven = NOT_GIVEN, 21 | labels: List[str] | NotGiven = NOT_GIVEN, 22 | file_ids: List[str] | NotGiven = NOT_GIVEN, 23 | max_segments: int | NotGiven = NOT_GIVEN, 24 | retrieval_strategy: RetrievalStrategy | str | NotGiven = NOT_GIVEN, 25 | retrieval_similarity_threshold: float | NotGiven = NOT_GIVEN, 26 | max_neighbors: int | NotGiven = NOT_GIVEN, 27 | hybrid_search_alpha: float | NotGiven = NOT_GIVEN, 28 | **kwargs, 29 | ) -> ConversationalRagResponse: 30 | body = self._create_body( 31 | messages=messages, 32 | path=path, 33 | labels=labels, 34 | file_ids=file_ids, 35 | max_segments=max_segments, 36 | retrieval_strategy=retrieval_strategy, 37 | retrieval_similarity_threshold=retrieval_similarity_threshold, 38 | max_neighbors=max_neighbors, 39 | hybrid_search_alpha=hybrid_search_alpha, 40 | **kwargs, 41 | ) 42 | 43 | return self._post(path=f"/{self._module_name}", body=body, response_cls=ConversationalRagResponse) 44 | 45 | 46 | class AsyncStudioConversationalRag(AsyncStudioResource, ConversationalRag): 47 | async def create( 48 | self, 49 | messages: List[ChatMessage], 50 | *, 51 | path: str | NotGiven = NOT_GIVEN, 52 | labels: List[str] | NotGiven = NOT_GIVEN, 53 | file_ids: List[str] | NotGiven = NOT_GIVEN, 54 | max_segments: int | NotGiven = NOT_GIVEN, 55 | retrieval_strategy: RetrievalStrategy | str | NotGiven = NOT_GIVEN, 56 | retrieval_similarity_threshold: float | NotGiven = NOT_GIVEN, 57 | max_neighbors: int | NotGiven = NOT_GIVEN, 58 | hybrid_search_alpha: float | NotGiven = NOT_GIVEN, 59 | **kwargs, 60 | ) -> ConversationalRagResponse: 61 | body = self._create_body( 62 | messages=messages, 63 | path=path, 64 | labels=labels, 65 | file_ids=file_ids, 66 | max_segments=max_segments, 67 | retrieval_strategy=retrieval_strategy, 68 | retrieval_similarity_threshold=retrieval_similarity_threshold, 69 | max_neighbors=max_neighbors, 70 | hybrid_search_alpha=hybrid_search_alpha, 71 | **kwargs, 72 | ) 73 | 74 | return await self._post(path=f"/{self._module_name}", body=body, response_cls=ConversationalRagResponse) 75 | -------------------------------------------------------------------------------- /ai21/clients/vertex/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/ai21/clients/vertex/__init__.py -------------------------------------------------------------------------------- /ai21/constants.py: -------------------------------------------------------------------------------- 1 | DEFAULT_API_VERSION = "v1" 2 | STUDIO_HOST = "https://api.ai21.com/studio/v1" 3 | -------------------------------------------------------------------------------- /ai21/errors.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | 4 | class AI21APIError(Exception): 5 | def __init__(self, status_code: int, details: Optional[str] = None): 6 | super().__init__(details) 7 | self.details = details 8 | self.status_code = status_code 9 | 10 | def __str__(self) -> str: 11 | return f"Failed with http status code: {self.status_code} ({type(self).__name__}). Details: {self.details}" 12 | 13 | 14 | class BadRequest(AI21APIError): 15 | def __init__(self, details: Optional[str] = None): 16 | super().__init__(400, details) 17 | 18 | 19 | class Unauthorized(AI21APIError): 20 | def __init__(self, details: Optional[str] = None): 21 | super().__init__(401, details) 22 | 23 | 24 | class AccessDenied(AI21APIError): 25 | def __init__(self, details: Optional[str] = None): 26 | super().__init__(403, details) 27 | 28 | 29 | class NotFound(AI21APIError): 30 | def __init__(self, details: Optional[str] = None): 31 | super().__init__(404, details) 32 | 33 | 34 | class APITimeoutError(AI21APIError): 35 | def __init__(self, details: Optional[str] = None): 36 | super().__init__(408, details) 37 | 38 | 39 | class UnprocessableEntity(AI21APIError): 40 | def __init__(self, details: Optional[str] = None): 41 | super().__init__(422, details) 42 | 43 | 44 | class ModelErrorException(AI21APIError): 45 | def __init__(self, details: Optional[str] = None): 46 | super().__init__(424, details) 47 | 48 | 49 | class TooManyRequestsError(AI21APIError): 50 | def __init__(self, details: Optional[str] = None): 51 | super().__init__(429, details) 52 | 53 | 54 | class AI21ServerError(AI21APIError): 55 | def __init__(self, details: Optional[str] = None): 56 | super().__init__(500, details) 57 | 58 | 59 | class ServiceUnavailable(AI21APIError): 60 | def __init__(self, details: Optional[str] = None): 61 | super().__init__(503, details) 62 | 63 | 64 | class AI21Error(Exception): 65 | def __init__(self, message: str): 66 | self.message = message 67 | super().__init__(message) 68 | 69 | def __str__(self) -> str: 70 | return f"{type(self).__name__} {self.message}" 71 | 72 | 73 | class MissingApiKeyError(AI21Error): 74 | def __init__(self): 75 | message = "API key must be supplied either globally in the ai21 namespace, or to be provided in the call args" 76 | super().__init__(message) 77 | self.message = message 78 | 79 | 80 | class ModelPackageDoesntExistError(AI21Error): 81 | def __init__(self, model_name: str, region: str, version: Optional[str] = None): 82 | message = f"model_name: {model_name} doesn't exist in region: {region}" 83 | 84 | if version is not None: 85 | message += f" with version: {version}" 86 | 87 | super().__init__(message) 88 | self.message = message 89 | 90 | 91 | class EmptyMandatoryListError(AI21Error): 92 | def __init__(self, key: str): 93 | message = f"Supplied {key} is empty. At least one element should be present in the list" 94 | super().__init__(message) 95 | 96 | 97 | class CredentialsError(AI21Error): 98 | def __init__(self, provider_name: str, error_message: str): 99 | message = f"Could not get default {provider_name} credentials: {error_message}" 100 | super().__init__(message) 101 | 102 | 103 | class StreamingDecodeError(AI21Error): 104 | def __init__(self, chunk: str, error_message: Optional[str] = None): 105 | message = f"Failed to decode chunk: {chunk} in stream. Please check the stream format." 106 | if error_message: 107 | message = f"{message} Error: {error_message}" 108 | super().__init__(message) 109 | 110 | 111 | class InternalDependencyException(AI21APIError): 112 | def __init__(self, details: Optional[str] = None): 113 | super().__init__(530, details) 114 | -------------------------------------------------------------------------------- /ai21/files/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/ai21/files/__init__.py -------------------------------------------------------------------------------- /ai21/files/downloaded_file.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | 5 | from typing import Iterator 6 | 7 | import httpx 8 | 9 | 10 | class DownloadedFile: 11 | def __init__(self, httpx_response: httpx.Response): 12 | self._response = httpx_response 13 | 14 | @property 15 | def content(self) -> bytes: 16 | return self._response.content 17 | 18 | def read(self) -> bytes: 19 | return self._response.read() 20 | 21 | def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]: 22 | return self._response.iter_bytes(chunk_size) 23 | 24 | def iter_text(self, chunk_size: int | None = None) -> Iterator[str]: 25 | return self._response.iter_text(chunk_size) 26 | 27 | def iter_lines(self) -> Iterator[str]: 28 | return self._response.iter_lines() 29 | 30 | def write_to_file(self, path: str | os.PathLike[str]): 31 | with open(path, "wb") as f: 32 | for data in self._response.iter_bytes(): 33 | f.write(data) 34 | -------------------------------------------------------------------------------- /ai21/http_client/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/ai21/http_client/__init__.py -------------------------------------------------------------------------------- /ai21/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import re 4 | 5 | from ai21.ai21_env_config import AI21EnvConfig 6 | 7 | 8 | _verbose = False 9 | 10 | logger = logging.getLogger("ai21") 11 | httpx_logger = logging.getLogger("httpx") 12 | 13 | 14 | class CensorSecretsFormatter(logging.Formatter): 15 | def format(self, record: logging.LogRecord) -> str: 16 | # Get original log message 17 | message = super().format(record) 18 | 19 | if not get_verbose(): 20 | return self._censor_secrets(message) 21 | 22 | return message 23 | 24 | def _censor_secrets(self, message: str) -> str: 25 | # Regular expression to find the Authorization key and its value 26 | pattern = r"('Authorization':\s*'[^']*'|'api-key':\s*'[^']*'|'X-Amz-Security-Token':\s*'[^']*')" 27 | 28 | def replacement(match): 29 | return match.group(0).split(":")[0] + ": '**************'" 30 | 31 | # Substitute the Authorization value with ************** 32 | return re.sub(pattern, replacement, message) 33 | 34 | 35 | def set_verbose(value: bool) -> None: 36 | """ 37 | Use this function if you want to log additional, more sensitive data like - secrets and environment variables. 38 | Log level will be set to DEBUG if verbose is set to True. 39 | """ 40 | global _verbose 41 | _verbose = value 42 | 43 | set_debug(_verbose) 44 | 45 | AI21EnvConfig.log(with_secrets=value) 46 | 47 | 48 | def set_debug(value: bool) -> None: 49 | """ 50 | Additional way to set log level to DEBUG. 51 | """ 52 | if value: 53 | os.environ["AI21_LOG_LEVEL"] = "debug" 54 | else: 55 | os.environ["AI21_LOG_LEVEL"] = "info" 56 | 57 | setup_logger() 58 | 59 | 60 | def get_verbose() -> bool: 61 | global _verbose 62 | return _verbose 63 | 64 | 65 | def setup_logger() -> None: 66 | handler = logging.StreamHandler() 67 | 68 | handler.setFormatter( 69 | CensorSecretsFormatter(fmt="[%(asctime)s - %(name)s - %(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S") 70 | ) 71 | 72 | logger.addHandler(handler) 73 | httpx_logger.addHandler(handler) 74 | 75 | if AI21EnvConfig.log_level.lower() == "debug": 76 | logger.setLevel(logging.DEBUG) 77 | httpx_logger.setLevel(logging.DEBUG) 78 | elif AI21EnvConfig.log_level.lower() == "info": 79 | logger.setLevel(logging.INFO) 80 | -------------------------------------------------------------------------------- /ai21/models/__init__.py: -------------------------------------------------------------------------------- 1 | from ai21.models.chat.role_type import RoleType 2 | from ai21.models.chat_message import ChatMessage 3 | from ai21.models.document_type import DocumentType 4 | from ai21.models.penalty import Penalty 5 | from ai21.models.responses.chat_response import ChatOutput, ChatResponse, FinishReason 6 | from ai21.models.responses.conversational_rag_response import ( 7 | ConversationalRagResponse, 8 | ConversationalRagSource, 9 | ) 10 | from ai21.models.responses.file_response import FileResponse 11 | from ai21.models.maestro.run import ( 12 | Requirement, 13 | Budget, 14 | Tool, 15 | ToolResources, 16 | DataSources, 17 | FileSearchResult, 18 | WebSearchResult, 19 | OutputOptions, 20 | ) 21 | 22 | __all__ = [ 23 | "ChatMessage", 24 | "RoleType", 25 | "Penalty", 26 | "DocumentType", 27 | "ChatResponse", 28 | "ChatOutput", 29 | "FinishReason", 30 | "FileResponse", 31 | "ConversationalRagResponse", 32 | "ConversationalRagSource", 33 | "Requirement", 34 | "Budget", 35 | "Tool", 36 | "ToolResources", 37 | "DataSources", 38 | "FileSearchResult", 39 | "WebSearchResult", 40 | "OutputOptions", 41 | ] 42 | -------------------------------------------------------------------------------- /ai21/models/_pydantic_compatibility.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Dict, Any, Type 4 | 5 | from pydantic import VERSION, BaseModel 6 | 7 | IS_PYDANTIC_V2 = VERSION.startswith("2.") 8 | 9 | 10 | def _to_dict(model_object: BaseModel, **kwargs) -> Dict[str, Any]: 11 | if IS_PYDANTIC_V2: 12 | return model_object.model_dump(**kwargs) 13 | 14 | return model_object.dict(**kwargs) 15 | 16 | 17 | def _to_json(model_object: BaseModel, **kwargs) -> str: 18 | if IS_PYDANTIC_V2: 19 | return model_object.model_dump_json(**kwargs) 20 | 21 | return model_object.json(**kwargs) 22 | 23 | 24 | def _from_dict(obj: "AI21BaseModel", obj_dict: Any, **kwargs) -> BaseModel: # noqa: F821 25 | if IS_PYDANTIC_V2: 26 | return obj.model_validate(obj_dict, **kwargs) 27 | 28 | return obj.parse_obj(obj_dict, **kwargs) 29 | 30 | 31 | def _from_json(obj: "AI21BaseModel", json_str: str, **kwargs) -> BaseModel: # noqa: F821 32 | if IS_PYDANTIC_V2: 33 | return obj.model_validate_json(json_str, **kwargs) 34 | 35 | return obj.parse_raw(json_str, **kwargs) 36 | 37 | 38 | def _to_schema(model_object: Type[BaseModel], **kwargs) -> Dict[str, Any]: 39 | if IS_PYDANTIC_V2: 40 | return model_object.model_json_schema(**kwargs) 41 | 42 | return model_object.schema(**kwargs) 43 | -------------------------------------------------------------------------------- /ai21/models/ai21_base_model.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Any, Dict 3 | 4 | from pydantic import BaseModel, ConfigDict 5 | from typing_extensions import Self 6 | 7 | from ai21.models._pydantic_compatibility import _to_dict, _to_json, _from_dict, _from_json, IS_PYDANTIC_V2 8 | 9 | 10 | class AI21BaseModel(BaseModel): 11 | if IS_PYDANTIC_V2: 12 | model_config = ConfigDict( 13 | populate_by_name=True, 14 | protected_namespaces=(), 15 | extra="allow", 16 | ) 17 | else: 18 | 19 | class Config: 20 | from pydantic import Extra 21 | 22 | allow_population_by_field_name = True 23 | extra = Extra.allow 24 | 25 | def to_dict(self, **kwargs) -> Dict[str, Any]: 26 | warnings.warn( 27 | "The 'to_dict' method is deprecated and will be removed in a future version." 28 | " Please use Pydantic's built-in methods instead.", 29 | DeprecationWarning, 30 | stacklevel=2, 31 | ) 32 | kwargs["by_alias"] = kwargs.pop("by_alias", True) 33 | 34 | return _to_dict(self, **kwargs) 35 | 36 | def to_json(self, **kwargs) -> str: 37 | warnings.warn( 38 | "The 'to_json' method is deprecated and will be removed in a future version." 39 | " Please use Pydantic's built-in methods instead.", 40 | DeprecationWarning, 41 | stacklevel=2, 42 | ) 43 | kwargs["by_alias"] = kwargs.pop("by_alias", True) 44 | 45 | return _to_json(self, **kwargs) 46 | 47 | @classmethod 48 | def from_dict(cls, obj: Any, **kwargs) -> Self: 49 | warnings.warn( 50 | "The 'from_dict' method is deprecated and will be removed in a future version." 51 | " Please use Pydantic's built-in methods instead.", 52 | DeprecationWarning, 53 | stacklevel=2, 54 | ) 55 | 56 | return _from_dict(cls, obj, **kwargs) 57 | 58 | @classmethod 59 | def from_json(cls, json_str: str, **kwargs) -> Self: 60 | warnings.warn( 61 | "The 'from_json' method is deprecated and will be removed in a future version." 62 | " Please use Pydantic's built-in methods instead.", 63 | DeprecationWarning, 64 | stacklevel=2, 65 | ) 66 | 67 | return _from_json(cls, json_str, **kwargs) 68 | -------------------------------------------------------------------------------- /ai21/models/chat/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from .chat_completion_chunk import ChatCompletionChunk, ChoicesChunk, ChoiceDelta 4 | from .chat_completion_response import ChatCompletionResponse 5 | from .chat_completion_response import ChatCompletionResponseChoice 6 | from .chat_message import ChatMessage, AssistantMessage, ToolMessage, UserMessage, SystemMessage, ChatMessageParam 7 | from .document_schema import DocumentSchema 8 | from .function_tool_definition import FunctionToolDefinition 9 | from .response_format import ResponseFormat 10 | from .role_type import RoleType as RoleType 11 | from .tool_call import ToolCall 12 | from .tool_defintions import ToolDefinition 13 | from .tool_function import ToolFunction 14 | from .tool_parameters import ToolParameters 15 | 16 | __all__ = [ 17 | "ChatCompletionResponse", 18 | "ChatCompletionResponseChoice", 19 | "ChatMessage", 20 | "RoleType", 21 | "ChatCompletionChunk", 22 | "ChoicesChunk", 23 | "ChoiceDelta", 24 | "AssistantMessage", 25 | "ToolMessage", 26 | "UserMessage", 27 | "SystemMessage", 28 | "ChatMessageParam", 29 | "DocumentSchema", 30 | "FunctionToolDefinition", 31 | "ResponseFormat", 32 | "ToolCall", 33 | "ToolDefinition", 34 | "ToolFunction", 35 | "ToolParameters", 36 | ] 37 | -------------------------------------------------------------------------------- /ai21/models/chat/chat_completion_chunk.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | 3 | from ai21.models.ai21_base_model import AI21BaseModel 4 | from ai21.models.logprobs import Logprobs 5 | from ai21.models.usage_info import UsageInfo 6 | 7 | 8 | class ChoiceDelta(AI21BaseModel): 9 | content: Optional[str] = None 10 | role: Optional[str] = None 11 | 12 | 13 | class ChoicesChunk(AI21BaseModel): 14 | index: int 15 | delta: ChoiceDelta 16 | logprobs: Optional[Logprobs] = None 17 | finish_reason: Optional[str] = None 18 | 19 | 20 | class ChatCompletionChunk(AI21BaseModel): 21 | id: str 22 | choices: List[ChoicesChunk] 23 | usage: Optional[UsageInfo] = None 24 | -------------------------------------------------------------------------------- /ai21/models/chat/chat_completion_response.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | 3 | from ai21.models.ai21_base_model import AI21BaseModel 4 | from ai21.models.logprobs import Logprobs 5 | from ai21.models.usage_info import UsageInfo 6 | from .chat_message import AssistantMessage 7 | 8 | 9 | class ChatCompletionResponseChoice(AI21BaseModel): 10 | index: int 11 | message: AssistantMessage 12 | logprobs: Optional[Logprobs] = None 13 | finish_reason: Optional[str] = None 14 | 15 | 16 | class ChatCompletionResponse(AI21BaseModel): 17 | id: str 18 | choices: List[ChatCompletionResponseChoice] 19 | usage: UsageInfo 20 | -------------------------------------------------------------------------------- /ai21/models/chat/chat_message.py: -------------------------------------------------------------------------------- 1 | from typing_extensions import Literal, List, Optional, Union, TypeAlias 2 | 3 | from ai21.models.ai21_base_model import AI21BaseModel 4 | from ai21.models.chat.tool_call import ToolCall 5 | 6 | 7 | class ChatMessage(AI21BaseModel): 8 | role: str 9 | content: str 10 | 11 | 12 | class AssistantMessage(ChatMessage): 13 | role: Literal["assistant"] = "assistant" 14 | tool_calls: Optional[List[ToolCall]] = None 15 | content: Optional[str] = None 16 | 17 | 18 | class ToolMessage(ChatMessage): 19 | role: Literal["tool"] = "tool" 20 | tool_call_id: str 21 | 22 | 23 | class UserMessage(ChatMessage): 24 | role: Literal["user"] = "user" 25 | 26 | 27 | class SystemMessage(ChatMessage): 28 | role: Literal["system"] = "system" 29 | 30 | 31 | ChatMessageParam: TypeAlias = Union[UserMessage, AssistantMessage, ToolMessage, SystemMessage, ChatMessage] 32 | -------------------------------------------------------------------------------- /ai21/models/chat/document_schema.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from typing_extensions import TypedDict, Required 3 | 4 | 5 | class DocumentSchema(TypedDict, total=False): 6 | content: Required[str] 7 | id: str 8 | metadata: Dict[str, str] 9 | -------------------------------------------------------------------------------- /ai21/models/chat/function_tool_definition.py: -------------------------------------------------------------------------------- 1 | from typing_extensions import TypedDict, Required 2 | 3 | from ai21.models.chat.tool_parameters import ToolParameters 4 | 5 | 6 | class FunctionToolDefinition(TypedDict, total=False): 7 | name: Required[str] 8 | description: str 9 | parameters: ToolParameters 10 | -------------------------------------------------------------------------------- /ai21/models/chat/response_format.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | from typing_extensions import TypedDict, Required 3 | 4 | 5 | class ResponseFormat(TypedDict, total=False): 6 | type: Required[Literal["text", "json_object"]] 7 | -------------------------------------------------------------------------------- /ai21/models/chat/role_type.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class RoleType(str, Enum): 5 | USER = "user" 6 | ASSISTANT = "assistant" 7 | TOOL = "tool" 8 | SYSTEM = "system" 9 | -------------------------------------------------------------------------------- /ai21/models/chat/tool_call.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | from ai21.models.ai21_base_model import AI21BaseModel 4 | from ai21.models.chat.tool_function import ToolFunction 5 | 6 | 7 | class ToolCall(AI21BaseModel): 8 | id: str 9 | function: ToolFunction 10 | type: Literal["function"] = "function" 11 | -------------------------------------------------------------------------------- /ai21/models/chat/tool_defintions.py: -------------------------------------------------------------------------------- 1 | from typing_extensions import Literal, TypedDict, Required 2 | 3 | from ai21.models.chat import FunctionToolDefinition 4 | 5 | 6 | class ToolDefinition(TypedDict, total=False): 7 | type: Required[Literal["function"]] 8 | function: Required[FunctionToolDefinition] 9 | -------------------------------------------------------------------------------- /ai21/models/chat/tool_function.py: -------------------------------------------------------------------------------- 1 | from ai21.models.ai21_base_model import AI21BaseModel 2 | 3 | 4 | class ToolFunction(AI21BaseModel): 5 | name: str 6 | arguments: str 7 | -------------------------------------------------------------------------------- /ai21/models/chat/tool_parameters.py: -------------------------------------------------------------------------------- 1 | from typing_extensions import Literal, Any, Dict, List, TypedDict, Required 2 | 3 | 4 | class ToolParameters(TypedDict, total=False): 5 | type: Literal["object"] 6 | properties: Required[Dict[str, Any]] 7 | required: List[str] 8 | -------------------------------------------------------------------------------- /ai21/models/chat_message.py: -------------------------------------------------------------------------------- 1 | from ai21.models.ai21_base_model import AI21BaseModel 2 | from ai21.models.chat.role_type import RoleType 3 | 4 | 5 | class ChatMessage(AI21BaseModel): 6 | role: RoleType 7 | text: str 8 | -------------------------------------------------------------------------------- /ai21/models/document_type.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class DocumentType(str, Enum): 5 | URL = "URL" 6 | TEXT = "TEXT" 7 | -------------------------------------------------------------------------------- /ai21/models/logprobs.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from ai21.models.ai21_base_model import AI21BaseModel 4 | 5 | 6 | class TopTokenData(AI21BaseModel): 7 | token: str 8 | logprob: float 9 | 10 | 11 | class LogprobsData(AI21BaseModel): 12 | token: str 13 | logprob: float 14 | top_logprobs: List[TopTokenData] 15 | 16 | 17 | class Logprobs(AI21BaseModel): 18 | content: LogprobsData 19 | -------------------------------------------------------------------------------- /ai21/models/maestro/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/ai21/models/maestro/__init__.py -------------------------------------------------------------------------------- /ai21/models/maestro/run.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, List, Optional, Any, Set, Dict, Type, Union 2 | from typing_extensions import TypedDict 3 | from pydantic import BaseModel 4 | 5 | from ai21.models.ai21_base_model import AI21BaseModel 6 | 7 | Budget = Literal["low", "medium", "high"] 8 | Role = Literal["user", "assistant"] 9 | RunStatus = Literal["completed", "failed", "in_progress", "requires_action"] 10 | ToolType = Literal["file_search", "web_search"] 11 | OutputOptions = Literal["data_sources", "requirements_result"] 12 | PrimitiveTypes = Union[Type[str], Type[int], Type[float], Type[bool]] 13 | PrimitiveLists = Type[List[PrimitiveTypes]] 14 | OutputType = Union[Type[BaseModel], PrimitiveTypes, Dict[str, Any]] 15 | 16 | DEFAULT_RUN_POLL_INTERVAL: float = 1 # seconds 17 | DEFAULT_RUN_POLL_TIMEOUT: float = 120 # seconds 18 | TERMINATED_RUN_STATUSES: Set[RunStatus] = {"completed", "failed", "requires_action"} 19 | 20 | 21 | class Tool(TypedDict): 22 | type: ToolType 23 | 24 | 25 | class FileSearchToolResource(TypedDict, total=False): 26 | retrieval_similarity_threshold: Optional[float] 27 | labels: Optional[List[str]] 28 | labels_filter_mode: Optional[Literal["AND", "OR"]] 29 | labels_filter: Optional[dict] 30 | file_ids: Optional[List[str]] 31 | retrieval_strategy: Optional[str] 32 | max_neighbors: Optional[int] 33 | 34 | 35 | class WebSearchToolResource(TypedDict, total=False): 36 | urls: Optional[List[str]] 37 | 38 | 39 | class ToolResources(TypedDict, total=False): 40 | file_search: Optional[FileSearchToolResource] 41 | web_search: Optional[WebSearchToolResource] 42 | 43 | 44 | class Requirement(TypedDict, total=False): 45 | name: str 46 | description: str 47 | is_mandatory: bool = False 48 | 49 | 50 | class RequirementResultItem(Requirement, total=False): 51 | score: float 52 | reason: Optional[str] = None 53 | 54 | 55 | class RequirementsResult(TypedDict, total=False): 56 | score: float 57 | finish_reason: str 58 | requirements: List[RequirementResultItem] 59 | 60 | 61 | class FileSearchResult(TypedDict, total=False): 62 | text: Optional[str] 63 | file_id: str 64 | file_name: str 65 | score: float 66 | order: int 67 | 68 | 69 | class WebSearchResult(TypedDict, total=False): 70 | text: str 71 | url: str 72 | score: float 73 | 74 | 75 | class DataSources(TypedDict, total=False): 76 | file_search: Optional[List[FileSearchResult]] 77 | web_search: Optional[List[WebSearchResult]] 78 | 79 | 80 | class RunResponse(AI21BaseModel): 81 | id: str 82 | status: RunStatus 83 | result: Any 84 | data_sources: Optional[DataSources] = None 85 | requirements_result: Optional[RequirementsResult] = None 86 | -------------------------------------------------------------------------------- /ai21/models/penalty.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Optional 4 | 5 | from ai21.models.ai21_base_model import AI21BaseModel 6 | from pydantic import Field 7 | 8 | 9 | class Penalty(AI21BaseModel): 10 | scale: float 11 | apply_to_whitespaces: Optional[bool] = Field(default=None, alias="applyToWhitespaces") 12 | apply_to_punctuation: Optional[bool] = Field(default=None, alias="applyToPunctuation") 13 | apply_to_numbers: Optional[bool] = Field(default=None, alias="applyToNumbers") 14 | apply_to_stopwords: Optional[bool] = Field(default=None, alias="applyToStopwords") 15 | apply_to_emojis: Optional[bool] = Field(default=None, alias="applyToEmojis") 16 | 17 | def to_dict(self, **kwargs): 18 | kwargs["by_alias"] = kwargs.pop("by_alias", True) 19 | kwargs["exclude_none"] = kwargs.pop("exclude_none", True) 20 | 21 | return super().to_dict(**kwargs) 22 | 23 | def to_json(self, **kwargs) -> str: 24 | kwargs["by_alias"] = kwargs.pop("by_alias", True) 25 | kwargs["exclude_none"] = kwargs.pop("exclude_none", True) 26 | 27 | return super().to_json(**kwargs) 28 | -------------------------------------------------------------------------------- /ai21/models/request_options.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Dict, Optional, BinaryIO 3 | 4 | from typing_extensions import Self 5 | 6 | 7 | @dataclass(frozen=True) 8 | class RequestOptions: 9 | url: str 10 | body: Dict[str, Any] 11 | method: str 12 | headers: Dict[str, Any] 13 | timeout: float 14 | path: Optional[str] = None 15 | params: Optional[Dict[str, Any]] = None 16 | stream: bool = False 17 | files: Optional[Dict[str, BinaryIO]] = None 18 | 19 | def replace( 20 | self, 21 | url: Optional[str] = None, 22 | body: Optional[Dict[str, Any]] = None, 23 | method: Optional[str] = None, 24 | headers: Optional[Dict[str, Any]] = None, 25 | timeout: Optional[float] = None, 26 | params: Optional[Dict[str, Any]] = None, 27 | stream: Optional[bool] = None, 28 | path: Optional[str] = None, 29 | files: Optional[Dict[str, BinaryIO]] = None, 30 | ) -> Self: 31 | return RequestOptions( 32 | url=url or self.url, 33 | body=body or self.body, 34 | method=method or self.method, 35 | headers=headers or self.headers, 36 | timeout=timeout or self.timeout, 37 | params=params or self.params, 38 | stream=stream if stream is not None else self.stream, 39 | path=path or self.path, 40 | files=files or self.files, 41 | ) 42 | -------------------------------------------------------------------------------- /ai21/models/responses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/ai21/models/responses/__init__.py -------------------------------------------------------------------------------- /ai21/models/responses/batch_response.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional 2 | 3 | from ai21.models.ai21_base_model import AI21BaseModel 4 | 5 | 6 | class RequestCounts(AI21BaseModel): 7 | total: int 8 | completed: int 9 | failed: int 10 | 11 | 12 | class Batch(AI21BaseModel): 13 | id: str 14 | """The ID of the batch.""" 15 | status: Literal["FAILED", "PROCESSING", "COMPLETED", "EXPIRED", "CANCELLED", "PENDING", "STALE"] 16 | 17 | """The current status of the batch.""" 18 | endpoint: str 19 | """The AI21 API endpoint used by the batch.""" 20 | 21 | created_at: int 22 | """The Unix timestamp (in seconds) for when the batch was created.""" 23 | 24 | completed_at: Optional[int] = None 25 | """The Unix timestamp (in seconds) for when the batch was completed.""" 26 | 27 | cancelled_at: Optional[int] = None 28 | """The Unix timestamp (in seconds) for when the batch was cancelled.""" 29 | 30 | request_counts: RequestCounts 31 | """The request counts for different statuses within the batch.""" 32 | 33 | metadata: Optional[dict] = None 34 | 35 | output_file: Optional[str] = None 36 | """The path to the file containing the outputs of successfully executed requests.""" 37 | 38 | error_file: Optional[str] = None 39 | """The path to the file containing the outputs of requests with errors.""" 40 | -------------------------------------------------------------------------------- /ai21/models/responses/chat_response.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | 3 | from pydantic import Field 4 | 5 | from ai21.models.ai21_base_model import AI21BaseModel 6 | from ai21.models.chat.role_type import RoleType 7 | 8 | 9 | class FinishReason(AI21BaseModel): 10 | reason: str 11 | length: Optional[int] = None 12 | sequence: Optional[str] = None 13 | 14 | 15 | class ChatOutput(AI21BaseModel): 16 | text: str 17 | role: RoleType 18 | finish_reason: Optional[FinishReason] = Field(default=None, alias="finishReason") 19 | 20 | 21 | class ChatResponse(AI21BaseModel): 22 | outputs: List[ChatOutput] 23 | -------------------------------------------------------------------------------- /ai21/models/responses/conversational_rag_response.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional, List 3 | 4 | from ai21.models.chat import ChatMessage 5 | from ai21.models.ai21_base_model import AI21BaseModel 6 | 7 | 8 | class ConversationalRagSource(AI21BaseModel): 9 | text: str 10 | file_id: str 11 | file_name: str 12 | score: float 13 | order: Optional[int] = None 14 | public_url: Optional[str] = None 15 | labels: Optional[List[str]] = None 16 | 17 | 18 | class ConversationalRagResponse(AI21BaseModel): 19 | id: str 20 | choices: List[ChatMessage] 21 | search_queries: Optional[List[str]] 22 | context_retrieved: bool 23 | answer_in_context: bool 24 | sources: List[ConversationalRagSource] 25 | -------------------------------------------------------------------------------- /ai21/models/responses/file_response.py: -------------------------------------------------------------------------------- 1 | from datetime import date 2 | from typing import Optional, List 3 | 4 | from pydantic import Field 5 | 6 | from ai21.models.ai21_base_model import AI21BaseModel 7 | 8 | 9 | class FileResponse(AI21BaseModel): 10 | file_id: str = Field(alias="fileId") 11 | name: str 12 | file_type: str = Field(alias="fileType") 13 | size_bytes: int = Field(alias="sizeBytes") 14 | created_by: str = Field(alias="createdBy") 15 | creation_date: date = Field(alias="creationDate") 16 | last_updated: date = Field(alias="lastUpdated") 17 | status: str 18 | path: Optional[str] = None 19 | labels: Optional[List[str]] = None 20 | public_url: Optional[str] = Field(default=None, alias="publicUrl") 21 | error_code: Optional[int] = Field(default=None, alias="errorCode") 22 | error_message: Optional[str] = Field(default=None, alias="errorMessage") 23 | -------------------------------------------------------------------------------- /ai21/models/retrieval_strategy.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | 4 | RetrievalStrategy = Literal["default", "segments", "add_neighbors", "full_doc"] 5 | -------------------------------------------------------------------------------- /ai21/models/usage_info.py: -------------------------------------------------------------------------------- 1 | from ai21.models.ai21_base_model import AI21BaseModel 2 | 3 | 4 | class UsageInfo(AI21BaseModel): 5 | prompt_tokens: int 6 | completion_tokens: int 7 | total_tokens: int 8 | -------------------------------------------------------------------------------- /ai21/pagination/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/ai21/pagination/__init__.py -------------------------------------------------------------------------------- /ai21/pagination/async_pagination.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any, AsyncIterator, Callable, Dict, List 4 | 5 | from .base_pagination import BasePagination, PageT, cast_page_response 6 | 7 | 8 | class AsyncPagination(BasePagination[PageT]): 9 | """A wrapper class that provides asynchronous pagination functionality for API resources.""" 10 | 11 | def __init__( 12 | self, 13 | request_callback: Callable, 14 | path: str, 15 | response_cls: PageT, 16 | params: Dict[str, Any] | None = None, 17 | ): 18 | super().__init__(request_callback, path, response_cls, params) 19 | self._iterator = self.__paginate__() 20 | 21 | async def __aiter__(self) -> AsyncIterator[List[PageT]]: 22 | async for page in self._iterator: 23 | yield page 24 | 25 | async def __anext__(self) -> List[PageT]: 26 | return await self._iterator.__anext__() 27 | 28 | async def __paginate__(self) -> AsyncIterator[List[PageT]]: 29 | while True: 30 | results = await self.request_callback(path=self.path, params=self.params, method="GET") 31 | response = cast_page_response(raw_response=results.json(), response_cls=self.response_cls) 32 | 33 | if not response: 34 | break 35 | 36 | self._set_next_page_params(response) 37 | yield response 38 | -------------------------------------------------------------------------------- /ai21/pagination/base_pagination.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any, Callable, Dict, Generic, List, Protocol, Sequence, TypeVar 4 | 5 | from ai21.models._pydantic_compatibility import _from_dict 6 | 7 | 8 | class HasId(Protocol): 9 | id: str 10 | 11 | 12 | PageT = TypeVar("PageT", bound=HasId) 13 | 14 | 15 | def cast_page_response(raw_response: Dict[str, Any], response_cls: PageT) -> List[PageT]: 16 | """Cast API response to appropriate types.""" 17 | return [_from_dict(obj=response_cls, obj_dict=item) for item in raw_response] 18 | 19 | 20 | class BasePagination(Generic[PageT]): 21 | def __init__( 22 | self, 23 | request_callback: Callable, 24 | path: str, 25 | response_cls: PageT, 26 | params: Dict[str, Any] | None = None, 27 | **kwargs: Any, 28 | ): 29 | self.request_callback = request_callback 30 | self.path = path 31 | self.params = params or {} 32 | self.response_cls = response_cls 33 | self.kwargs = kwargs 34 | 35 | def _set_next_page_params(self, response: Sequence[PageT]) -> None: 36 | """Update params with the ID from the last item in the response.""" 37 | if response: 38 | self.params["after"] = response[-1].id 39 | -------------------------------------------------------------------------------- /ai21/pagination/sync_pagination.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any, Callable, Dict, Iterator, List 4 | 5 | from .base_pagination import BasePagination, PageT, cast_page_response 6 | 7 | 8 | class SyncPagination(BasePagination[PageT]): 9 | """A wrapper class that provides synchronous pagination functionality for API resources.""" 10 | 11 | def __init__( 12 | self, 13 | request_callback: Callable, 14 | path: str, 15 | response_cls: PageT, 16 | params: Dict[str, Any] | None = None, 17 | **kwargs: Any, 18 | ): 19 | super().__init__(request_callback, path, response_cls, params, **kwargs) 20 | self._iterator = self.__paginate__() 21 | 22 | def __iter__(self) -> Iterator[List[PageT]]: 23 | for page in self._iterator: 24 | yield page 25 | 26 | def __next__(self) -> List[Any]: 27 | return self._iterator.__next__() 28 | 29 | def __paginate__(self) -> Iterator[List[PageT]]: 30 | while True: 31 | results = self.request_callback(path=self.path, params=self.params, method="GET", **self.kwargs) 32 | response = cast_page_response(raw_response=results.json(), response_cls=self.response_cls) 33 | 34 | if not response: 35 | break 36 | 37 | self._set_next_page_params(response) 38 | yield response 39 | -------------------------------------------------------------------------------- /ai21/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/ai21/py.typed -------------------------------------------------------------------------------- /ai21/stream/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/ai21/stream/__init__.py -------------------------------------------------------------------------------- /ai21/stream/async_stream.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Generic, AsyncIterator, Optional, Any 4 | from typing_extensions import Self 5 | from types import TracebackType 6 | 7 | from ai21.stream.stream_commons import _T, _SSEDecoder, _SSE_DONE_MSG, get_stream_message 8 | 9 | import httpx 10 | 11 | 12 | class AsyncStream(Generic[_T]): 13 | response: httpx.Response 14 | 15 | def __init__( 16 | self, 17 | *, 18 | cast_to: type[_T], 19 | response: httpx.Response, 20 | streaming_decoder: Optional[Any] = None, 21 | ): 22 | self.response = response 23 | self.cast_to = cast_to 24 | self._decoder = streaming_decoder or _SSEDecoder() 25 | self._iterator = self.__stream__() 26 | 27 | async def __anext__(self) -> _T: 28 | return await self._iterator.__anext__() 29 | 30 | async def __aiter__(self) -> AsyncIterator[_T]: 31 | async for item in self._iterator: 32 | yield item 33 | 34 | async def __stream__(self) -> AsyncIterator[_T]: 35 | iterator = self._decoder.aiter(self.response) 36 | async for chunk in iterator: 37 | if chunk.endswith(_SSE_DONE_MSG): 38 | break 39 | 40 | yield get_stream_message(chunk, self.cast_to) 41 | 42 | # Ensure the entire stream is consumed 43 | async for _chunk in iterator: 44 | ... 45 | 46 | async def __aenter__(self) -> Self: 47 | return self 48 | 49 | async def __aexit__( 50 | self, 51 | exc_type: type[BaseException] | None, 52 | exc: BaseException | None, 53 | exc_tb: TracebackType | None, 54 | ) -> None: 55 | await self.close() 56 | 57 | async def close(self): 58 | await self.response.aclose() 59 | -------------------------------------------------------------------------------- /ai21/stream/stream.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from types import TracebackType 4 | from typing import Generic, Iterator, Optional, Any 5 | 6 | import httpx 7 | from typing_extensions import Self 8 | 9 | from ai21.stream.stream_commons import _T, _SSEDecoder, _SSE_DONE_MSG, get_stream_message 10 | 11 | 12 | class Stream(Generic[_T]): 13 | response: httpx.Response 14 | 15 | def __init__( 16 | self, 17 | *, 18 | cast_to: type[_T], 19 | response: httpx.Response, 20 | streaming_decoder: Optional[Any] = None, 21 | ): 22 | self.response = response 23 | self.cast_to = cast_to 24 | self._decoder = streaming_decoder or _SSEDecoder() 25 | self._iterator = self.__stream__() 26 | 27 | def __next__(self) -> _T: 28 | return self._iterator.__next__() 29 | 30 | def __iter__(self) -> Iterator[_T]: 31 | for item in self._iterator: 32 | yield item 33 | 34 | def __stream__(self) -> Iterator[_T]: 35 | iterator = self._decoder.iter(self.response) 36 | for chunk in iterator: 37 | if chunk.endswith(_SSE_DONE_MSG): 38 | break 39 | 40 | yield get_stream_message(chunk, self.cast_to) 41 | 42 | # Ensure the entire stream is consumed 43 | for _chunk in iterator: 44 | ... 45 | 46 | def __enter__(self) -> Self: 47 | return self 48 | 49 | def __exit__( 50 | self, 51 | exc_type: type[BaseException] | None, 52 | exc: BaseException | None, 53 | exc_tb: TracebackType | None, 54 | ) -> None: 55 | self.close() 56 | 57 | def close(self): 58 | self.response.close() 59 | -------------------------------------------------------------------------------- /ai21/stream/stream_commons.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import json 4 | from abc import ABC, abstractmethod 5 | from typing import TypeVar, Iterator, AsyncIterator, Optional 6 | 7 | import httpx 8 | 9 | from ai21.errors import StreamingDecodeError 10 | 11 | _T = TypeVar("_T") 12 | _SSE_DATA_PREFIX = "data: " 13 | _SSE_DONE_MSG = "[DONE]" 14 | 15 | 16 | def get_stream_message(chunk: str, cast_to: type[_T]) -> Iterator[_T] | AsyncIterator[_T]: 17 | try: 18 | chunk = json.loads(chunk) 19 | if hasattr(cast_to, "from_dict"): 20 | return cast_to.from_dict(chunk) 21 | else: 22 | return cast_to(**chunk) 23 | except json.JSONDecodeError: 24 | raise StreamingDecodeError(chunk) 25 | 26 | 27 | class _SSEDecoderBase(ABC): 28 | @abstractmethod 29 | def iter(self, response: httpx.Response) -> Iterator[str]: 30 | pass 31 | 32 | @abstractmethod 33 | async def aiter(self, response: httpx.Response) -> AsyncIterator[str]: 34 | pass 35 | 36 | 37 | class _SSEDecoder(_SSEDecoderBase): 38 | def iter(self, response: httpx.Response): 39 | for line in response.iter_lines(): 40 | line = line.strip() 41 | decoded_line = self._decode(line) 42 | 43 | if decoded_line is not None: 44 | yield decoded_line 45 | 46 | async def aiter(self, response: httpx.Response): 47 | async for line in response.aiter_lines(): 48 | line = line.strip() 49 | decoded_line = self._decode(line) 50 | 51 | if decoded_line is not None: 52 | yield decoded_line 53 | 54 | def _decode(self, line: str) -> Optional[str]: 55 | if not line: 56 | return None 57 | 58 | if line.startswith(_SSE_DATA_PREFIX): 59 | return line.strip(_SSE_DATA_PREFIX) 60 | 61 | raise StreamingDecodeError(f"Invalid SSE line: {line}") 62 | -------------------------------------------------------------------------------- /ai21/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .ai21_tokenizer import AI21Tokenizer 2 | from .factory import get_tokenizer, get_async_tokenizer 3 | 4 | __all__ = ["AI21Tokenizer", "get_tokenizer", "get_async_tokenizer"] 5 | -------------------------------------------------------------------------------- /ai21/tokenizers/ai21_tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import List, Any 2 | 3 | from ai21_tokenizer import BaseTokenizer, AsyncBaseTokenizer 4 | 5 | 6 | class AI21Tokenizer: 7 | """ 8 | A class that wraps a tokenizer and provides additional functionality. 9 | """ 10 | 11 | def __init__(self, tokenizer: BaseTokenizer): 12 | self._tokenizer = tokenizer 13 | 14 | def count_tokens(self, text: str) -> int: 15 | encoded_text = self._tokenizer.encode(text) 16 | 17 | return len(encoded_text) 18 | 19 | def tokenize(self, text: str, **kwargs: Any) -> List[str]: 20 | encoded_text = self._tokenizer.encode(text, **kwargs) 21 | 22 | return self._tokenizer.convert_ids_to_tokens(encoded_text, **kwargs) 23 | 24 | def detokenize(self, tokens: List[str], **kwargs: Any) -> str: 25 | token_ids = self._tokenizer.convert_tokens_to_ids(tokens) 26 | 27 | return self._tokenizer.decode(token_ids, **kwargs) 28 | 29 | 30 | class AsyncAI21Tokenizer: 31 | """ 32 | A class that wraps an async tokenizer and provides additional functionality. 33 | """ 34 | 35 | def __init__(self, tokenizer: AsyncBaseTokenizer): 36 | self._tokenizer = tokenizer 37 | 38 | async def count_tokens(self, text: str) -> int: 39 | encoded_text = await self._tokenizer.encode(text) 40 | 41 | return len(encoded_text) 42 | 43 | async def tokenize(self, text: str, **kwargs: Any) -> List[str]: 44 | encoded_text = await self._tokenizer.encode(text, **kwargs) 45 | 46 | return await self._tokenizer.convert_ids_to_tokens(encoded_text, **kwargs) 47 | 48 | async def detokenize(self, tokens: List[str], **kwargs: Any) -> str: 49 | token_ids = await self._tokenizer.convert_tokens_to_ids(tokens) 50 | 51 | return await self._tokenizer.decode(token_ids, **kwargs) 52 | -------------------------------------------------------------------------------- /ai21/tokenizers/factory.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from ai21_tokenizer import Tokenizer, PreTrainedTokenizers 4 | 5 | from ai21.tokenizers.ai21_tokenizer import AI21Tokenizer, AsyncAI21Tokenizer 6 | 7 | _cached_tokenizers: Dict[str, AI21Tokenizer] = {} 8 | _cached_async_tokenizers: Dict[str, AsyncAI21Tokenizer] = {} 9 | 10 | 11 | def get_tokenizer(name: str = PreTrainedTokenizers.J2_TOKENIZER) -> AI21Tokenizer: 12 | """ 13 | Get the tokenizer instance. 14 | 15 | If the tokenizer instance is not cached, it will be created using the Tokenizer.get_tokenizer() method. 16 | """ 17 | global _cached_tokenizers 18 | 19 | if _cached_tokenizers.get(name) is None: 20 | _cached_tokenizers[name] = AI21Tokenizer(Tokenizer.get_tokenizer(name)) 21 | 22 | return _cached_tokenizers[name] 23 | 24 | 25 | async def get_async_tokenizer(name: str = PreTrainedTokenizers.J2_TOKENIZER) -> AsyncAI21Tokenizer: 26 | """ 27 | Get the async tokenizer instance. 28 | 29 | If the tokenizer instance is not cached, it will be created using the Tokenizer.get_tokenizer() method. 30 | """ 31 | global _cached_async_tokenizers 32 | 33 | if _cached_async_tokenizers.get(name) is None: 34 | _cached_async_tokenizers[name] = AsyncAI21Tokenizer(await Tokenizer.get_async_tokenizer(name)) 35 | 36 | return _cached_async_tokenizers[name] 37 | -------------------------------------------------------------------------------- /ai21/types.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Union 2 | 3 | import httpx 4 | 5 | from typing_extensions import TYPE_CHECKING, Literal, TypeVar 6 | 7 | from ai21.pagination.async_pagination import AsyncPagination 8 | from ai21.pagination.sync_pagination import SyncPagination 9 | from ai21.stream.async_stream import AsyncStream 10 | from ai21.stream.stream import Stream 11 | 12 | 13 | if TYPE_CHECKING: 14 | from ai21.models.ai21_base_model import AI21BaseModel # noqa 15 | 16 | ResponseT = TypeVar("_ResponseT", bound=Union["AI21BaseModel", str, httpx.Response, List[Any]]) 17 | StreamT = TypeVar("_StreamT", bound=Stream[Any]) 18 | AsyncStreamT = TypeVar("_AsyncStreamT", bound=AsyncStream[Any]) 19 | SyncPaginationT = TypeVar("SyncPagination", bound=SyncPagination[Any]) 20 | AsyncPaginationT = TypeVar("AsyncPaginationT", bound=AsyncPagination[Any]) 21 | 22 | 23 | # Sentinel class used until PEP 0661 is accepted 24 | class NotGiven: 25 | """ 26 | A sentinel singleton class used to distinguish omitted keyword arguments 27 | from those passed in with the value None (which may have different behavior). 28 | 29 | For example: 30 | 31 | ```py 32 | def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: 33 | ... 34 | 35 | 36 | get(timeout=1) # 1s timeout 37 | get(timeout=None) # No timeout 38 | get() # Default timeout behavior, which may not be statically known at the method definition. 39 | ``` 40 | """ 41 | 42 | def __bool__(self) -> Literal[False]: 43 | return False 44 | 45 | def __repr__(self) -> str: 46 | return "NOT_GIVEN" 47 | 48 | 49 | NOT_GIVEN = NotGiven() 50 | -------------------------------------------------------------------------------- /ai21/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/ai21/utils/__init__.py -------------------------------------------------------------------------------- /ai21/utils/typing.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, get_args, cast 2 | 3 | from ai21.types import NotGiven 4 | 5 | 6 | def is_not_given(value: Any) -> bool: 7 | return isinstance(value, NotGiven) 8 | 9 | 10 | def remove_not_given(body: Dict[str, Any]) -> Dict[str, Any]: 11 | return {k: v for k, v in body.items() if not is_not_given(v)} 12 | 13 | 14 | def to_camel_case(snake_str: str) -> str: 15 | return "".join(x.capitalize() for x in snake_str.lower().split("_")) 16 | 17 | 18 | def to_lower_camel_case(snake_str: str) -> str: 19 | # We capitalize the first letter of each component except the first one 20 | # with the 'capitalize' method and join them together. 21 | camel_string = to_camel_case(snake_str) 22 | return snake_str[0].lower() + camel_string[1:] 23 | 24 | 25 | def extract_type(type_to_extract: Any) -> type: 26 | args = get_args(type_to_extract) 27 | try: 28 | return cast(type, args[0]) 29 | except IndexError as err: 30 | raise RuntimeError( 31 | f"Expected type {type_to_extract} to have a type argument at index 0 but it did not" 32 | ) from err 33 | -------------------------------------------------------------------------------- /ai21/version.py: -------------------------------------------------------------------------------- 1 | VERSION = "3.3.0" 2 | -------------------------------------------------------------------------------- /ai21/version_utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from functools import wraps 3 | 4 | V3_DEPRECATION_MESSAGE = "This model is not supported anymore - Please upgrade to v3.0.0" 5 | 6 | 7 | def deprecated(message): 8 | def decorator(func): 9 | @wraps(func) 10 | def wrapper(*args, **kwargs): 11 | warnings.warn(message, category=DeprecationWarning, stacklevel=2) 12 | return func(*args, **kwargs) 13 | 14 | return wrapper 15 | 16 | return decorator 17 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/examples/__init__.py -------------------------------------------------------------------------------- /examples/azure/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/examples/azure/__init__.py -------------------------------------------------------------------------------- /examples/azure/async_azure_chat_completions.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from ai21 import AsyncAI21AzureClient 4 | from ai21.models.chat import ChatMessage 5 | 6 | 7 | async def chat_completions(): 8 | client = AsyncAI21AzureClient( 9 | base_url="", 10 | api_key="", 11 | ) 12 | 13 | messages = ChatMessage(content="What is the meaning of life?", role="user") 14 | 15 | completion = await client.chat.completions.create( 16 | model="jamba-instruct", 17 | messages=[messages], 18 | ) 19 | 20 | print(completion.to_json()) 21 | 22 | 23 | asyncio.run(chat_completions()) 24 | -------------------------------------------------------------------------------- /examples/azure/azure_chat_completions.py: -------------------------------------------------------------------------------- 1 | from ai21 import AI21AzureClient 2 | 3 | from ai21.models.chat import ChatMessage 4 | 5 | client = AI21AzureClient( 6 | base_url="", 7 | api_key="", 8 | ) 9 | 10 | messages = ChatMessage(content="What is the meaning of life?", role="user") 11 | 12 | completion = client.chat.completions.create( 13 | model="jamba-instruct", 14 | messages=[messages], 15 | ) 16 | 17 | print(completion.to_json()) 18 | -------------------------------------------------------------------------------- /examples/bedrock/chat/async_chat_completions.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from ai21 import AsyncAI21BedrockClient, BedrockModelID 3 | from ai21.models.chat import ChatMessage 4 | 5 | client = AsyncAI21BedrockClient(region="us-east-1") # region is optional, as you can use the env variable instead 6 | 7 | messages = [ 8 | ChatMessage(content="You are a helpful assistant", role="system"), 9 | ChatMessage(content="What is the meaning of life?", role="user"), 10 | ] 11 | 12 | 13 | async def main(): 14 | response = await client.chat.completions.create( 15 | messages=messages, 16 | model=BedrockModelID.JAMBA_1_5_MINI, 17 | ) 18 | 19 | print(f"response: {response}") 20 | 21 | 22 | asyncio.run(main()) 23 | -------------------------------------------------------------------------------- /examples/bedrock/chat/async_stream_chat_completions.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from ai21 import AsyncAI21BedrockClient, BedrockModelID 3 | from ai21.models.chat import ChatMessage 4 | 5 | client = AsyncAI21BedrockClient(region="us-east-1") # region is optional, as you can use the env variable instead 6 | 7 | system = "You're a support engineer in a SaaS company" 8 | messages = [ 9 | ChatMessage(content=system, role="system"), 10 | ChatMessage(content="Hello, I need help with a signup process.", role="user"), 11 | ChatMessage(content="Hi Alice, I can help you with that. What seems to be the problem?", role="assistant"), 12 | ChatMessage(content="I am having trouble signing up for your product with my Google account.", role="user"), 13 | ] 14 | 15 | 16 | async def main(): 17 | response = await client.chat.completions.create( 18 | messages=messages, 19 | model=BedrockModelID.JAMBA_1_5_MINI, 20 | max_tokens=100, 21 | stream=True, 22 | ) 23 | 24 | async for chunk in response: 25 | print(chunk.choices[0].delta.content, end="") 26 | 27 | 28 | asyncio.run(main()) 29 | -------------------------------------------------------------------------------- /examples/bedrock/chat/chat_completions.py: -------------------------------------------------------------------------------- 1 | from ai21 import AI21BedrockClient, BedrockModelID 2 | from ai21.models.chat import ChatMessage 3 | 4 | # Bedrock is currently supported only in us-east-1 region. 5 | # Either set your profile's region to us-east-1 or uncomment next line 6 | # ai21.aws_region = 'us-east-1' 7 | # Or create a boto session and pass it: 8 | # import boto3 9 | # session = boto3.Session(region_name="us-east-1") 10 | 11 | system = "You're a support engineer in a SaaS company" 12 | messages = [ 13 | ChatMessage(content=system, role="system"), 14 | ChatMessage(content="Hello, I need help with a signup process.", role="user"), 15 | ChatMessage(content="Hi Alice, I can help you with that. What seems to be the problem?", role="assistant"), 16 | ChatMessage(content="I am having trouble signing up for your product with my Google account.", role="user"), 17 | ] 18 | 19 | client = AI21BedrockClient() 20 | 21 | response = client.chat.completions.create( 22 | messages=messages, 23 | max_tokens=1000, 24 | temperature=0, 25 | model=BedrockModelID.JAMBA_1_5_MINI, 26 | ) 27 | 28 | print(f"response: {response}") 29 | -------------------------------------------------------------------------------- /examples/bedrock/chat/stream_chat_completions.py: -------------------------------------------------------------------------------- 1 | from ai21 import AI21BedrockClient, BedrockModelID 2 | from ai21.models.chat import ChatMessage 3 | 4 | system = "You're a support engineer in a SaaS company" 5 | messages = [ 6 | ChatMessage(content=system, role="system"), 7 | ChatMessage(content="Hello, I need help with a signup process.", role="user"), 8 | ChatMessage(content="Hi Alice, I can help you with that. What seems to be the problem?", role="assistant"), 9 | ChatMessage(content="I am having trouble signing up for your product with my Google account.", role="user"), 10 | ] 11 | 12 | client = AI21BedrockClient() 13 | 14 | response = client.chat.completions.create( 15 | messages=messages, 16 | model=BedrockModelID.JAMBA_1_5_MINI, 17 | max_tokens=100, 18 | stream=True, 19 | ) 20 | 21 | for chunk in response: 22 | print(chunk.choices[0].delta.content, end="") 23 | -------------------------------------------------------------------------------- /examples/launchpad/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/examples/launchpad/__init__.py -------------------------------------------------------------------------------- /examples/launchpad/async_chat_completions.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from ai21 import AsyncAI21LaunchpadClient 4 | from ai21.models.chat import ChatMessage 5 | 6 | 7 | client = AsyncAI21LaunchpadClient(endpoint_id="") 8 | 9 | 10 | async def main(): 11 | messages = ChatMessage(content="What is the meaning of life?", role="user") 12 | 13 | completion = await client.chat.completions.create( 14 | model="jamba-1.6-large", 15 | messages=[messages], 16 | ) 17 | 18 | print(completion) 19 | 20 | 21 | asyncio.run(main()) 22 | -------------------------------------------------------------------------------- /examples/launchpad/chat_completions.py: -------------------------------------------------------------------------------- 1 | from ai21 import AI21LaunchpadClient 2 | from ai21.models.chat import ChatMessage 3 | 4 | 5 | client = AI21LaunchpadClient(endpoint_id="") 6 | 7 | messages = ChatMessage(content="What is the meaning of life?", role="user") 8 | 9 | completion = client.chat.completions.create( 10 | model="jamba-1.6-large", 11 | messages=[messages], 12 | stream=True, 13 | ) 14 | 15 | 16 | print(completion) 17 | -------------------------------------------------------------------------------- /examples/studio/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/examples/studio/__init__.py -------------------------------------------------------------------------------- /examples/studio/async_library.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | import uuid 4 | 5 | import file_utils 6 | from ai21 import AsyncAI21Client, AI21APIError 7 | 8 | client = AsyncAI21Client() 9 | 10 | 11 | def validate_file_deleted(file_id: str): 12 | try: 13 | client.library.files.get(file_id) 14 | except AI21APIError as e: 15 | print(f"File not found. Exception: {e.details}") 16 | 17 | 18 | async def main(): 19 | file_name = f"test-file3-{uuid.uuid4()}.txt" 20 | file_path = os.getcwd() 21 | 22 | path = os.path.join(file_path, file_name) 23 | _SOURCE_TEXT = """Holland is a geographical region and former province on the western coast of the 24 | Netherlands. From the 10th to the 16th century, Holland proper was a unified political 25 | region within the Holy Roman Empire as a county ruled by the counts of Holland. 26 | By the 17th century, the province of Holland had risen to become a maritime and economic power, 27 | dominating the other provinces of the newly independent Dutch Republic.""" 28 | file_utils.create_file(file_path, file_name, content=_SOURCE_TEXT) 29 | 30 | file_id = await client.library.files.create( 31 | file_path=path, 32 | path=file_path, 33 | labels=["label1", "label2"], 34 | public_url="www.example.com", 35 | ) 36 | print(file_id) 37 | 38 | files = await client.library.files.list() 39 | print(files) 40 | uploaded_file = await client.library.files.get(file_id) 41 | print(uploaded_file) 42 | print(uploaded_file.name) 43 | print(uploaded_file.path) 44 | print(uploaded_file.labels) 45 | print(uploaded_file.public_url) 46 | 47 | await client.library.files.update( 48 | file_id, 49 | publicUrl="www.example-updated.com", 50 | labels=["label3", "label4"], 51 | ) 52 | updated_file = await client.library.files.get(file_id) 53 | print(updated_file.name) 54 | print(updated_file.public_url) 55 | print(updated_file.labels) 56 | 57 | await client.library.files.delete(file_id) 58 | try: 59 | uploaded_file = await client.library.files.get(file_id) 60 | except AI21APIError as e: 61 | print(f"File not found. Exception: {e.details}") 62 | 63 | # Cleanup created file 64 | file_utils.delete_file(file_path, file_name) 65 | 66 | 67 | asyncio.run(main()) 68 | -------------------------------------------------------------------------------- /examples/studio/async_tokenization.py: -------------------------------------------------------------------------------- 1 | from ai21.tokenizers import get_async_tokenizer 2 | import asyncio 3 | 4 | prompt = ( 5 | "The following is a conversation between a user of an eCommerce store and a user operation" 6 | " associate called Max. Max is very kind and keen to help." 7 | " The following are important points about the business policies:\n- " 8 | "Delivery takes up to 5 days\n- There is no return option\n\nUser gender:" 9 | " Male.\n\nConversation:\nUser: Hi, had a question\nMax: " 10 | "Hi there, happy to help!\nUser: Is there no way to return a product?" 11 | " I got your blue T-Shirt size small but it doesn't fit.\n" 12 | "Max: I'm sorry to hear that. Unfortunately we don't have a return policy. \n" 13 | "User: That's a shame. \nMax: Is there anything else i can do for you?\n\n" 14 | "##\n\nThe following is a conversation between a user of an eCommerce store and a user operation" 15 | " associate called Max. Max is very kind and keen to help. The following are important points about" 16 | " the business policies:\n- Delivery takes up to 5 days\n- There is no return option\n\n" 17 | 'User gender: Female.\n\nConversation:\nUser: Hi, I was wondering when you\'ll have the "Blue & White" ' 18 | "t-shirt back in stock?\nMax: Hi, happy to assist! We currently don't have it in stock. Do you want me" 19 | " to send you an email once we do?\nUser: Yes!\nMax: Awesome. What's your email?\nUser: anc@gmail.com\n" 20 | "Max: Great. I'll send you an email as soon as we get it.\n\n##\n\nThe following is a conversation between" 21 | " a user of an eCommerce store and a user operation associate called Max. Max is very kind and keen to help." 22 | " The following are important points about the business policies:\n- Delivery takes up to 5 days\n" 23 | "- There is no return option\n\nUser gender: Female.\n\nConversation:\nUser: Hi, how much time does it" 24 | " take for the product to reach me?\nMax: Hi, happy to assist! It usually takes 5 working" 25 | " days to reach you.\nUser: Got it! thanks. Is there a way to shorten that delivery time if i pay extra?\n" 26 | "Max: I'm sorry, no.\nUser: Got it. How do i know if the White Crisp t-shirt will fit my size?\n" 27 | "Max: The size charts are available on the website.\nUser: Can you tell me what will fit a young women.\n" 28 | "Max: Sure. Can you share her exact size?\n\n##\n\nThe following is a conversation between a user of an" 29 | " eCommerce store and a user operation associate called Max. Max is very kind and keen to help. The following" 30 | " are important points about the business policies:\n- Delivery takes up to 5 days\n" 31 | "- There is no return option\n\nUser gender: Female.\n\nConversation:\n" 32 | "User: Hi, I have a question for you" 33 | ) 34 | 35 | 36 | async def main(): 37 | tokenizer = await get_async_tokenizer(name="jamba-tokenizer") 38 | response = await tokenizer.count_tokens(prompt) 39 | print(response) 40 | 41 | 42 | asyncio.run(main()) 43 | -------------------------------------------------------------------------------- /examples/studio/batches/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/examples/studio/batches/__init__.py -------------------------------------------------------------------------------- /examples/studio/batches/async_batches.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from ai21 import AsyncAI21Client 4 | 5 | 6 | async def main(): 7 | client = AsyncAI21Client() 8 | 9 | # Create a batch 10 | batch = await client.batches.create( 11 | file="", 12 | endpoint="/v1/chat/completions", 13 | ) 14 | 15 | print(batch) 16 | 17 | # Retrieve a batch 18 | batch = await client.batches.retrieve(batch.id) 19 | print(batch) 20 | 21 | # List batches 22 | pages = client.batches.list(limit=3) 23 | 24 | async for page in pages: 25 | print([batch.id for batch in page]) 26 | 27 | # Cancel a batch 28 | await client.batches.cancel(batch.id) 29 | 30 | # Get results 31 | results = await client.batches.get_results( 32 | batch_id=batch.id, 33 | file_type="output", 34 | ) 35 | results.write_to_file("") 36 | print(results) 37 | 38 | 39 | if __name__ == "__main__": 40 | asyncio.run(main()) 41 | -------------------------------------------------------------------------------- /examples/studio/batches/batches.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from ai21 import AI21Client 4 | 5 | 6 | client = AI21Client(api_host="https://api-stage.ai21.com", api_key="IwOxkvOSQpbhGWVP62cWLvHCnv9poy5i") 7 | 8 | # Create a batch 9 | 10 | batch_requests = [ 11 | { 12 | "custom_id": "request-001", 13 | "method": "POST", 14 | "url": "/v1/chat/completions", 15 | "body": { 16 | "model": "jamba-1.5", 17 | "messages": [{"role": "user", "content": "What is your favorite color?"}], 18 | }, 19 | }, 20 | { 21 | "custom_id": "request-002", 22 | "method": "POST", 23 | "url": "/v1/chat/completions", 24 | "body": { 25 | "model": "jamba-1.5", 26 | "messages": [{"role": "user", "content": "Tell me about your hobbies."}], 27 | }, 28 | }, 29 | { 30 | "custom_id": "request-003", 31 | "method": "POST", 32 | "url": "/v1/chat/completions", 33 | "body": { 34 | "model": "jamba-1.5", 35 | "messages": [{"role": "user", "content": "Tell me about your favorite food."}], 36 | }, 37 | }, 38 | ] 39 | file_name = "batched_file.jsonl" 40 | with open(file_name, "w") as f: 41 | for request in batch_requests: 42 | json.dump(request, f) 43 | f.write("\n") 44 | 45 | # batch = client.beta.batches.create( 46 | # file=file_name, 47 | # endpoint="/v1/chat/completions", 48 | # ) 49 | 50 | # print(batch) 51 | 52 | # Retrieve a batch 53 | 54 | # batch = client.batches.retrieve(batch.id) 55 | # print(batch) 56 | 57 | # # List batches 58 | 59 | pages = client.beta.batches.list(limit=3) 60 | 61 | for page in pages: 62 | print([batch.id for batch in page]) 63 | 64 | # # Cancel a batch 65 | 66 | # client.batches.cancel(batch.id) 67 | 68 | 69 | # # Get results 70 | 71 | # results = client.batches.get_results( 72 | # batch_id=batch.id, 73 | # file_type="output", 74 | # ) 75 | # results.write_to_file("") 76 | # print(results) 77 | -------------------------------------------------------------------------------- /examples/studio/chat/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/examples/studio/chat/__init__.py -------------------------------------------------------------------------------- /examples/studio/chat/async_chat_completions.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from ai21 import AsyncAI21Client 4 | from ai21.models.chat import ChatMessage 5 | 6 | 7 | system = "You're a support engineer in a SaaS company" 8 | messages = [ 9 | ChatMessage(content=system, role="system"), 10 | ChatMessage(content="Hello, I need help with a signup process.", role="user"), 11 | ChatMessage(content="Hi Alice, I can help you with that. What seems to be the problem?", role="assistant"), 12 | ChatMessage(content="I am having trouble signing up for your product with my Google account.", role="user"), 13 | ] 14 | 15 | client = AsyncAI21Client() 16 | 17 | 18 | async def main(): 19 | response = await client.chat.completions.create( 20 | messages=messages, 21 | model="jamba-mini-1.6-2025-03", 22 | max_tokens=100, 23 | temperature=0.7, 24 | top_p=1.0, 25 | stop=["\n"], 26 | ) 27 | 28 | print(response) 29 | 30 | 31 | asyncio.run(main()) 32 | -------------------------------------------------------------------------------- /examples/studio/chat/async_stream_chat_completions.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from ai21 import AsyncAI21Client 4 | from ai21.models.chat import ChatMessage 5 | 6 | 7 | system = "You're a support engineer in a SaaS company" 8 | messages = [ 9 | ChatMessage(content=system, role="system"), 10 | ChatMessage(content="Hello, I need help with a signup process.", role="user"), 11 | ChatMessage(content="Hi Alice, I can help you with that. What seems to be the problem?", role="assistant"), 12 | ChatMessage(content="I am having trouble signing up for your product with my Google account.", role="user"), 13 | ] 14 | 15 | client = AsyncAI21Client() 16 | 17 | 18 | async def main(): 19 | response = await client.chat.completions.create( 20 | messages=messages, 21 | model="jamba-mini-1.6-2025-03", 22 | max_tokens=100, 23 | stream=True, 24 | ) 25 | async for chunk in response: 26 | print(chunk.choices[0].delta.content, end="") 27 | 28 | 29 | asyncio.run(main()) 30 | -------------------------------------------------------------------------------- /examples/studio/chat/chat_completions.py: -------------------------------------------------------------------------------- 1 | from ai21 import AI21Client 2 | from ai21.models.chat.chat_message import AssistantMessage, SystemMessage, UserMessage 3 | 4 | 5 | system = "You're a support engineer in a SaaS company" 6 | messages = [ 7 | SystemMessage(content=system, role="system"), 8 | UserMessage(content="Hello, I need help with a signup process.", role="user"), 9 | AssistantMessage(content="Hi Alice, I can help you with that. What seems to be the problem?", role="assistant"), 10 | UserMessage(content="I am having trouble signing up for your product with my Google account.", role="user"), 11 | ] 12 | 13 | client = AI21Client() 14 | 15 | response = client.chat.completions.create( 16 | messages=messages, 17 | model="jamba-mini-1.6-2025-03", 18 | max_tokens=100, 19 | temperature=0.7, 20 | top_p=1.0, 21 | stop=["\n"], 22 | ) 23 | 24 | print(response) 25 | -------------------------------------------------------------------------------- /examples/studio/chat/chat_documents.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | from ai21 import AI21Client 4 | from ai21.logger import set_verbose 5 | from ai21.models.chat import ChatMessage, DocumentSchema 6 | 7 | 8 | set_verbose(True) 9 | 10 | schnoodel = DocumentSchema( 11 | id=str(uuid.uuid4()), 12 | content="Schnoodel Inc. Annual Report - 2024. Schnoodel Inc., a leader in innovative culinary technology, saw a " 13 | "15% revenue growth this year, reaching $120 million. The launch of SchnoodelChef Pro has significantly " 14 | "contributed, making up 35% of total sales. We've expanded into the Asian market, notably Japan, " 15 | "and increased our global presence. Committed to sustainability, we reduced our carbon footprint " 16 | "by 20%. Looking ahead, we plan to integrate more advanced machine learning features and expand " 17 | "into South America.", 18 | metadata={"topic": "revenue"}, 19 | ) 20 | shnokel = DocumentSchema( 21 | id=str(uuid.uuid4()), 22 | content="Shnokel Corp. Annual Report - 2024. Shnokel Corp., a pioneer in renewable energy solutions, " 23 | "reported a 20% increase in revenue this year, reaching $200 million. The successful deployment of " 24 | "our advanced solar panels, SolarFlex, accounted for 40% of our sales. We entered new markets in Europe " 25 | "and have plans to develop wind energy projects next year. Our commitment to reducing environmental " 26 | "impact saw a 25% decrease in operational emissions. Upcoming initiatives include a significant " 27 | "investment in R&D for sustainable technologies.", 28 | metadata={"topic": "revenue"}, 29 | ) 30 | 31 | documents = [schnoodel, shnokel] 32 | 33 | messages = [ 34 | ChatMessage( 35 | role="system", 36 | content="You are a helpful assistant that receives revenue documents and answers related questions", 37 | ), 38 | ChatMessage(role="user", content="Hi, which company earned more during 2024 - Schnoodel or Shnokel?"), 39 | ] 40 | 41 | client = AI21Client() 42 | 43 | response = client.chat.completions.create( 44 | messages=messages, 45 | model="jamba-mini-1.6-2025-03", 46 | documents=documents, 47 | ) 48 | 49 | print(response) 50 | -------------------------------------------------------------------------------- /examples/studio/chat/chat_function_calling.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from ai21 import AI21Client 4 | from ai21.logger import set_verbose 5 | from ai21.models.chat import ( 6 | ChatMessage, 7 | FunctionToolDefinition, 8 | ToolDefinition, 9 | ToolMessage, 10 | ToolParameters, 11 | ) 12 | 13 | 14 | set_verbose(True) 15 | 16 | 17 | def get_order_delivery_date(order_id: str) -> str: 18 | print(f"Retrieving the delivery date for order ID: {order_id} from the database...") 19 | return "2025-05-04" 20 | 21 | 22 | messages = [ 23 | ChatMessage( 24 | role="system", 25 | content="You are a helpful customer support assistant. Use the supplied tools to assist the user.", 26 | ), 27 | ChatMessage(role="user", content="Hi, can you tell me the delivery date for my order?"), 28 | ChatMessage(role="assistant", content="Hi there! I can help with that. Can you please provide your order ID?"), 29 | ChatMessage(role="user", content="i think it is order_12345"), 30 | ] 31 | 32 | tool_definition = ToolDefinition( 33 | type="function", 34 | function=FunctionToolDefinition( 35 | name="get_order_delivery_date", 36 | description="Retrieve the delivery date associated with the specified order ID", 37 | parameters=ToolParameters( 38 | type="object", 39 | properties={"order_id": {"type": "string", "description": "The customer's order ID."}}, 40 | required=["order_id"], 41 | ), 42 | ), 43 | ) 44 | 45 | tools = [tool_definition] 46 | 47 | client = AI21Client() 48 | 49 | response = client.chat.completions.create(messages=messages, model="jamba-large-1.6-2025-03", tools=tools) 50 | 51 | """ AI models can be error-prone, it's crucial to ensure that the tool calls align with the expectations. 52 | The below code snippet demonstrates how to handle tool calls in the response and invoke the tool function 53 | to get the delivery date for the user's order. After retrieving the delivery date, we pass the response back 54 | to the AI model to continue the conversation, using the ToolMessage object. """ 55 | 56 | assistant_message = response.choices[0].message 57 | messages.append(assistant_message) # Adding the assistant message to the chat history 58 | 59 | delivery_date = None 60 | tool_calls = assistant_message.tool_calls 61 | if tool_calls: 62 | tool_call = tool_calls[0] 63 | if tool_call.function.name == "get_order_delivery_date": 64 | func_arguments = tool_call.function.arguments 65 | func_args_dict = json.loads(func_arguments) 66 | 67 | if "order_id" in func_args_dict: 68 | delivery_date = get_order_delivery_date(func_args_dict["order_id"]) 69 | else: 70 | print("order_id not found in function arguments") 71 | else: 72 | print(f"Unexpected tool call found - {tool_call.function.name}") 73 | else: 74 | print("No tool calls found") 75 | 76 | if delivery_date is not None: 77 | """Continue the conversation by passing the delivery date back to the model""" 78 | 79 | tool_message = ToolMessage(role="tool", tool_call_id=tool_calls[0].id, content=delivery_date) 80 | messages.append(tool_message) 81 | 82 | response = client.chat.completions.create(messages=messages, model="jamba-large-1.6-2025-03", tools=tools) 83 | print(response.choices[0].message.content) 84 | -------------------------------------------------------------------------------- /examples/studio/chat/chat_response_format.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from enum import Enum 4 | 5 | from pydantic import BaseModel 6 | 7 | from ai21 import AI21Client 8 | from ai21.logger import set_verbose 9 | from ai21.models.chat import ChatMessage, ResponseFormat 10 | 11 | 12 | set_verbose(True) 13 | 14 | 15 | class TicketType(Enum): 16 | ADULT = "adult" 17 | CHILD = "child" 18 | 19 | 20 | class ZooTicket(BaseModel): 21 | ticket_type: TicketType 22 | quantity: int 23 | 24 | 25 | class ZooTicketsOrder(BaseModel): 26 | date: str 27 | tickets: list[ZooTicket] 28 | 29 | 30 | messages = [ 31 | ChatMessage( 32 | role="user", 33 | content="Please create a JSON object for ordering zoo tickets for September 22, 2024, " 34 | f"for myself and two kids, based on the following JSON schema: {ZooTicketsOrder.schema()}.", 35 | ) 36 | ] 37 | 38 | client = AI21Client() 39 | 40 | response = client.chat.completions.create( 41 | messages=messages, 42 | model="jamba-large-1.6-2025-03", 43 | max_tokens=800, 44 | temperature=0, 45 | response_format=ResponseFormat(type="json_object"), 46 | ) 47 | 48 | zoo_order_json = json.loads(response.choices[0].message.content) 49 | print(zoo_order_json) 50 | -------------------------------------------------------------------------------- /examples/studio/chat/stream_chat_completions.py: -------------------------------------------------------------------------------- 1 | from ai21 import AI21Client 2 | from ai21.models.chat import ChatMessage 3 | 4 | 5 | system = "You're a support engineer in a SaaS company" 6 | messages = [ 7 | ChatMessage(content=system, role="system"), 8 | ChatMessage(content="Hello, I need help with a signup process.", role="user"), 9 | ChatMessage(content="Hi Alice, I can help you with that. What seems to be the problem?", role="assistant"), 10 | ChatMessage(content="I am having trouble signing up for your product with my Google account.", role="user"), 11 | ] 12 | 13 | client = AI21Client() 14 | 15 | response = client.chat.completions.create( 16 | messages=messages, 17 | model="jamba-large-1.6-2025-03", 18 | max_tokens=100, 19 | stream=True, 20 | ) 21 | for chunk in response: 22 | print(chunk.choices[0].delta.content, end="") 23 | -------------------------------------------------------------------------------- /examples/studio/conversational_rag/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/examples/studio/conversational_rag/__init__.py -------------------------------------------------------------------------------- /examples/studio/conversational_rag/async_conversational_rag.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | import uuid 4 | 5 | from ai21 import AsyncAI21Client 6 | from ai21.models.chat import ChatMessage 7 | from examples.studio import file_utils 8 | 9 | messages = [ 10 | ChatMessage(content="", role="user"), 11 | ] 12 | 13 | 14 | async def upload_file(client: AsyncAI21Client): 15 | file_name = f"test-file-{uuid.uuid4()}.txt" 16 | file_path = os.getcwd() 17 | 18 | path = os.path.join(file_path, file_name) 19 | file_utils.create_file(file_path, file_name, "") 20 | 21 | return await client.library.files.create( 22 | file_path=path, 23 | path=file_path, 24 | ) 25 | 26 | 27 | async def delete_file(client: AsyncAI21Client, file_id: str): 28 | await client.library.files.delete(file_id) 29 | 30 | 31 | async def main(): 32 | ai21_client = AsyncAI21Client() 33 | 34 | file_id = await upload_file(ai21_client) 35 | 36 | chat_response = await ai21_client.beta.conversational_rag.create( 37 | messages=messages, 38 | # you may add file IDs from your Studio library in order to question the model about their content 39 | file_ids=[file_id], 40 | ) 41 | 42 | print("Chat Response:", chat_response) 43 | 44 | await delete_file(ai21_client, file_id) 45 | 46 | 47 | if __name__ == "__main__": 48 | asyncio.run(main()) 49 | -------------------------------------------------------------------------------- /examples/studio/conversational_rag/conversational_rag.py: -------------------------------------------------------------------------------- 1 | import os 2 | import uuid 3 | 4 | from ai21 import AI21Client 5 | from ai21.models.chat import ChatMessage 6 | from examples.studio import file_utils 7 | 8 | messages = [ 9 | ChatMessage(content="", role="user"), 10 | ] 11 | 12 | 13 | def upload_file(client: AI21Client): 14 | file_name = f"test-file-{uuid.uuid4()}.txt" 15 | file_path = os.getcwd() 16 | 17 | path = os.path.join(file_path, file_name) 18 | file_utils.create_file(file_path, file_name, "") 19 | 20 | return client.library.files.create( 21 | file_path=path, 22 | path=file_path, 23 | ) 24 | 25 | 26 | def delete_file(client: AI21Client, file_id: str): 27 | client.library.files.delete(file_id) 28 | 29 | 30 | def main(): 31 | ai21_client = AI21Client() 32 | 33 | file_id = upload_file(ai21_client) 34 | 35 | chat_response = ai21_client.beta.conversational_rag.create( 36 | messages=messages, 37 | # you may add file IDs from your Studio library in order to question the model about their content 38 | file_ids=[file_id], 39 | ) 40 | 41 | print("Chat Response:", chat_response) 42 | 43 | delete_file(ai21_client, file_id) 44 | 45 | 46 | if __name__ == "__main__": 47 | main() 48 | -------------------------------------------------------------------------------- /examples/studio/file_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def create_file(path: str, name: str, content: str): 5 | f = open(os.path.join(path, name), "w") 6 | f.write(content) 7 | f.close() 8 | 9 | 10 | def delete_file(path: str, name: str): 11 | full_file_path = os.path.join(path, name) 12 | if os.path.isfile(full_file_path): 13 | os.remove(full_file_path) 14 | print(f"Delete file: {full_file_path}") 15 | else: 16 | print(f"Error: file not found: {full_file_path}") 17 | -------------------------------------------------------------------------------- /examples/studio/library.py: -------------------------------------------------------------------------------- 1 | import os 2 | import uuid 3 | 4 | import file_utils 5 | from ai21 import AI21Client, AI21APIError 6 | 7 | client = AI21Client() 8 | 9 | 10 | def validate_file_deleted(): 11 | try: 12 | client.library.files.get(file_id) 13 | except AI21APIError as e: 14 | print(f"File not found. Exception: {e.details}") 15 | 16 | 17 | file_name = f"test-file3-{uuid.uuid4()}.txt" 18 | file_path = os.getcwd() 19 | 20 | path = os.path.join(file_path, file_name) 21 | _SOURCE_TEXT = """Holland is a geographical region and former province on the western coast of the 22 | Netherlands. From the 10th to the 16th century, Holland proper was a unified political 23 | region within the Holy Roman Empire as a county ruled by the counts of Holland. 24 | By the 17th century, the province of Holland had risen to become a maritime and economic power, 25 | dominating the other provinces of the newly independent Dutch Republic.""" 26 | file_utils.create_file(file_path, file_name, content=_SOURCE_TEXT) 27 | 28 | file_id = client.library.files.create( 29 | file_path=path, 30 | path=file_path, 31 | labels=["label1", "label2"], 32 | public_url="www.example.com", 33 | ) 34 | print(file_id) 35 | 36 | files = client.library.files.list() 37 | print(files) 38 | uploaded_file = client.library.files.get(file_id) 39 | print(uploaded_file) 40 | print(uploaded_file.name) 41 | print(uploaded_file.path) 42 | print(uploaded_file.labels) 43 | print(uploaded_file.public_url) 44 | 45 | client.library.files.update( 46 | file_id, 47 | publicUrl="www.example-updated.com", 48 | labels=["label3", "label4"], 49 | ) 50 | updated_file = client.library.files.get(file_id) 51 | print(updated_file.name) 52 | print(updated_file.public_url) 53 | print(updated_file.labels) 54 | 55 | client.library.files.delete(file_id) 56 | try: 57 | uploaded_file = client.library.files.get(file_id) 58 | except AI21APIError as e: 59 | print(f"File not found. Exception: {e.details}") 60 | 61 | # Cleanup created file 62 | file_utils.delete_file(file_path, file_name) 63 | -------------------------------------------------------------------------------- /examples/studio/maestro/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/examples/studio/maestro/__init__.py -------------------------------------------------------------------------------- /examples/studio/maestro/async_run.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from ai21 import AsyncAI21Client 4 | 5 | client = AsyncAI21Client() 6 | 7 | 8 | async def main(): 9 | run_result = await client.beta.maestro.runs.create_and_poll( 10 | input="Write a poem about the ocean", 11 | requirements=[ 12 | { 13 | "name": "length requirement", 14 | "description": "The length of the poem should be less than 1000 characters", 15 | }, 16 | { 17 | "name": "rhyme requirement", 18 | "description": "The poem should rhyme", 19 | }, 20 | ], 21 | include=["requirements_result"], 22 | ) 23 | 24 | print(run_result) 25 | 26 | 27 | if __name__ == "__main__": 28 | asyncio.run(main()) 29 | -------------------------------------------------------------------------------- /examples/studio/maestro/run.py: -------------------------------------------------------------------------------- 1 | from ai21 import AI21Client 2 | 3 | client = AI21Client() 4 | 5 | 6 | def main(): 7 | run_result = client.beta.maestro.runs.create_and_poll( 8 | input="Write a poem about the ocean", 9 | requirements=[ 10 | { 11 | "name": "length requirement", 12 | "description": "The length of the poem should be less than 1000 characters", 13 | }, 14 | { 15 | "name": "rhyme requirement", 16 | "description": "The poem should rhyme", 17 | }, 18 | ], 19 | include=["requirements_result"], 20 | ) 21 | 22 | print(run_result) 23 | 24 | 25 | if __name__ == "__main__": 26 | main() 27 | -------------------------------------------------------------------------------- /examples/studio/tokenization.py: -------------------------------------------------------------------------------- 1 | from ai21.tokenizers import get_tokenizer 2 | 3 | prompt = ( 4 | "The following is a conversation between a user of an eCommerce store and a user operation" 5 | " associate called Max. Max is very kind and keen to help." 6 | " The following are important points about the business policies:\n- " 7 | "Delivery takes up to 5 days\n- There is no return option\n\nUser gender:" 8 | " Male.\n\nConversation:\nUser: Hi, had a question\nMax: " 9 | "Hi there, happy to help!\nUser: Is there no way to return a product?" 10 | " I got your blue T-Shirt size small but it doesn't fit.\n" 11 | "Max: I'm sorry to hear that. Unfortunately we don't have a return policy. \n" 12 | "User: That's a shame. \nMax: Is there anything else i can do for you?\n\n" 13 | "##\n\nThe following is a conversation between a user of an eCommerce store and a user operation" 14 | " associate called Max. Max is very kind and keen to help. The following are important points about" 15 | " the business policies:\n- Delivery takes up to 5 days\n- There is no return option\n\n" 16 | 'User gender: Female.\n\nConversation:\nUser: Hi, I was wondering when you\'ll have the "Blue & White" ' 17 | "t-shirt back in stock?\nMax: Hi, happy to assist! We currently don't have it in stock. Do you want me" 18 | " to send you an email once we do?\nUser: Yes!\nMax: Awesome. What's your email?\nUser: anc@gmail.com\n" 19 | "Max: Great. I'll send you an email as soon as we get it.\n\n##\n\nThe following is a conversation between" 20 | " a user of an eCommerce store and a user operation associate called Max. Max is very kind and keen to help." 21 | " The following are important points about the business policies:\n- Delivery takes up to 5 days\n" 22 | "- There is no return option\n\nUser gender: Female.\n\nConversation:\nUser: Hi, how much time does it" 23 | " take for the product to reach me?\nMax: Hi, happy to assist! It usually takes 5 working" 24 | " days to reach you.\nUser: Got it! thanks. Is there a way to shorten that delivery time if i pay extra?\n" 25 | "Max: I'm sorry, no.\nUser: Got it. How do i know if the White Crisp t-shirt will fit my size?\n" 26 | "Max: The size charts are available on the website.\nUser: Can you tell me what will fit a young women.\n" 27 | "Max: Sure. Can you share her exact size?\n\n##\n\nThe following is a conversation between a user of an" 28 | " eCommerce store and a user operation associate called Max. Max is very kind and keen to help. The following" 29 | " are important points about the business policies:\n- Delivery takes up to 5 days\n" 30 | "- There is no return option\n\nUser gender: Female.\n\nConversation:\n" 31 | "User: Hi, I have a question for you" 32 | ) 33 | tokenizer = get_tokenizer(name="jamba-tokenizer") 34 | response = tokenizer.count_tokens(prompt) 35 | print(response) 36 | -------------------------------------------------------------------------------- /examples/vertex/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/examples/vertex/__init__.py -------------------------------------------------------------------------------- /examples/vertex/async_chat_completions.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from ai21 import AsyncAI21VertexClient 4 | from ai21.models.chat import ChatMessage 5 | 6 | client = AsyncAI21VertexClient() 7 | 8 | 9 | async def main(): 10 | messages = ChatMessage(content="What is the meaning of life?", role="user") 11 | 12 | completion = await client.chat.completions.create( 13 | model="jamba-1.5-mini", 14 | messages=[messages], 15 | ) 16 | 17 | print(completion) 18 | 19 | 20 | asyncio.run(main()) 21 | -------------------------------------------------------------------------------- /examples/vertex/async_stream_chat_completions.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from ai21 import AsyncAI21VertexClient 4 | 5 | from ai21.models.chat import ChatMessage 6 | 7 | client = AsyncAI21VertexClient() 8 | 9 | messages = ChatMessage(content="What is the meaning of life?", role="user") 10 | 11 | 12 | async def main(): 13 | completion = await client.chat.completions.create( 14 | model="jamba-1.5-mini", 15 | messages=[messages], 16 | stream=True, 17 | ) 18 | 19 | async for chunk in completion: 20 | print(chunk.choices[0].delta.content, end="") 21 | 22 | 23 | asyncio.run(main()) 24 | -------------------------------------------------------------------------------- /examples/vertex/chat_completions.py: -------------------------------------------------------------------------------- 1 | from ai21 import AI21VertexClient 2 | 3 | from ai21.models.chat import ChatMessage 4 | 5 | client = AI21VertexClient() 6 | 7 | messages = ChatMessage(content="What is the meaning of life?", role="user") 8 | 9 | completion = client.chat.completions.create( 10 | model="jamba-1.5-mini", 11 | messages=[messages], 12 | ) 13 | 14 | print(completion) 15 | -------------------------------------------------------------------------------- /examples/vertex/stream_chat_completions.py: -------------------------------------------------------------------------------- 1 | from ai21 import AI21VertexClient 2 | 3 | from ai21.models.chat import ChatMessage 4 | 5 | client = AI21VertexClient() 6 | 7 | messages = ChatMessage(content="What is the meaning of life?", role="user") 8 | 9 | completion = client.chat.completions.create( 10 | model="jamba-1.5-mini", 11 | messages=[messages], 12 | stream=True, 13 | ) 14 | 15 | for chunk in completion: 16 | print(chunk.choices[0].delta.content, end="") 17 | -------------------------------------------------------------------------------- /init.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd "$(dirname "$0")" || exit 4 | 5 | # create .git folder 6 | if [[ ! -d .git ]]; then 7 | git init 8 | fi 9 | 10 | # install the python version specified in .python-version, if not already installed 11 | if pyenv --version; then 12 | pyenv install --skip-existing 13 | fi 14 | 15 | # install poetry if not already installed 16 | if ! poetry --version; then 17 | brew install poetry 18 | fi 19 | 20 | # poetry needs to create the venv with the same python version 21 | poetry env use "$(cat .python-version)" 22 | 23 | # update lock file 24 | poetry lock --no-update 25 | 26 | # install dependencies 27 | poetry install 28 | 29 | # install pre-commit if not already installed 30 | if ! pre-commit --version; then 31 | brew install pre-commit 32 | fi 33 | 34 | # install pre-commit hooks 35 | pre-commit install --install-hooks -t pre-commit -t commit-msg 36 | 37 | # shellcheck source=/dev/null 38 | source .venv/bin/activate 39 | -------------------------------------------------------------------------------- /poetry.toml: -------------------------------------------------------------------------------- 1 | [virtualenvs] 2 | in-project = true 3 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | target_version = ['py310'] 4 | include = '\.pyi?$' 5 | exclude = ''' 6 | ( 7 | /( 8 | \.eggs # exclude a few common directories in the 9 | | \.git # root of the project 10 | | \.pytest_cache 11 | | \.mypy_cache 12 | | \.venv 13 | | venv 14 | | _build 15 | | build 16 | | dist 17 | | pynguin.egg-info 18 | )/ 19 | ) 20 | ''' 21 | 22 | [tool.isort] 23 | py_version = 310 24 | profile = "black" 25 | combine_as_imports = true 26 | lines_between_types = 1 27 | lines_after_imports = 2 28 | src_paths = [ "ai21", "tests"] 29 | 30 | [tool.coverage.run] 31 | branch = true 32 | source = ["ai21", "tests"] 33 | omit = ["tests/fixtures/*"] 34 | 35 | [tool.coverage.report] 36 | exclude_lines = [ 37 | "pragma: no cover", 38 | "def __repr__", 39 | "def __str__", 40 | "raise AssertionError", 41 | "raise NotImplementedError", 42 | "if __name__ == .__main__.:", 43 | "if TYPE_CHECKING:", 44 | "if typing.TYPE_CHECKING:" 45 | ] 46 | 47 | [tool.poetry] 48 | name = "ai21" 49 | version = "3.3.0" 50 | description = "" 51 | authors = ["AI21 Labs"] 52 | readme = "README.md" 53 | packages = [ 54 | { include = "ai21" } 55 | ] 56 | 57 | [tool.poetry.dependencies] 58 | python = "^3.8" 59 | ai21-tokenizer = ">=0.12.0,<1.0.0" 60 | boto3 = { version = "^1.28.82", optional = true } 61 | typing-extensions = "^4.9.0" 62 | httpx = ">=0.27.0,<1.0.0" 63 | tenacity = "^8.3.0" 64 | google-auth = { version = "^2.31.0", optional = true } 65 | pydantic = ">=1.9.0,<3.0.0" 66 | 67 | 68 | [tool.poetry.group.dev.dependencies] 69 | black = "*" 70 | invoke = "*" 71 | isort = "*" 72 | mypy = "*" 73 | safety = "*" 74 | ruff = "*" 75 | python-semantic-release = "^8.5.0" 76 | pytest = "^7.4.3" 77 | pytest-mock = "^3.12.0" 78 | pytest-asyncio = "^0.21.1" 79 | 80 | [tool.poetry.extras] 81 | AWS = ["boto3"] 82 | Vertex = ["google-auth"] 83 | 84 | [build-system] 85 | requires = ["poetry-core"] 86 | build-backend = "poetry.core.masonry.api" 87 | 88 | [tool.coverage.html] 89 | directory = "cov_html" 90 | 91 | [tool.coverage.xml] 92 | directory = "coverage.xml" 93 | 94 | [tool.commitizen] 95 | name = "cz_customize" 96 | 97 | [tool.commitizen.customize] 98 | schema_pattern = "(build|ci|docs|feat|fix|perf|refactor|style|test|chore|revert|bump):(\\s.*)" 99 | 100 | [tool.semantic_release] 101 | version_toml = [ 102 | "pyproject.toml:tool.poetry.version" 103 | ] 104 | version_variables = [ 105 | "ai21/version.py:VERSION" 106 | ] 107 | match = "(main)" 108 | build_command = "pip install poetry && poetry build" 109 | version_source = "tag" 110 | commit_version_number = true 111 | commit_message = "chore(release): v{version} [skip ci]" 112 | 113 | [tool.semantic_release.branches.main] 114 | match = "(main)" 115 | 116 | [tool.semantic_release.branches."Release Candidates"] 117 | match = "(rc_*)" 118 | prerelease_token = "rc" 119 | prerelease = true 120 | 121 | [tool.ruff] 122 | line-length = 120 123 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | import os 4 | import codecs 5 | 6 | from setuptools import setup, find_packages 7 | from ai21.version import VERSION 8 | 9 | readme_path = os.path.abspath(os.path.dirname(__file__)) 10 | 11 | with codecs.open(os.path.join(readme_path, "README.md"), encoding="utf-8") as fh: 12 | long_description = "\\n" + fh.read() 13 | 14 | setup( 15 | name="ai21", 16 | version=VERSION, 17 | author="AI21 Labs", 18 | author_email="support@ai21.com", 19 | long_description_content_type="text/markdown", 20 | long_description=long_description, 21 | packages=find_packages(exclude=["tests", "tests.*"]), 22 | keywords=["python", "sdk", "ai", "ai21", "jurassic", "ai21-python", "llm"], 23 | install_requires=[ 24 | "httpx", 25 | ], 26 | ) 27 | -------------------------------------------------------------------------------- /tasks.py: -------------------------------------------------------------------------------- 1 | """ 2 | tasks to kick-off the project 3 | """ 4 | 5 | from invoke import task 6 | 7 | 8 | @task 9 | def formatter(tsk, fix=False): 10 | """ 11 | python format 12 | """ 13 | auto_fix = "" if fix else "--check --diff" 14 | cmd = " && ".join( 15 | [ 16 | f"python -m black . {auto_fix}", 17 | ] 18 | ) 19 | tsk.run(cmd, echo=True, pty=True) 20 | 21 | 22 | @task 23 | def lint(tsk, fix=False): 24 | """ 25 | python lint 26 | """ 27 | flags = "--fix" if fix else "" 28 | cmd = " && ".join( 29 | [ 30 | f"ruff *.py {flags} ai21/ tests/", 31 | ] 32 | ) 33 | tsk.run(cmd, echo=True, pty=True) 34 | 35 | 36 | @task(optional=["coverage"], help={"coverage": "[true|false]"}) 37 | def test(tsk, coverage=False): 38 | """ 39 | Run unit tests 40 | """ 41 | cov = "--cov --cov-report=term-missing" if coverage else "" 42 | cmd = f"poetry run pytest {cov}" 43 | tsk.run(cmd, echo=True, pty=True) 44 | 45 | 46 | @task 47 | def audit(tsk): 48 | """ 49 | Run audit check on the dependent packages 50 | """ 51 | cmd = "safety check --full-report" 52 | tsk.run(cmd, echo=True, pty=True) 53 | 54 | 55 | @task 56 | def staticcheck(tsk): 57 | """ 58 | Run static check on the projects files 59 | """ 60 | cmd = "mypy ai21 tests" 61 | tsk.run(cmd, echo=True, pty=True) 62 | 63 | 64 | @task 65 | def isort(tsk): 66 | """ 67 | Run static check on the projects files 68 | """ 69 | cmd = "isort ai21 tests" 70 | tsk.run(cmd, echo=True, pty=True) 71 | 72 | 73 | @task 74 | def build(tsk): 75 | """ 76 | generate a package for ai21 77 | """ 78 | cmd = "poetry build" 79 | tsk.run(cmd, echo=True, pty=True) 80 | 81 | 82 | @task 83 | def update(tsk): 84 | """ 85 | update outdated packages 86 | """ 87 | cmd = "poetry update" 88 | tsk.run(cmd, echo=True, pty=True) 89 | 90 | 91 | @task 92 | def outdated(tsk): 93 | """ 94 | update outdated packages 95 | """ 96 | cmd = "poetry show --outdated --top-level" 97 | tsk.run(cmd, echo=True, pty=True) 98 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/tests/__init__.py -------------------------------------------------------------------------------- /tests/integration_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/tests/integration_tests/__init__.py -------------------------------------------------------------------------------- /tests/integration_tests/clients/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/tests/integration_tests/clients/__init__.py -------------------------------------------------------------------------------- /tests/integration_tests/clients/bedrock/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/tests/integration_tests/clients/bedrock/__init__.py -------------------------------------------------------------------------------- /tests/integration_tests/clients/bedrock/test_chat_completions.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from ai21 import AI21BedrockClient, AsyncAI21BedrockClient, BedrockModelID 4 | from ai21.models._pydantic_compatibility import _to_dict 5 | from ai21.models.chat import ChatMessage 6 | 7 | 8 | _SYSTEM_MSG = "You're a support engineer in a SaaS company" 9 | _MESSAGES = [ 10 | ChatMessage(content=_SYSTEM_MSG, role="system"), 11 | ChatMessage(content="Hello, I need help with a signup process.", role="user"), 12 | ] 13 | 14 | 15 | def test_chat_completions__when_stream__last_chunk_should_hold_bedrock_metrics(): 16 | client = AI21BedrockClient() 17 | response = client.chat.completions.create( 18 | messages=_MESSAGES, 19 | model=BedrockModelID.JAMBA_1_5_MINI, 20 | stream=True, 21 | ) 22 | 23 | last_chunk = list(response)[-1] 24 | chunk_dict = _to_dict(last_chunk) 25 | assert "amazon-bedrock-invocationMetrics" in chunk_dict 26 | 27 | 28 | @pytest.mark.asyncio 29 | async def test__async_chat_completions__when_stream__last_chunk_should_hold_bedrock_metrics(): 30 | client = AsyncAI21BedrockClient() 31 | response = await client.chat.completions.create( 32 | messages=_MESSAGES, 33 | model=BedrockModelID.JAMBA_1_5_MINI, 34 | stream=True, 35 | ) 36 | 37 | last_chunk = [chunk async for chunk in response][-1] 38 | chunk_dict = _to_dict(last_chunk) 39 | assert "amazon-bedrock-invocationMetrics" in chunk_dict 40 | -------------------------------------------------------------------------------- /tests/integration_tests/clients/resources/library_file.txt: -------------------------------------------------------------------------------- 1 | Albert Einstein was a renowned physicist who made significant contributions to the field of theoretical physics. Born on March 14, 1879, in Ulm, in the Kingdom of Württemberg in the German Empire, Einstein's early life showed signs of his later intellectual prowess. 2 | 3 | Einstein attended the Swiss Federal Institute of Technology in Zurich, where he studied physics and mathematics. Despite facing challenges and financial difficulties, he persevered in his studies and graduated in 1900. After graduation, he struggled to secure a teaching position but eventually found work as a patent examiner at the Swiss Patent Office. 4 | 5 | In 1905, often referred to as his "miracle year," Einstein published four groundbreaking papers that transformed the scientific landscape. These papers covered the photoelectric effect, Brownian motion, special relativity, and the famous equation E=mc², demonstrating the equivalence of mass and energy. 6 | 7 | His theory of special relativity, published in the paper "On the Electrodynamics of Moving Bodies," challenged traditional notions of space and time. It introduced the concept of spacetime and showed that time is relative, depending on the observer's motion. 8 | 9 | In 1915, Einstein presented the general theory of relativity, providing a new understanding of gravitation. According to general relativity, massive objects like planets and stars cause a curvature in spacetime, influencing the motion of other objects. This theory successfully explained phenomena like the bending of light around massive objects. 10 | 11 | Einstein's work laid the foundation for modern cosmology and astrophysics. His predictions, such as the bending of light by gravity, were later confirmed through experiments and observations. 12 | 13 | Apart from his scientific endeavors, Einstein was an advocate for civil rights, pacifism, and Zionism. He spoke out against discrimination and injustice, using his platform to promote social and political causes. In 1933, Einstein fled Nazi Germany and settled in the United States, where he continued his scientific research. 14 | 15 | Einstein received the Nobel Prize in Physics in 1921 for his explanation of the photoelectric effect. Despite his immense contributions to science, he remained humble and often expressed a deep curiosity about the mysteries of the universe. 16 | 17 | In the latter part of his life, Einstein worked towards a unified field theory, attempting to combine electromagnetism and gravity into a single framework. However, this goal remained elusive, and Einstein's efforts in this direction were not as successful as his earlier work. 18 | 19 | Albert Einstein passed away on April 18, 1955, leaving behind a legacy that continues to shape our understanding of the physical world. His intellectual brilliance, coupled with his commitment to social justice, has made him an enduring symbol of scientific achievement and moral responsibility. The impact of Einstein's ideas extends far beyond the realm of physics, influencing fields as diverse as philosophy, literature, and popular culture. 20 | -------------------------------------------------------------------------------- /tests/integration_tests/clients/studio/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/tests/integration_tests/clients/studio/__init__.py -------------------------------------------------------------------------------- /tests/integration_tests/clients/studio/conftest.py: -------------------------------------------------------------------------------- 1 | import time 2 | from pathlib import Path 3 | 4 | import pytest 5 | 6 | from ai21 import AI21Client 7 | 8 | LIBRARY_FILE_TO_UPLOAD = str(Path(__file__).parent.parent / "resources" / "library_file.txt") 9 | DEFAULT_LABELS = ["einstein", "science"] 10 | 11 | 12 | def _wait_for_file_to_process(client: AI21Client, file_id: str, timeout: float = 60): 13 | start_time = time.time() 14 | 15 | elapsed_time = time.time() - start_time 16 | while elapsed_time < timeout: 17 | file_response = client.library.files.get(file_id) 18 | 19 | if file_response.status == "PROCESSED": 20 | return 21 | 22 | elapsed_time = time.time() - start_time 23 | time.sleep(2) 24 | 25 | raise TimeoutError(f"Timeout: {timeout} seconds passed. File processing not completed") 26 | 27 | 28 | def _delete_uploaded_file(client: AI21Client, file_id: str): 29 | _wait_for_file_to_process(client, file_id) 30 | client.library.files.delete(file_id) 31 | 32 | 33 | @pytest.fixture(scope="module") 34 | def file_in_library(): 35 | """ 36 | Uploads a file to the library and deletes it after the test is done 37 | This happens in a scope of a module so the file is uploaded only once 38 | :return: file_id: str 39 | """ 40 | client = AI21Client() 41 | 42 | # Delete any file that might be in the library due to failed tests 43 | files = client.library.files.list() 44 | for file in files: 45 | _delete_uploaded_file(client=client, file_id=file.file_id) 46 | 47 | file_id = client.library.files.create(file_path=LIBRARY_FILE_TO_UPLOAD, labels=DEFAULT_LABELS) 48 | _wait_for_file_to_process(client, file_id) 49 | yield file_id 50 | _delete_uploaded_file(client=client, file_id=file_id) 51 | -------------------------------------------------------------------------------- /tests/integration_tests/clients/studio/test_library.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from time import sleep 3 | import pytest 4 | 5 | from ai21 import AI21Client, AsyncAI21Client 6 | from tests.integration_tests.clients.studio.conftest import LIBRARY_FILE_TO_UPLOAD, DEFAULT_LABELS 7 | 8 | 9 | @pytest.mark.skipif 10 | def test_library__when_upload__should_get_file_id(file_in_library: str): 11 | assert file_in_library is not None 12 | 13 | 14 | @pytest.mark.skipif 15 | def test_library__when_list__should_get_file_id_in_list_of_files(file_in_library: str): 16 | client = AI21Client() 17 | 18 | files = client.library.files.list() 19 | assert files[0].file_id == file_in_library 20 | assert files[0].name == Path(LIBRARY_FILE_TO_UPLOAD).name 21 | 22 | 23 | @pytest.mark.skipif 24 | def test_library__when_get__should_match_file_id(file_in_library: str): 25 | client = AI21Client() 26 | 27 | file_response = client.library.files.get(file_in_library) 28 | assert file_response.file_id == file_in_library 29 | 30 | 31 | @pytest.mark.skipif 32 | def test_library__when_update__should_update_labels_successfully(file_in_library: str): 33 | client = AI21Client() 34 | 35 | file_response = client.library.files.get(file_in_library) 36 | assert set(file_response.labels) == set(DEFAULT_LABELS) 37 | sleep(2) 38 | 39 | new_labels = DEFAULT_LABELS + ["new_label"] 40 | client.library.files.update(file_in_library, labels=new_labels) 41 | file_response = client.library.files.get(file_in_library) 42 | assert set(file_response.labels) == set(new_labels) 43 | 44 | 45 | @pytest.mark.skipif 46 | @pytest.mark.asyncio 47 | async def test_async_library__when_list__should_get_file_id_in_list_of_files(file_in_library: str): 48 | client = AsyncAI21Client() 49 | 50 | files = await client.library.files.list() 51 | assert files[0].file_id == file_in_library 52 | assert files[0].name == Path(LIBRARY_FILE_TO_UPLOAD).name 53 | 54 | 55 | @pytest.mark.skipif 56 | @pytest.mark.asyncio 57 | async def test_async_library__when_get__should_match_file_id(file_in_library: str): 58 | client = AsyncAI21Client() 59 | 60 | file_response = await client.library.files.get(file_in_library) 61 | assert file_response.file_id == file_in_library 62 | 63 | 64 | @pytest.mark.skipif 65 | @pytest.mark.asyncio 66 | async def test_async_library__when_update__should_update_labels_successfully(file_in_library: str): 67 | client = AsyncAI21Client() 68 | curr_labels = DEFAULT_LABELS + ["new_label"] 69 | file_response = await client.library.files.get(file_in_library) 70 | assert set(file_response.labels) == set(curr_labels) 71 | sleep(2) 72 | 73 | new_labels = curr_labels + ["new_label2"] 74 | await client.library.files.update(file_in_library, labels=new_labels) 75 | file_response = await client.library.files.get(file_in_library) 76 | print(file_response.labels) 77 | assert set(file_response.labels) == set(new_labels) 78 | -------------------------------------------------------------------------------- /tests/integration_tests/clients/studio/test_maestro.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from ai21 import AsyncAI21Client 4 | 5 | 6 | @pytest.mark.asyncio 7 | async def test_maestro__when_upload__should_return_data_sources(): # file_in_library: str): 8 | client = AsyncAI21Client() 9 | result = await client.beta.maestro.runs.create_and_poll( 10 | input="When did Einstein receive a Nobel Prize?", tools=[{"type": "file_search"}], include=["data_sources"] 11 | ) 12 | assert result.status == "completed", "Expected 'completed' status" 13 | assert result.result, "Expected a non-empty answer" 14 | assert result.data_sources, "Expected data sources" 15 | assert len(result.data_sources["file_search"]) > 0, "Expected at least one file search data source" 16 | assert result.data_sources.get("web_search") is None, "Expected no web search data sources" 17 | -------------------------------------------------------------------------------- /tests/integration_tests/clients/test_bedrock.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run this script after setting the environment variable called AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY 3 | """ 4 | 5 | import subprocess 6 | 7 | from pathlib import Path 8 | 9 | import pytest 10 | 11 | from tests.integration_tests.skip_helpers import should_skip_bedrock_integration_tests 12 | 13 | 14 | BEDROCK_DIR = "bedrock" 15 | 16 | BEDROCK_PATH = Path(__file__).parent.parent.parent.parent / "examples" / BEDROCK_DIR 17 | 18 | 19 | @pytest.mark.skipif(should_skip_bedrock_integration_tests(), reason="No keys supplied for AWS. Skipping.") 20 | @pytest.mark.parametrize( 21 | argnames=["test_file_name"], 22 | argvalues=[ 23 | ("chat/chat_completions.py",), 24 | ("chat/stream_chat_completions.py",), 25 | ("chat/async_chat_completions.py",), 26 | ("chat/async_stream_chat_completions.py",), 27 | ], 28 | ids=[ 29 | "when_chat_completions__should_return_ok", 30 | "when_stream_chat_completions__should_return_ok", 31 | "when_async_chat_completions__should_return_ok", 32 | "when_stream_async_chat_completions__should_return_ok", 33 | ], 34 | ) 35 | def test_bedrock(test_file_name: str): 36 | file_path = BEDROCK_PATH / test_file_name 37 | print(f"About to run: {file_path}") 38 | exit_code = subprocess.call(["python", file_path]) 39 | assert exit_code == 0, f"failed to run {test_file_name}" 40 | -------------------------------------------------------------------------------- /tests/integration_tests/clients/test_studio.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run this script after setting the environment variable called AI21_API_KEY 3 | """ 4 | 5 | import subprocess 6 | 7 | from pathlib import Path 8 | from time import sleep 9 | 10 | import pytest 11 | 12 | from tests.integration_tests.skip_helpers import should_skip_studio_integration_tests 13 | 14 | 15 | STUDIO_PATH = Path(__file__).parent.parent.parent.parent / "examples" / "studio" 16 | 17 | 18 | @pytest.mark.skipif(should_skip_studio_integration_tests(), reason="No key supplied for AI21 Studio. Skipping.") 19 | @pytest.mark.parametrize( 20 | argnames=["test_file_name"], 21 | argvalues=[ 22 | ("tokenization.py",), 23 | ("chat/chat_completions.py",), 24 | ("chat/stream_chat_completions.py",), 25 | ("chat/chat_documents.py",), 26 | ("chat/chat_function_calling.py",), 27 | ("chat/chat_function_calling_multiple_tools.py",), 28 | ("chat/chat_response_format.py",), 29 | ], 30 | ids=[ 31 | "when_tokenization__should_return_ok", 32 | "when_chat_completions__should_return_ok", 33 | "when_stream_chat_completions__should_return_ok", 34 | "when_chat_completions_with_documents__should_return_ok", 35 | "when_chat_completions_with_function_calling__should_return_ok", 36 | "when_chat_completions_with_function_calling_multiple_tools_should_return_ok", 37 | "when_chat_completions_with_response_format__should_return_ok", 38 | ], 39 | ) 40 | def test_studio(test_file_name: str): 41 | file_path = STUDIO_PATH / test_file_name 42 | print(f"About to run: {file_path}") 43 | sleep(0.5) 44 | exit_code = subprocess.call(["python", file_path]) 45 | assert exit_code == 0, f"failed to run {test_file_name}" 46 | 47 | 48 | @pytest.mark.asyncio 49 | @pytest.mark.skipif(should_skip_studio_integration_tests(), reason="No key supplied for AI21 Studio. Skipping.") 50 | @pytest.mark.parametrize( 51 | argnames=["test_file_name"], 52 | argvalues=[ 53 | ("chat/async_chat_completions.py",), 54 | ("chat/async_stream_chat_completions.py",), 55 | ("conversational_rag/conversational_rag.py",), 56 | ("conversational_rag/async_conversational_rag.py",), 57 | ("maestro/run.py",), 58 | ("maestro/async_run.py",), 59 | ], 60 | ids=[ 61 | "when_chat_completions__should_return_ok", 62 | "when_stream_chat_completions__should_return_ok", 63 | "when_conversational_rag__should_return_ok", 64 | "when_async_conversational_rag__should_return_ok", 65 | "when_maestro_runs__should_return_ok", 66 | "when_maestro_async_runs__should_return_ok", 67 | ], 68 | ) 69 | async def test_async_studio(test_file_name: str): 70 | file_path = STUDIO_PATH / test_file_name 71 | print(f"About to run: {file_path}") 72 | exit_code = subprocess.call(["python", file_path]) 73 | assert exit_code == 0, f"failed to run {test_file_name}" 74 | -------------------------------------------------------------------------------- /tests/integration_tests/clients/test_vertex.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run this script after setting the environment variable called AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY 3 | """ 4 | 5 | import subprocess 6 | from pathlib import Path 7 | 8 | import pytest 9 | 10 | from tests.integration_tests.skip_helpers import should_skip_vertex_integration_tests 11 | 12 | VERTEX_DIR = "vertex" 13 | 14 | VERTEX_PATH = Path(__file__).parent.parent.parent.parent / "examples" / VERTEX_DIR 15 | 16 | 17 | @pytest.mark.skipif(should_skip_vertex_integration_tests(), reason="No keys supplied for Vertex. Skipping.") 18 | @pytest.mark.parametrize( 19 | argnames=["test_file_name"], 20 | argvalues=[ 21 | ("chat_completions.py",), 22 | ("stream_chat_completions.py",), 23 | ("async_chat_completions.py",), 24 | ("async_stream_chat_completions.py",), 25 | ], 26 | ids=[ 27 | "when_chat_completions__should_return_ok", 28 | "when_stream_chat_completions__should_return_ok", 29 | "when_async_chat_completions__should_return_ok", 30 | "when_async_stream_chat_completions__should_return_ok", 31 | ], 32 | ) 33 | def test_vertex(test_file_name: str): 34 | file_path = VERTEX_PATH / test_file_name 35 | print(f"About to run: {file_path}") 36 | exit_code = subprocess.call(["python", file_path]) 37 | assert exit_code == 0, f"failed to run {test_file_name}" 38 | -------------------------------------------------------------------------------- /tests/integration_tests/skip_helpers.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def should_skip_bedrock_integration_tests() -> bool: 5 | return os.getenv("AWS_ACCESS_KEY_ID") is None or os.getenv("AWS_SECRET_ACCESS_KEY") is None 6 | 7 | 8 | def should_skip_studio_integration_tests() -> bool: 9 | return os.getenv("AI21_API_KEY") is None 10 | 11 | 12 | def should_skip_vertex_integration_tests() -> bool: 13 | return True 14 | -------------------------------------------------------------------------------- /tests/unittests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/tests/unittests/__init__.py -------------------------------------------------------------------------------- /tests/unittests/clients/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/tests/unittests/clients/__init__.py -------------------------------------------------------------------------------- /tests/unittests/clients/azure/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/tests/unittests/clients/azure/__init__.py -------------------------------------------------------------------------------- /tests/unittests/clients/azure/test_ai21_azure_client.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from ai21 import AI21AzureClient 4 | 5 | 6 | def test__when_with_api_key__should_be_ok(): 7 | client = AI21AzureClient(base_url="https://example.com", api_key="test_api_key") 8 | assert client._api_key == "test_api_key" 9 | assert client._azure_ad_token is None 10 | assert client._azure_ad_token_provider is None 11 | 12 | 13 | def test__when_with_azure_ad_token__should_be_ok(): 14 | client = AI21AzureClient(base_url="https://example.com", azure_ad_token="test_azure_ad_token") 15 | assert client._azure_ad_token == "test_azure_ad_token" 16 | assert client._api_key is None 17 | assert client._azure_ad_token_provider is None 18 | 19 | 20 | def test__when_with_azure_ad_token_provider__should_be_ok(): 21 | def token_provider(): 22 | return "test_azure_ad_token" 23 | 24 | client = AI21AzureClient(base_url="https://example.com", azure_ad_token_provider=token_provider) 25 | assert client._azure_ad_token_provider == token_provider 26 | assert client._api_key is None 27 | assert client._azure_ad_token is None 28 | 29 | 30 | def test__when_without_any_token_or_key__should_raise_error(): 31 | with pytest.raises(ValueError, match="Must provide either api_key or azure_ad_token_provider or azure_ad_token"): 32 | AI21AzureClient(base_url="https://example.com") 33 | -------------------------------------------------------------------------------- /tests/unittests/clients/azure/test_chat_completions.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from ai21 import AI21AzureClient 4 | 5 | 6 | def test__azure_client__when_init_with_no_auth__should_raise_error(): 7 | with pytest.raises(ValueError) as e: 8 | AI21AzureClient(base_url="http://some_endpoint_url") 9 | 10 | assert str(e.value) == "Must provide either api_key or azure_ad_token_provider or azure_ad_token" 11 | -------------------------------------------------------------------------------- /tests/unittests/clients/bedrock/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/tests/unittests/clients/bedrock/__init__.py -------------------------------------------------------------------------------- /tests/unittests/clients/studio/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/tests/unittests/clients/studio/__init__.py -------------------------------------------------------------------------------- /tests/unittests/clients/studio/resources/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/tests/unittests/clients/studio/resources/__init__.py -------------------------------------------------------------------------------- /tests/unittests/clients/studio/resources/chat/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/tests/unittests/clients/studio/resources/chat/__init__.py -------------------------------------------------------------------------------- /tests/unittests/clients/studio/resources/test_async_studio_resource.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, TypeVar 2 | 3 | import pytest 4 | 5 | from ai21.clients.studio.resources.studio_resource import AsyncStudioResource 6 | from ai21.http_client.async_http_client import AsyncAI21HTTPClient 7 | from ai21.models.ai21_base_model import AI21BaseModel 8 | from tests.unittests.clients.studio.resources.conftest import ( 9 | get_chat_completions, 10 | get_studio_chat, 11 | ) 12 | 13 | 14 | _BASE_URL = "https://test.api.ai21.com/studio/v1" 15 | 16 | T = TypeVar("T", bound=AsyncStudioResource) 17 | 18 | 19 | class TestAsyncStudioResources: 20 | @pytest.mark.asyncio 21 | @pytest.mark.parametrize( 22 | ids=[ 23 | "async_studio_chat", 24 | "async_chat_completions", 25 | ], 26 | argnames=[ 27 | "studio_resource", 28 | "function_body", 29 | "url_suffix", 30 | "expected_body", 31 | "expected_httpx_response", 32 | "expected_response", 33 | ], 34 | argvalues=[ 35 | (get_studio_chat(is_async=True)), 36 | (get_chat_completions(is_async=True)), 37 | ], 38 | ) 39 | async def test__create__should_return_response( 40 | self, 41 | studio_resource: Callable[[AsyncAI21HTTPClient], T], 42 | function_body, 43 | url_suffix: str, 44 | expected_body, 45 | expected_httpx_response, 46 | expected_response: AI21BaseModel, 47 | mock_async_ai21_studio_client: AsyncAI21HTTPClient, 48 | ): 49 | mock_async_ai21_studio_client.execute_http_request.return_value = expected_httpx_response 50 | 51 | resource = studio_resource(mock_async_ai21_studio_client) 52 | 53 | actual_response = await resource.create( 54 | **function_body, 55 | ) 56 | 57 | assert actual_response == expected_response 58 | mock_async_ai21_studio_client.execute_http_request.assert_called_with( 59 | method="POST", 60 | path=f"/{url_suffix}", 61 | body=expected_body, 62 | params={}, 63 | stream=False, 64 | files=None, 65 | ) 66 | -------------------------------------------------------------------------------- /tests/unittests/clients/studio/resources/test_chat.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from ai21 import AI21Client, AsyncAI21Client 3 | from ai21.models.chat import ChatMessage 4 | 5 | _DUMMY_API_KEY = "dummy_api_key" 6 | 7 | 8 | def test_chat_create__when_bad_import_to_chat_message__raise_error(): 9 | with pytest.raises(ValueError) as e: 10 | AI21Client(api_key=_DUMMY_API_KEY).chat.create( 11 | model="j2-ultra", messages=[ChatMessage(role="user", content="Hello")], system="System Test" 12 | ) 13 | 14 | assert ( 15 | e.value.args[0] 16 | == "Please use the ChatMessage class from ai21.models instead of ai21.models.chat when working with chat" 17 | ) 18 | 19 | 20 | @pytest.mark.asyncio 21 | async def test_async_chat_create__when_bad_import_to_chat_message__raise_error(): 22 | with pytest.raises(ValueError) as e: 23 | await AsyncAI21Client(api_key=_DUMMY_API_KEY).chat.create( 24 | model="j2-ultra", messages=[ChatMessage(role="user", content="Hello")], system="System Test" 25 | ) 26 | 27 | assert ( 28 | e.value.args[0] 29 | == "Please use the ChatMessage class from ai21.models instead of ai21.models.chat when working with chat" 30 | ) 31 | -------------------------------------------------------------------------------- /tests/unittests/clients/studio/resources/test_studio_resources.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, TypeVar 2 | 3 | import pytest 4 | 5 | from ai21.clients.studio.resources.studio_resource import StudioResource 6 | from ai21.http_client.http_client import AI21HTTPClient 7 | from ai21.models.ai21_base_model import AI21BaseModel 8 | from tests.unittests.clients.studio.resources.conftest import ( 9 | get_chat_completions, 10 | get_studio_chat, 11 | ) 12 | 13 | 14 | _BASE_URL = "https://test.api.ai21.com/studio/v1" 15 | T = TypeVar("T", bound=StudioResource) 16 | 17 | 18 | class TestStudioResources: 19 | @pytest.mark.parametrize( 20 | ids=[ 21 | "studio_chat", 22 | "chat_completions", 23 | ], 24 | argnames=[ 25 | "studio_resource", 26 | "function_body", 27 | "url_suffix", 28 | "expected_body", 29 | "expected_httpx_response", 30 | "expected_response", 31 | ], 32 | argvalues=[ 33 | (get_studio_chat()), 34 | (get_chat_completions()), 35 | ], 36 | ) 37 | def test__create__should_return_response( 38 | self, 39 | studio_resource: Callable[[AI21HTTPClient], T], 40 | function_body, 41 | url_suffix: str, 42 | expected_body, 43 | expected_httpx_response, 44 | expected_response: AI21BaseModel, 45 | mock_ai21_studio_client: AI21HTTPClient, 46 | ): 47 | mock_ai21_studio_client.execute_http_request.return_value = expected_httpx_response 48 | 49 | resource = studio_resource(mock_ai21_studio_client) 50 | 51 | actual_response = resource.create( 52 | **function_body, 53 | ) 54 | 55 | assert actual_response == expected_response 56 | mock_ai21_studio_client.execute_http_request.assert_called_with( 57 | method="POST", 58 | path=f"/{url_suffix}", 59 | body=expected_body, 60 | params={}, 61 | stream=False, 62 | files=None, 63 | ) 64 | -------------------------------------------------------------------------------- /tests/unittests/clients/studio/test_ai21_client.py: -------------------------------------------------------------------------------- 1 | from ai21 import AI21Client, AI21EnvConfig 2 | 3 | 4 | def test_ai21_client__when_pass_api_host__should_leave_as_is(): 5 | base_url = "https://dont-modify-me.com" 6 | client = AI21Client(api_host=base_url) 7 | assert client._base_url == base_url 8 | 9 | 10 | def test_ai21_client__when_not_pass_api_host__should_be_studio_host(): 11 | client = AI21Client() 12 | assert client._base_url == AI21EnvConfig.api_host 13 | 14 | 15 | def test_ai21_client__when_pass_ai21_with_suffix__should_not_modify(): 16 | ai21_url = "https://api.ai21.com/studio/v1" 17 | client = AI21Client(api_host=ai21_url) 18 | assert client._base_url == ai21_url 19 | -------------------------------------------------------------------------------- /tests/unittests/clients/studio/test_async_ai21_client.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from ai21 import AI21EnvConfig, AsyncAI21Client 4 | 5 | 6 | @pytest.mark.asyncio 7 | def test_async_ai21_client__when_pass_api_host__should_leave_as_is(): 8 | base_url = "https://dont-modify-me.com" 9 | client = AsyncAI21Client(api_host=base_url) 10 | assert client._base_url == base_url 11 | 12 | 13 | @pytest.mark.asyncio 14 | def test_async_ai21_client__when_not_pass_api_host__should_be_studio_host(): 15 | client = AsyncAI21Client() 16 | assert client._base_url == AI21EnvConfig.api_host 17 | 18 | 19 | @pytest.mark.asyncio 20 | def test_async_ai21_client__when_pass_ai21_with_suffix__should_not_modify(): 21 | ai21_url = "https://api.ai21.com/studio/v1" 22 | client = AsyncAI21Client(api_host=ai21_url) 23 | assert client._base_url == ai21_url 24 | -------------------------------------------------------------------------------- /tests/unittests/clients/vertex/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/tests/unittests/clients/vertex/__init__.py -------------------------------------------------------------------------------- /tests/unittests/commons.py: -------------------------------------------------------------------------------- 1 | FAKE_CHAT_COMPLETION_RESPONSE_DICT = { 2 | "id": "cmpl-392a6a33e5204aa7a2070be4d0ddbc0a", 3 | "choices": [ 4 | { 5 | "index": 0, 6 | "message": { 7 | "role": "assistant", 8 | "content": "Test", 9 | }, 10 | "logprobs": None, 11 | "finishReason": "stop", 12 | } 13 | ], 14 | "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}, 15 | } 16 | 17 | FAKE_AUTH_HEADERS = {"Authorization": "Bearer fake-token"} 18 | -------------------------------------------------------------------------------- /tests/unittests/conftest.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | import httpx 3 | import pytest 4 | 5 | from google.auth.credentials import Credentials 6 | from google.auth.transport.requests import Request 7 | 8 | from ai21.clients.common.auth.gcp_authorization import GCPAuthorization 9 | from ai21.clients.vertex.ai21_vertex_client import AI21VertexClient 10 | 11 | 12 | @pytest.fixture 13 | def dummy_api_host() -> str: 14 | return "http://test_host" 15 | 16 | 17 | @pytest.fixture 18 | def mock_httpx_client(mocker) -> httpx.Client: 19 | return mocker.Mock(spec=httpx.Client) 20 | 21 | 22 | @pytest.fixture 23 | def mock_httpx_async_client(mocker) -> httpx.AsyncClient: 24 | return mocker.AsyncMock(spec=httpx.AsyncClient) 25 | 26 | 27 | @pytest.fixture 28 | def mock_httpx_response(mocker) -> httpx.Response: 29 | return mocker.Mock(spec=httpx.Response) 30 | 31 | 32 | @pytest.fixture 33 | def mock_boto_session(mocker) -> boto3.Session: 34 | boto_mocker = mocker.Mock(spec=boto3.Session) 35 | boto_mocker.region_name = "us-east-1" 36 | return boto_mocker 37 | 38 | 39 | @pytest.fixture 40 | def mock_gcp_credentials(mocker) -> Credentials: 41 | return mocker.Mock(spec=Credentials) 42 | 43 | 44 | @pytest.fixture 45 | def mock_gcp_request(mocker) -> Request: 46 | return mocker.Mock(spec=Request) 47 | 48 | 49 | @pytest.fixture 50 | def mock_ai21_vertex_client(mocker) -> AI21VertexClient: 51 | ai21_vertex_client_mock = mocker.Mock(spec=AI21VertexClient) 52 | ai21_vertex_client_mock._gcp_auth = mocker.Mock(spec=GCPAuthorization) 53 | return ai21_vertex_client_mock 54 | -------------------------------------------------------------------------------- /tests/unittests/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/ai21-python/624fa7cfeedcac0b41e4d7dddc48b77ac12e3bc9/tests/unittests/models/__init__.py -------------------------------------------------------------------------------- /tests/unittests/models/response_mocks.py: -------------------------------------------------------------------------------- 1 | from ai21.models import ChatOutput, ChatResponse, FinishReason, RoleType 2 | from ai21.models.chat import ChatCompletionResponse, ChatCompletionResponseChoice 3 | from ai21.models.chat.chat_message import AssistantMessage 4 | from ai21.models.usage_info import UsageInfo 5 | 6 | 7 | def get_chat_response(): 8 | expected_dict = { 9 | "outputs": [ 10 | { 11 | "text": "It's a big question, and the answer is different for everyone. Some people", 12 | "role": "assistant", 13 | "finishReason": {"reason": "length", "length": 10, "sequence": None}, 14 | }, 15 | ], 16 | } 17 | 18 | output = ChatOutput( 19 | text="It's a big question, and the answer is different for everyone. Some people", 20 | role=RoleType.ASSISTANT, 21 | finish_reason=FinishReason(reason="length", length=10), 22 | ) 23 | chat_response = ChatResponse(outputs=[output]) 24 | 25 | return chat_response, expected_dict, ChatResponse 26 | 27 | 28 | def get_chat_completions_response(): 29 | expected_dict = { 30 | "id": "123", 31 | "choices": [ 32 | { 33 | "index": 0, 34 | "message": { 35 | "role": "assistant", 36 | "content": "I apologize for any inconvenience you're experiencing. Can you please provide me with " 37 | "more information about the issue you're facing? For example, are you receiving an " 38 | "error message when you try to sign up with your Google account? If so, what does the " 39 | "error message say?", 40 | "tool_calls": None, 41 | }, 42 | "logprobs": None, 43 | "finish_reason": "stop", 44 | } 45 | ], 46 | "usage": {"prompt_tokens": 105, "completion_tokens": 61, "total_tokens": 166}, 47 | } 48 | 49 | choice = ChatCompletionResponseChoice( 50 | index=0, 51 | message=AssistantMessage( 52 | role="assistant", 53 | content="I apologize for any inconvenience you're experiencing. Can you please provide me with more " 54 | "information about the issue you're facing? For example, are you receiving an error message when " 55 | "you try to sign up with your Google account? If so, what does the error message say?", 56 | ), 57 | finish_reason="stop", 58 | ) 59 | 60 | chat_completions_response = ChatCompletionResponse( 61 | id="123", 62 | choices=[choice], 63 | usage=UsageInfo(prompt_tokens=105, completion_tokens=61, total_tokens=166), 64 | ) 65 | 66 | return chat_completions_response, expected_dict, ChatCompletionResponse 67 | -------------------------------------------------------------------------------- /tests/unittests/models/test_serialization.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | import pytest 4 | 5 | from ai21.models import Penalty 6 | from ai21.models._pydantic_compatibility import _from_dict, _to_dict 7 | from ai21.models.ai21_base_model import IS_PYDANTIC_V2, AI21BaseModel 8 | from tests.unittests.models.response_mocks import ( 9 | get_chat_completions_response, 10 | get_chat_response, 11 | ) 12 | 13 | 14 | def test_penalty__to_dict__when_has_none_fields__should_filter_them_out(): 15 | penalty = Penalty(scale=0.5, apply_to_whitespaces=True) 16 | assert penalty.to_dict() == {"scale": 0.5, "applyToWhitespaces": True} 17 | 18 | 19 | def test_penalty__to_json__when_has_none_fields__should_filter_them_out(): 20 | penalty = Penalty(scale=0.5, apply_to_whitespaces=True) 21 | if IS_PYDANTIC_V2: 22 | assert penalty.to_json() == '{"scale":0.5,"applyToWhitespaces":true}' 23 | else: 24 | assert penalty.to_json() == '{"scale": 0.5, "applyToWhitespaces": true}' 25 | 26 | 27 | def test_penalty__from_dict__should_return_instance_with_given_values(): 28 | penalty = Penalty.from_dict({"scale": 0.5, "applyToWhitespaces": True}) 29 | assert penalty.scale == 0.5 30 | assert penalty.apply_to_whitespaces is True 31 | assert penalty.apply_to_emojis is None 32 | 33 | 34 | def test_penalty__from_json__should_return_instance_with_given_values(): 35 | penalty = Penalty.from_json('{"scale":0.5,"applyToWhitespaces":true}') 36 | assert penalty.scale == 0.5 37 | assert penalty.apply_to_whitespaces is True 38 | assert penalty.apply_to_emojis is None 39 | 40 | 41 | @pytest.mark.parametrize( 42 | ids=[ 43 | "chat_response", 44 | "chat_completions_response", 45 | ], 46 | argnames=[ 47 | "response_obj", 48 | "expected_dict", 49 | "response_cls", 50 | ], 51 | argvalues=[ 52 | (get_chat_response()), 53 | (get_chat_completions_response()), 54 | ], 55 | ) 56 | def test_to_dict__should_serialize_to_dict__( 57 | response_obj: AI21BaseModel, expected_dict: Dict[str, Any], response_cls: Any 58 | ): 59 | assert response_obj.to_dict() == expected_dict 60 | assert _to_dict(model_object=response_obj, by_alias=True) == expected_dict 61 | 62 | 63 | @pytest.mark.parametrize( 64 | ids=[ 65 | "chat_response", 66 | "chat_completions_response", 67 | ], 68 | argnames=[ 69 | "response_obj", 70 | "expected_dict", 71 | "response_cls", 72 | ], 73 | argvalues=[ 74 | (get_chat_response()), 75 | (get_chat_completions_response()), 76 | ], 77 | ) 78 | def test_from_dict__should_serialize_from_dict__( 79 | response_obj: AI21BaseModel, 80 | expected_dict: Dict[str, Any], 81 | response_cls: Any, 82 | ): 83 | assert response_cls.from_dict(expected_dict) == response_obj 84 | assert _from_dict(obj=response_obj, obj_dict=expected_dict) == response_obj 85 | -------------------------------------------------------------------------------- /tests/unittests/test_ai21_env_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from contextlib import contextmanager 3 | 4 | from ai21 import AI21Client 5 | 6 | _FAKE_API_KEY = "fake-key" 7 | os.environ["AI21_API_KEY"] = _FAKE_API_KEY 8 | 9 | 10 | @contextmanager 11 | def set_env_var(key: str, value: str): 12 | os.environ[key] = value 13 | yield 14 | del os.environ[key] 15 | 16 | 17 | def test_env_config__when_set_via_init_and_env__should_be_taken_from_init(): 18 | client = AI21Client() 19 | assert client._api_key == _FAKE_API_KEY 20 | 21 | init_api_key = "init-key" 22 | client2 = AI21Client(api_key=init_api_key) 23 | 24 | assert client2._api_key == init_api_key 25 | 26 | 27 | def test_env_config__when_set_twice__should_be_updated(): 28 | client = AI21Client() 29 | 30 | assert client._api_key == _FAKE_API_KEY 31 | 32 | new_api_key = "new-key" 33 | 34 | with set_env_var("AI21_API_KEY", new_api_key): 35 | client2 = AI21Client() 36 | assert client2._api_key == new_api_key 37 | 38 | 39 | def test_env_config__when_set_int__should_be_set(): 40 | with set_env_var("AI21_TIMEOUT_SEC", "1"): 41 | client = AI21Client() 42 | 43 | assert client._timeout_sec == 1 44 | -------------------------------------------------------------------------------- /tests/unittests/test_aws_authorization.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | 3 | from ai21.clients.aws.aws_authorization import AWSAuthorization 4 | 5 | 6 | class MockCredentials: 7 | def __init__(self, access_key, secret_key, token): 8 | self.access_key = access_key 9 | self.secret_key = secret_key 10 | self.token = token 11 | self.method = "env" 12 | 13 | 14 | def test_prepare_auth_headers__(mock_boto_session: boto3.Session): 15 | auth_headers_keys = ["X-Amz-Date", "X-Amz-Security-Token", "Authorization"] 16 | mock_boto_session.get_credentials.return_value = MockCredentials( 17 | access_key="some-key", secret_key="some-secret", token="some-token" 18 | ) 19 | aws_authorization = AWSAuthorization(aws_session=mock_boto_session) 20 | headers = aws_authorization.get_auth_headers( 21 | url="https://dummy.com", 22 | service_name="bedrock", 23 | method="POST", 24 | data='{"foo": "bar"}', 25 | ) 26 | 27 | for auth_header in auth_headers_keys: 28 | assert auth_header in headers 29 | -------------------------------------------------------------------------------- /tests/unittests/test_aws_stream_decoder.py: -------------------------------------------------------------------------------- 1 | from typing import AsyncIterable, Iterable 2 | 3 | import httpx 4 | import pytest 5 | 6 | from ai21.clients.bedrock._stream_decoder import _AWSEventStreamDecoder 7 | from ai21.errors import StreamingDecodeError 8 | from ai21.models.chat import ChatCompletionChunk 9 | from ai21.stream.async_stream import AsyncStream 10 | from ai21.stream.stream import Stream 11 | 12 | 13 | def byte_stream() -> Iterable[bytes]: 14 | for i in range(10): 15 | yield ( 16 | b"\x00\x00\x01\x80\x00\x00\x00K\xfe\x96$F\x0b:event-type\x07\x00\x05chunk\r:content-type\x07\x00" 17 | b"\x10application/json\r:message-type\x07\x00\x05event{" 18 | b'"bytes":"eyJpZCI6ImNtcGwtOTgxZjdmMTc2YWQ0NDE0NTliOTRlNDVlZTI5MmEzMjEiLCJjaG9pY2VzIjpbeyJpbmRleC' 19 | b"I6MCwiZGVsdGEiOnsicm9sZSI6ImFzc2lzdGFudCJ9LCJmaW5pc2hfcmVhc29uIjpudWxsfV0sInVzYWdlIjp7InByb21wd" 20 | b'F90b2tlbnMiOjQ0LCJ0b3RhbF90b2tlbnMiOjQ0LCJjb21wbGV0aW9uX3Rva2VucyI6MH19","p":"abcdefghijklmnopq' 21 | b'rstuv"}5\xca\xa7\x98' 22 | ) 23 | 24 | 25 | async def async_byte_stream() -> AsyncIterable[bytes]: 26 | for i in range(10): 27 | yield ( 28 | b"\x00\x00\x01\x80\x00\x00\x00K\xfe\x96$F\x0b:event-type\x07\x00\x05chunk\r:content-type\x07\x00" 29 | b"\x10application/json\r:message-type\x07\x00\x05event{" 30 | b'"bytes":"eyJpZCI6ImNtcGwtOTgxZjdmMTc2YWQ0NDE0NTliOTRlNDVlZTI5MmEzMjEiLCJjaG9pY2VzIjpbeyJpbmRleC' 31 | b"I6MCwiZGVsdGEiOnsicm9sZSI6ImFzc2lzdGFudCJ9LCJmaW5pc2hfcmVhc29uIjpudWxsfV0sInVzYWdlIjp7InByb21wd" 32 | b'F90b2tlbnMiOjQ0LCJ0b3RhbF90b2tlbnMiOjQ0LCJjb21wbGV0aW9uX3Rva2VucyI6MH19","p":"abcdefghijklmnopq' 33 | b'rstuv"}5\xca\xa7\x98' 34 | ) 35 | 36 | 37 | def byte_bad_stream_json_format() -> AsyncIterable[bytes]: 38 | msg = "data: not a json format\r\n" 39 | yield msg.encode("utf-8") 40 | 41 | 42 | async def async_byte_bad_stream_json_format() -> AsyncIterable[bytes]: 43 | msg = "data: not a json format\r\n" 44 | yield msg.encode("utf-8") 45 | 46 | 47 | def test_stream_object_when_json_string_ok__should_be_ok(): 48 | stream = byte_stream() 49 | response = httpx.Response(status_code=200, content=stream) 50 | stream_obj = Stream[ChatCompletionChunk]( 51 | response=response, cast_to=ChatCompletionChunk, streaming_decoder=_AWSEventStreamDecoder() 52 | ) 53 | 54 | chunk_counter = 0 55 | for i, chunk in enumerate(stream_obj): 56 | assert isinstance(chunk, ChatCompletionChunk) 57 | chunk_counter += 1 58 | 59 | assert chunk_counter == 10 60 | 61 | 62 | @pytest.mark.asyncio 63 | async def test_async_stream_object_when_json_string_ok__should_be_ok(): 64 | stream = async_byte_stream() 65 | response = httpx.Response(status_code=200, content=stream) 66 | stream_obj = AsyncStream[ChatCompletionChunk]( 67 | response=response, cast_to=ChatCompletionChunk, streaming_decoder=_AWSEventStreamDecoder() 68 | ) 69 | 70 | chunk_counter = 0 71 | async for chunk in stream_obj: 72 | assert isinstance(chunk, ChatCompletionChunk) 73 | chunk_counter += 1 74 | 75 | assert chunk_counter == 10 76 | 77 | 78 | def test_stream_object_when_bad_json__should_raise_error(): 79 | stream = byte_bad_stream_json_format() 80 | response = httpx.Response(status_code=200, content=stream) 81 | stream_obj = Stream[ChatCompletionChunk]( 82 | response=response, cast_to=ChatCompletionChunk, streaming_decoder=_AWSEventStreamDecoder() 83 | ) 84 | 85 | with pytest.raises(StreamingDecodeError): 86 | for _ in stream_obj: 87 | pass 88 | 89 | 90 | @pytest.mark.asyncio 91 | async def test_async_stream_object_when_bad_json__should_raise_error(): 92 | stream = async_byte_bad_stream_json_format() 93 | response = httpx.Response(status_code=200, content=stream) 94 | stream_obj = AsyncStream[ChatCompletionChunk]( 95 | response=response, cast_to=ChatCompletionChunk, streaming_decoder=_AWSEventStreamDecoder() 96 | ) 97 | 98 | with pytest.raises(StreamingDecodeError): 99 | async for _ in stream_obj: 100 | pass 101 | -------------------------------------------------------------------------------- /tests/unittests/test_imports.py: -------------------------------------------------------------------------------- 1 | # The following line should not be removed, as it is used to test the imports in runtime 2 | # noinspection PyUnresolvedReferences 3 | from ai21 import * # noqa: F403 4 | from ai21 import __all__ 5 | 6 | 7 | EXPECTED_ALL = [ 8 | "AI21EnvConfig", 9 | "AI21Client", 10 | "AsyncAI21Client", 11 | "AI21APIError", 12 | "APITimeoutError", 13 | "AI21Error", 14 | "MissingApiKeyError", 15 | "ModelPackageDoesntExistError", 16 | "TooManyRequestsError", 17 | "AI21BedrockClient", 18 | "BedrockModelID", 19 | "AI21AzureClient", 20 | "AsyncAI21AzureClient", 21 | "AsyncAI21BedrockClient", 22 | "AI21VertexClient", 23 | "AsyncAI21VertexClient", 24 | "AI21LaunchpadClient", 25 | "AsyncAI21LaunchpadClient", 26 | ] 27 | 28 | 29 | def test_all_imports() -> None: 30 | assert sorted(EXPECTED_ALL) == sorted(__all__) 31 | -------------------------------------------------------------------------------- /tests/unittests/test_stream.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import AsyncIterable 3 | 4 | import httpx 5 | import pytest 6 | 7 | from ai21.errors import StreamingDecodeError 8 | from ai21.models.ai21_base_model import AI21BaseModel 9 | from ai21.stream.stream import Stream 10 | from ai21.stream.async_stream import AsyncStream 11 | 12 | 13 | class StubStreamObject(AI21BaseModel): 14 | id: str 15 | name: str 16 | 17 | 18 | def byte_stream() -> AsyncIterable[bytes]: 19 | for i in range(10): 20 | data = {"id": f"some-{i}", "name": f"some-name-{i}"} 21 | msg = f"data: {json.dumps(data)}\r\n" 22 | yield msg.encode("utf-8") 23 | 24 | 25 | async def async_byte_stream() -> AsyncIterable[bytes]: 26 | for i in range(10): 27 | data = {"id": f"some-{i}", "name": f"some-name-{i}"} 28 | msg = f"data: {json.dumps(data)}\r\n" 29 | yield msg.encode("utf-8") 30 | 31 | 32 | async def async_byte_bad_stream_prefix() -> AsyncIterable[bytes]: 33 | msg = "bad_stream: {}\r\n" 34 | yield msg.encode("utf-8") 35 | 36 | 37 | def byte_bad_stream_prefix() -> AsyncIterable[bytes]: 38 | msg = "bad_stream: {}\r\n" 39 | yield msg.encode("utf-8") 40 | 41 | 42 | def byte_bad_stream_json_format() -> AsyncIterable[bytes]: 43 | msg = "data: not a json format\r\n" 44 | yield msg.encode("utf-8") 45 | 46 | 47 | async def async_byte_bad_stream_json_format() -> AsyncIterable[bytes]: 48 | msg = "data: not a json format\r\n" 49 | yield msg.encode("utf-8") 50 | 51 | 52 | def test_stream_object_when_json_string_ok__should_be_ok(): 53 | stream = byte_stream() 54 | response = httpx.Response(status_code=200, content=stream) 55 | stream_obj = Stream[StubStreamObject](response=response, cast_to=StubStreamObject) 56 | 57 | for i, chunk in enumerate(stream_obj): 58 | assert isinstance(chunk, StubStreamObject) 59 | assert chunk.name == f"some-name-{i}" 60 | assert chunk.id == f"some-{i}" 61 | 62 | 63 | @pytest.mark.parametrize( 64 | ids=[ 65 | "bad_stream_data_prefix", 66 | "bad_stream_json_format", 67 | ], 68 | argnames=["stream"], 69 | argvalues=[ 70 | (byte_bad_stream_prefix(),), 71 | (byte_bad_stream_json_format(),), 72 | ], 73 | ) 74 | def test_stream_object_when_bad_json__should_raise_error(stream): 75 | response = httpx.Response(status_code=200, content=stream) 76 | stream_obj = Stream[StubStreamObject](response=response, cast_to=StubStreamObject) 77 | 78 | with pytest.raises(StreamingDecodeError): 79 | for _ in stream_obj: 80 | pass 81 | 82 | 83 | @pytest.mark.asyncio 84 | async def test_async_stream_object_when_json_string_ok__should_be_ok(): 85 | stream = async_byte_stream() 86 | response = httpx.Response(status_code=200, content=stream) 87 | stream_obj = AsyncStream[StubStreamObject](response=response, cast_to=StubStreamObject) 88 | 89 | index = 0 90 | async for chunk in stream_obj: 91 | assert isinstance(chunk, StubStreamObject) 92 | assert chunk.name == f"some-name-{index}" 93 | assert chunk.id == f"some-{index}" 94 | index += 1 95 | 96 | 97 | @pytest.mark.asyncio 98 | @pytest.mark.parametrize( 99 | ids=[ 100 | "bad_stream_data_prefix", 101 | "bad_stream_json_format", 102 | ], 103 | argnames=["stream"], 104 | argvalues=[ 105 | (async_byte_bad_stream_prefix(),), 106 | (async_byte_bad_stream_json_format(),), 107 | ], 108 | ) 109 | async def test_async_stream_object_when_bad_json__should_raise_error(stream): 110 | response = httpx.Response(status_code=200, content=stream) 111 | stream_obj = AsyncStream[StubStreamObject](response=response, cast_to=StubStreamObject) 112 | 113 | with pytest.raises(StreamingDecodeError): 114 | async for _ in stream_obj: 115 | pass 116 | -------------------------------------------------------------------------------- /tests/unittests/tokenizers/test_ai21_tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pytest 4 | from ai21.tokenizers.factory import get_tokenizer 5 | 6 | 7 | class TestAI21Tokenizer: 8 | @pytest.mark.parametrize( 9 | ids=[ 10 | "when_j2_tokenizer", 11 | "when_jamba_instruct_tokenizer", 12 | ], 13 | argnames=["tokenizer_name", "expected_tokens"], 14 | argvalues=[ 15 | ("j2-tokenizer", 8), 16 | ("jamba-tokenizer", 9), 17 | ], 18 | ) 19 | def test__count_tokens__should_return_number_of_tokens(self, tokenizer_name: str, expected_tokens: int): 20 | tokenizer = get_tokenizer(tokenizer_name) 21 | 22 | actual_number_of_tokens = tokenizer.count_tokens("Text to Tokenize - Hello world!") 23 | 24 | assert actual_number_of_tokens == expected_tokens 25 | 26 | @pytest.mark.parametrize( 27 | ids=[ 28 | "when_j2_tokenizer", 29 | "when_jamba_instruct_tokenizer", 30 | ], 31 | argnames=["tokenizer_name", "expected_tokens"], 32 | argvalues=[ 33 | ("j2-tokenizer", ["▁Text", "▁to", "▁Token", "ize", "▁-", "▁Hello", "▁world", "!"]), 34 | ( 35 | "jamba-tokenizer", 36 | ["<|startoftext|>", "Text", "▁to", "▁Token", "ize", "▁-", "▁Hello", "▁world", "!"], 37 | ), 38 | ], 39 | ) 40 | def test__tokenize__should_return_list_of_tokens(self, tokenizer_name: str, expected_tokens: List[str]): 41 | tokenizer = get_tokenizer(tokenizer_name) 42 | 43 | actual_tokens = tokenizer.tokenize("Text to Tokenize - Hello world!") 44 | 45 | assert actual_tokens == expected_tokens 46 | 47 | @pytest.mark.parametrize( 48 | ids=[ 49 | "when_j2_tokenizer", 50 | "when_jamba_instruct_tokenizer", 51 | ], 52 | argnames=["tokenizer_name"], 53 | argvalues=[ 54 | ("j2-tokenizer",), 55 | ("jamba-tokenizer",), 56 | ], 57 | ) 58 | def test__detokenize__should_return_list_of_tokens(self, tokenizer_name: str): 59 | tokenizer = get_tokenizer(tokenizer_name) 60 | original_text = "Text to Tokenize - Hello world!" 61 | actual_tokens = tokenizer.tokenize(original_text) 62 | detokenized_text = tokenizer.detokenize(actual_tokens) 63 | 64 | assert original_text == detokenized_text 65 | 66 | def test__tokenizer__should_be_singleton__when_called_twice(self): 67 | tokenizer1 = get_tokenizer() 68 | tokenizer2 = get_tokenizer() 69 | 70 | assert tokenizer1 is tokenizer2 71 | 72 | def test__get_tokenizer__when_called_with_different_tokenizer_name__should_return_different_tokenizer(self): 73 | tokenizer1 = get_tokenizer("j2-tokenizer") 74 | tokenizer2 = get_tokenizer("jamba-tokenizer") 75 | 76 | assert tokenizer1._tokenizer is not tokenizer2._tokenizer 77 | 78 | def test__get_tokenizer__when_tokenizer_name_not_supported__should_raise_error(self): 79 | with pytest.raises(ValueError): 80 | get_tokenizer("some-tokenizer") 81 | -------------------------------------------------------------------------------- /tests/unittests/tokenizers/test_async_ai21_tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pytest 4 | from ai21.tokenizers.factory import get_async_tokenizer 5 | 6 | 7 | class TestAsyncAI21Tokenizer: 8 | @pytest.mark.asyncio 9 | @pytest.mark.parametrize( 10 | ids=[ 11 | "when_j2_tokenizer", 12 | "when_jamba_instruct_tokenizer", 13 | ], 14 | argnames=["tokenizer_name", "expected_tokens"], 15 | argvalues=[ 16 | ("j2-tokenizer", 8), 17 | ("jamba-tokenizer", 9), 18 | ], 19 | ) 20 | async def test__count_tokens__should_return_number_of_tokens(self, tokenizer_name: str, expected_tokens: int): 21 | tokenizer = await get_async_tokenizer(tokenizer_name) 22 | 23 | actual_number_of_tokens = await tokenizer.count_tokens("Text to Tokenize - Hello world!") 24 | 25 | assert actual_number_of_tokens == expected_tokens 26 | 27 | @pytest.mark.asyncio 28 | @pytest.mark.parametrize( 29 | ids=[ 30 | "when_j2_tokenizer", 31 | "when_jamba_instruct_tokenizer", 32 | ], 33 | argnames=["tokenizer_name", "expected_tokens"], 34 | argvalues=[ 35 | ("j2-tokenizer", ["▁Text", "▁to", "▁Token", "ize", "▁-", "▁Hello", "▁world", "!"]), 36 | ( 37 | "jamba-tokenizer", 38 | ["<|startoftext|>", "Text", "▁to", "▁Token", "ize", "▁-", "▁Hello", "▁world", "!"], 39 | ), 40 | ], 41 | ) 42 | async def test__tokenize__should_return_list_of_tokens(self, tokenizer_name: str, expected_tokens: List[str]): 43 | tokenizer = await get_async_tokenizer(tokenizer_name) 44 | 45 | actual_tokens = await tokenizer.tokenize("Text to Tokenize - Hello world!") 46 | 47 | assert actual_tokens == expected_tokens 48 | 49 | @pytest.mark.asyncio 50 | @pytest.mark.parametrize( 51 | ids=[ 52 | "when_j2_tokenizer", 53 | "when_jamba_instruct_tokenizer", 54 | ], 55 | argnames=["tokenizer_name"], 56 | argvalues=[ 57 | ("j2-tokenizer",), 58 | ("jamba-tokenizer",), 59 | ], 60 | ) 61 | async def test__detokenize__should_return_list_of_tokens(self, tokenizer_name: str): 62 | tokenizer = await get_async_tokenizer(tokenizer_name) 63 | original_text = "Text to Tokenize - Hello world!" 64 | actual_tokens = await tokenizer.tokenize(original_text) 65 | detokenized_text = await tokenizer.detokenize(actual_tokens) 66 | 67 | assert original_text == detokenized_text 68 | 69 | @pytest.mark.asyncio 70 | async def test__tokenizer__should_be_singleton__when_called_twice(self): 71 | tokenizer1 = await get_async_tokenizer() 72 | tokenizer2 = await get_async_tokenizer() 73 | 74 | assert tokenizer1 is tokenizer2 75 | 76 | @pytest.mark.asyncio 77 | async def test__get_tokenizer__when_called_with_different_tokenizer_name__should_return_different_tokenizer(self): 78 | tokenizer1 = await get_async_tokenizer("j2-tokenizer") 79 | tokenizer2 = await get_async_tokenizer("jamba-tokenizer") 80 | 81 | assert tokenizer1._tokenizer is not tokenizer2._tokenizer 82 | 83 | @pytest.mark.asyncio 84 | async def test__get_tokenizer__when_tokenizer_name_not_supported__should_raise_error(self): 85 | with pytest.raises(ValueError): 86 | await get_async_tokenizer("some-tokenizer") 87 | --------------------------------------------------------------------------------