├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.yml │ ├── config.yml │ └── feature_request.yml ├── PULL_REQUEST_TEMPLATE.md ├── dependabot.yml └── workflows │ ├── integration-test.yml │ ├── pr-and-push.yml │ ├── pypi-publish-on-release.yml │ └── test-lint.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE ├── NOTICE ├── README.md ├── STYLE_GUIDE.md ├── pyproject.toml ├── src └── strands │ ├── __init__.py │ ├── agent │ ├── __init__.py │ ├── agent.py │ ├── agent_result.py │ ├── conversation_manager │ │ ├── __init__.py │ │ ├── conversation_manager.py │ │ ├── null_conversation_manager.py │ │ ├── sliding_window_conversation_manager.py │ │ └── summarizing_conversation_manager.py │ └── state.py │ ├── event_loop │ ├── __init__.py │ ├── event_loop.py │ └── streaming.py │ ├── experimental │ ├── __init__.py │ └── hooks │ │ ├── __init__.py │ │ └── events.py │ ├── handlers │ ├── __init__.py │ └── callback_handler.py │ ├── hooks │ ├── __init__.py │ ├── events.py │ ├── registry.py │ └── rules.md │ ├── models │ ├── __init__.py │ ├── anthropic.py │ ├── bedrock.py │ ├── litellm.py │ ├── llamaapi.py │ ├── mistral.py │ ├── model.py │ ├── ollama.py │ ├── openai.py │ └── writer.py │ ├── multiagent │ ├── __init__.py │ ├── a2a │ │ ├── __init__.py │ │ ├── executor.py │ │ └── server.py │ ├── base.py │ ├── graph.py │ └── swarm.py │ ├── py.typed │ ├── session │ ├── __init__.py │ ├── file_session_manager.py │ ├── repository_session_manager.py │ ├── s3_session_manager.py │ ├── session_manager.py │ └── session_repository.py │ ├── telemetry │ ├── __init__.py │ ├── config.py │ ├── metrics.py │ ├── metrics_constants.py │ └── tracer.py │ ├── tools │ ├── __init__.py │ ├── decorator.py │ ├── executor.py │ ├── loader.py │ ├── mcp │ │ ├── __init__.py │ │ ├── mcp_agent_tool.py │ │ ├── mcp_client.py │ │ └── mcp_types.py │ ├── registry.py │ ├── structured_output.py │ ├── tools.py │ └── watcher.py │ └── types │ ├── __init__.py │ ├── collections.py │ ├── content.py │ ├── event_loop.py │ ├── exceptions.py │ ├── guardrails.py │ ├── media.py │ ├── session.py │ ├── streaming.py │ ├── tools.py │ └── traces.py ├── tests ├── __init__.py ├── conftest.py ├── fixtures │ ├── mock_hook_provider.py │ ├── mock_session_repository.py │ └── mocked_model_provider.py └── strands │ ├── __init__.py │ ├── agent │ ├── __init__.py │ ├── test_agent.py │ ├── test_agent_hooks.py │ ├── test_agent_result.py │ ├── test_agent_state.py │ ├── test_conversation_manager.py │ └── test_summarizing_conversation_manager.py │ ├── event_loop │ ├── __init__.py │ ├── test_event_loop.py │ └── test_streaming.py │ ├── experimental │ ├── __init__.py │ └── hooks │ │ ├── __init__.py │ │ ├── test_events.py │ │ └── test_hook_registry.py │ ├── handlers │ ├── __init__.py │ └── test_callback_handler.py │ ├── models │ ├── __init__.py │ ├── test_anthropic.py │ ├── test_bedrock.py │ ├── test_litellm.py │ ├── test_llamaapi.py │ ├── test_mistral.py │ ├── test_model.py │ ├── test_ollama.py │ ├── test_openai.py │ └── test_writer.py │ ├── multiagent │ ├── __init__.py │ ├── a2a │ │ ├── __init__.py │ │ ├── conftest.py │ │ ├── test_executor.py │ │ └── test_server.py │ ├── test_base.py │ ├── test_graph.py │ └── test_swarm.py │ ├── session │ ├── __init__.py │ ├── test_file_session_manager.py │ ├── test_repository_session_manager.py │ └── test_s3_session_manager.py │ ├── telemetry │ ├── test_config.py │ ├── test_metrics.py │ └── test_tracer.py │ ├── tools │ ├── __init__.py │ ├── mcp │ │ ├── __init__.py │ │ ├── test_mcp_agent_tool.py │ │ └── test_mcp_client.py │ ├── test_decorator.py │ ├── test_executor.py │ ├── test_loader.py │ ├── test_registry.py │ ├── test_structured_output.py │ ├── test_tools.py │ └── test_watcher.py │ └── types │ └── test_session.py └── tests_integ ├── __init__.py ├── conftest.py ├── echo_server.py ├── models ├── __init__.py ├── conformance.py ├── providers.py ├── test_model_anthropic.py ├── test_model_bedrock.py ├── test_model_cohere.py ├── test_model_litellm.py ├── test_model_llamaapi.py ├── test_model_mistral.py ├── test_model_ollama.py ├── test_model_openai.py └── test_model_writer.py ├── test_agent_async.py ├── test_bedrock_cache_point.py ├── test_bedrock_guardrails.py ├── test_context_overflow.py ├── test_function_tools.py ├── test_hot_tool_reload_decorator.py ├── test_mcp_client.py ├── test_multiagent_graph.py ├── test_multiagent_swarm.py ├── test_session.py ├── test_stream_agent.py ├── test_summarizing_conversation_manager_integration.py └── yellow.png /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | name: Bug Report 2 | description: Report a bug in the Strands Agents SDK 3 | title: "[BUG] " 4 | labels: ["bug", "triage"] 5 | assignees: [] 6 | body: 7 | - type: markdown 8 | attributes: 9 | value: | 10 | Thanks for taking the time to fill out this bug report for Strands SDK! 11 | - type: checkboxes 12 | id: "checks" 13 | attributes: 14 | label: "Checks" 15 | options: 16 | - label: "I have updated to the lastest minor and patch version of Strands" 17 | required: true 18 | - label: "I have checked the documentation and this is not expected behavior" 19 | required: true 20 | - label: "I have searched [./issues](./issues?q=) and there are no duplicates of my issue" 21 | required: true 22 | - type: input 23 | id: strands-version 24 | attributes: 25 | label: Strands Version 26 | description: Which version of Strands are you using? 27 | placeholder: e.g., 0.5.2 28 | validations: 29 | required: true 30 | - type: input 31 | id: python-version 32 | attributes: 33 | label: Python Version 34 | description: Which version of Python are you using? 35 | placeholder: e.g., 3.10.5 36 | validations: 37 | required: true 38 | - type: input 39 | id: os 40 | attributes: 41 | label: Operating System 42 | description: Which operating system are you using? 43 | placeholder: e.g., macOS 12.6 44 | validations: 45 | required: true 46 | - type: dropdown 47 | id: installation-method 48 | attributes: 49 | label: Installation Method 50 | description: How did you install Strands? 51 | options: 52 | - pip 53 | - git clone 54 | - binary 55 | - other 56 | validations: 57 | required: true 58 | - type: textarea 59 | id: steps-to-reproduce 60 | attributes: 61 | label: Steps to Reproduce 62 | description: Detailed steps to reproduce the behavior 63 | placeholder: | 64 | 1. Install Strands using... 65 | 2. Run the command... 66 | 3. See error... 67 | validations: 68 | required: true 69 | - type: textarea 70 | id: expected-behavior 71 | attributes: 72 | label: Expected Behavior 73 | description: A clear description of what you expected to happen 74 | validations: 75 | required: true 76 | - type: textarea 77 | id: actual-behavior 78 | attributes: 79 | label: Actual Behavior 80 | description: What actually happened 81 | validations: 82 | required: true 83 | - type: textarea 84 | id: additional-context 85 | attributes: 86 | label: Additional Context 87 | description: Any other relevant information, logs, screenshots, etc. 88 | - type: textarea 89 | id: possible-solution 90 | attributes: 91 | label: Possible Solution 92 | description: Optional - If you have suggestions on how to fix the bug 93 | - type: input 94 | id: related-issues 95 | attributes: 96 | label: Related Issues 97 | description: Optional - Link to related issues if applicable 98 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: Strands Agents SDK Support 4 | url: https://github.com/strands-agents/sdk-python/discussions 5 | about: Please ask and answer questions here 6 | - name: Strands Agents SDK Documentation 7 | url: https://github.com/strands-agents/docs 8 | about: Visit our documentation for help 9 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | name: Feature Request 2 | description: Suggest a new feature or enhancement for Strands Agents SDK 3 | title: "[FEATURE] " 4 | labels: ["enhancement", "triage"] 5 | assignees: [] 6 | body: 7 | - type: markdown 8 | attributes: 9 | value: | 10 | Thanks for suggesting a new feature for Strands Agents SDK! 11 | - type: textarea 12 | id: problem-statement 13 | attributes: 14 | label: Problem Statement 15 | description: Describe the problem you're trying to solve. What is currently difficult or impossible to do? 16 | placeholder: I would like Strands to... 17 | validations: 18 | required: true 19 | - type: textarea 20 | id: proposed-solution 21 | attributes: 22 | label: Proposed Solution 23 | description: Optional - Describe your proposed solution in detail. How would this feature work? 24 | - type: textarea 25 | id: use-case 26 | attributes: 27 | label: Use Case 28 | description: Provide specific use cases for the feature. How would people use it? 29 | placeholder: This would help with... 30 | validations: 31 | required: true 32 | - type: textarea 33 | id: alternatives-solutions 34 | attributes: 35 | label: Alternatives Solutions 36 | description: Optional - Have you considered alternative approaches? What are their pros and cons? 37 | - type: textarea 38 | id: additional-context 39 | attributes: 40 | label: Additional Context 41 | description: Include any other context, screenshots, code examples, or references that might help understand the feature request. 42 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | <!-- Provide a detailed description of the changes in this PR --> 3 | 4 | ## Related Issues 5 | 6 | <!-- Link to related issues using #issue-number format --> 7 | 8 | ## Documentation PR 9 | 10 | <!-- Link to related associated PR in the agent-docs repo --> 11 | 12 | ## Type of Change 13 | 14 | <!-- Choose one of the following types of changes, delete the rest --> 15 | 16 | Bug fix 17 | New feature 18 | Breaking change 19 | Documentation update 20 | Other (please describe): 21 | 22 | ## Testing 23 | 24 | How have you tested the change? Verify that the changes do not break functionality or introduce warnings in consuming repositories: agents-docs, agents-tools, agents-cli 25 | 26 | - [ ] I ran `hatch run prepare` 27 | 28 | ## Checklist 29 | - [ ] I have read the CONTRIBUTING document 30 | - [ ] I have added any necessary tests that prove my fix is effective or my feature works 31 | - [ ] I have updated the documentation accordingly 32 | - [ ] I have added an appropriate example to the documentation to outline the feature, or no new docs are needed 33 | - [ ] My changes generate no new warnings 34 | - [ ] Any dependent changes have been merged and published 35 | 36 | ---- 37 | 38 | By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. 39 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "pip" 4 | directory: "/" 5 | schedule: 6 | interval: "daily" 7 | open-pull-requests-limit: 100 8 | commit-message: 9 | prefix: ci 10 | groups: 11 | dev-dependencies: 12 | patterns: 13 | - "pytest" 14 | - package-ecosystem: "github-actions" 15 | directory: "/" 16 | schedule: 17 | interval: "daily" 18 | open-pull-requests-limit: 100 19 | commit-message: 20 | prefix: ci 21 | -------------------------------------------------------------------------------- /.github/workflows/integration-test.yml: -------------------------------------------------------------------------------- 1 | name: Secure Integration test 2 | 3 | on: 4 | pull_request_target: 5 | branches: main 6 | 7 | jobs: 8 | authorization-check: 9 | permissions: read-all 10 | runs-on: ubuntu-latest 11 | outputs: 12 | approval-env: ${{ steps.collab-check.outputs.result }} 13 | steps: 14 | - name: Collaborator Check 15 | uses: actions/github-script@v7 16 | id: collab-check 17 | with: 18 | result-encoding: string 19 | script: | 20 | try { 21 | const permissionResponse = await github.rest.repos.getCollaboratorPermissionLevel({ 22 | owner: context.repo.owner, 23 | repo: context.repo.repo, 24 | username: context.payload.pull_request.user.login, 25 | }); 26 | const permission = permissionResponse.data.permission; 27 | const hasWriteAccess = ['write', 'admin'].includes(permission); 28 | if (!hasWriteAccess) { 29 | console.log(`User ${context.payload.pull_request.user.login} does not have write access to the repository (permission: ${permission})`); 30 | return "manual-approval" 31 | } else { 32 | console.log(`Verifed ${context.payload.pull_request.user.login} has write access. Auto Approving PR Checks.`) 33 | return "auto-approve" 34 | } 35 | } catch (error) { 36 | console.log(`${context.payload.pull_request.user.login} does not have write access. Requiring Manual Approval to run PR Checks.`) 37 | return "manual-approval" 38 | } 39 | check-access-and-checkout: 40 | runs-on: ubuntu-latest 41 | needs: authorization-check 42 | environment: ${{ needs.authorization-check.outputs.approval-env }} 43 | permissions: 44 | id-token: write 45 | pull-requests: read 46 | contents: read 47 | steps: 48 | - name: Configure Credentials 49 | uses: aws-actions/configure-aws-credentials@v4 50 | with: 51 | role-to-assume: ${{ secrets.STRANDS_INTEG_TEST_ROLE }} 52 | aws-region: us-east-1 53 | mask-aws-account-id: true 54 | - name: Checkout head commit 55 | uses: actions/checkout@v4 56 | with: 57 | ref: ${{ github.event.pull_request.head.sha }} # Pull the commit from the forked repo 58 | persist-credentials: false # Don't persist credentials for subsequent actions 59 | - name: Set up Python 60 | uses: actions/setup-python@v5 61 | with: 62 | python-version: '3.10' 63 | - name: Install dependencies 64 | run: | 65 | pip install --no-cache-dir hatch 66 | - name: Run integration tests 67 | env: 68 | AWS_REGION: us-east-1 69 | AWS_REGION_NAME: us-east-1 # Needed for LiteLLM 70 | id: tests 71 | run: | 72 | hatch test tests_integ 73 | -------------------------------------------------------------------------------- /.github/workflows/pr-and-push.yml: -------------------------------------------------------------------------------- 1 | name: Pull Request and Push Action 2 | 3 | on: 4 | pull_request: # Safer than pull_request_target for untrusted code 5 | branches: [ main ] 6 | types: [opened, synchronize, reopened, ready_for_review, review_requested, review_request_removed] 7 | push: 8 | branches: [ main ] # Also run on direct pushes to main 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} 11 | cancel-in-progress: true 12 | 13 | jobs: 14 | call-test-lint: 15 | uses: ./.github/workflows/test-lint.yml 16 | permissions: 17 | contents: read 18 | with: 19 | ref: ${{ github.event.pull_request.head.sha }} 20 | -------------------------------------------------------------------------------- /.github/workflows/pypi-publish-on-release.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python Package 2 | 3 | on: 4 | release: 5 | types: 6 | - published 7 | 8 | jobs: 9 | call-test-lint: 10 | uses: ./.github/workflows/test-lint.yml 11 | permissions: 12 | contents: read 13 | with: 14 | ref: ${{ github.event.release.target_commitish }} 15 | 16 | build: 17 | name: Build distribution 📦 18 | permissions: 19 | contents: read 20 | needs: 21 | - call-test-lint 22 | runs-on: ubuntu-latest 23 | 24 | steps: 25 | - uses: actions/checkout@v4 26 | with: 27 | persist-credentials: false 28 | 29 | - name: Set up Python 30 | uses: actions/setup-python@v5 31 | with: 32 | python-version: '3.10' 33 | 34 | - name: Install dependencies 35 | run: | 36 | python -m pip install --upgrade pip 37 | pip install hatch twine 38 | 39 | - name: Validate version 40 | run: | 41 | version=$(hatch version) 42 | if [[ $version =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then 43 | echo "Valid version format" 44 | exit 0 45 | else 46 | echo "Invalid version format" 47 | exit 1 48 | fi 49 | 50 | - name: Build 51 | run: | 52 | hatch build 53 | 54 | - name: Store the distribution packages 55 | uses: actions/upload-artifact@v4 56 | with: 57 | name: python-package-distributions 58 | path: dist/ 59 | 60 | deploy: 61 | name: Upload release to PyPI 62 | needs: 63 | - build 64 | runs-on: ubuntu-latest 65 | 66 | # environment is used by PyPI Trusted Publisher and is strongly encouraged 67 | # https://docs.pypi.org/trusted-publishers/adding-a-publisher/ 68 | environment: 69 | name: pypi 70 | url: https://pypi.org/p/strands-agents 71 | permissions: 72 | # IMPORTANT: this permission is mandatory for Trusted Publishing 73 | id-token: write 74 | 75 | steps: 76 | - name: Download all the dists 77 | uses: actions/download-artifact@v4 78 | with: 79 | name: python-package-distributions 80 | path: dist/ 81 | - name: Publish distribution 📦 to PyPI 82 | uses: pypa/gh-action-pypi-publish@release/v1 83 | -------------------------------------------------------------------------------- /.github/workflows/test-lint.yml: -------------------------------------------------------------------------------- 1 | name: Test and Lint 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | ref: 7 | required: true 8 | type: string 9 | 10 | jobs: 11 | unit-test: 12 | name: Unit Tests - Python ${{ matrix.python-version }} - ${{ matrix.os-name }} 13 | permissions: 14 | contents: read 15 | strategy: 16 | matrix: 17 | include: 18 | # Linux 19 | - os: ubuntu-latest 20 | os-name: 'linux' 21 | python-version: "3.10" 22 | - os: ubuntu-latest 23 | os-name: 'linux' 24 | python-version: "3.11" 25 | - os: ubuntu-latest 26 | os-name: 'linux' 27 | python-version: "3.12" 28 | - os: ubuntu-latest 29 | os-name: 'linux' 30 | python-version: "3.13" 31 | # Windows 32 | - os: windows-latest 33 | os-name: 'windows' 34 | python-version: "3.10" 35 | - os: windows-latest 36 | os-name: 'windows' 37 | python-version: "3.11" 38 | - os: windows-latest 39 | os-name: 'windows' 40 | python-version: "3.12" 41 | - os: windows-latest 42 | os-name: 'windows' 43 | python-version: "3.13" 44 | # MacOS - latest only; not enough runners for macOS 45 | - os: macos-latest 46 | os-name: 'macOS' 47 | python-version: "3.13" 48 | fail-fast: true 49 | runs-on: ${{ matrix.os }} 50 | env: 51 | LOG_LEVEL: DEBUG 52 | steps: 53 | - name: Checkout code 54 | uses: actions/checkout@v4 55 | with: 56 | ref: ${{ inputs.ref }} # Explicitly define which commit to check out 57 | persist-credentials: false # Don't persist credentials for subsequent actions 58 | - name: Set up Python 59 | uses: actions/setup-python@v5 60 | with: 61 | python-version: ${{ matrix.python-version }} 62 | - name: Install dependencies 63 | run: | 64 | pip install --no-cache-dir hatch 65 | - name: Run Unit tests 66 | id: tests 67 | run: hatch test tests --cover 68 | continue-on-error: false 69 | lint: 70 | name: Lint 71 | runs-on: ubuntu-latest 72 | permissions: 73 | contents: read 74 | steps: 75 | - name: Checkout code 76 | uses: actions/checkout@v4 77 | with: 78 | ref: ${{ inputs.ref }} 79 | persist-credentials: false 80 | 81 | - name: Set up Python 82 | uses: actions/setup-python@v5 83 | with: 84 | python-version: '3.10' 85 | cache: 'pip' 86 | 87 | - name: Install dependencies 88 | run: | 89 | pip install --no-cache-dir hatch 90 | 91 | - name: Run lint 92 | id: lint 93 | run: hatch run test-lint 94 | continue-on-error: false 95 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | __pycache__* 3 | .coverage* 4 | .env 5 | .venv 6 | .mypy_cache 7 | .pytest_cache 8 | .ruff_cache 9 | *.bak 10 | .vscode 11 | dist 12 | repl_state -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: local 3 | hooks: 4 | - id: hatch-format 5 | name: Format code 6 | entry: hatch fmt --formatter 7 | language: system 8 | pass_filenames: false 9 | types: [python] 10 | stages: [pre-commit] 11 | - id: hatch-lint 12 | name: Lint code 13 | entry: hatch run test-lint 14 | language: system 15 | pass_filenames: false 16 | types: [python] 17 | stages: [pre-commit] 18 | - id: hatch-test-lint 19 | name: Type linting 20 | entry: hatch run test-lint 21 | language: system 22 | pass_filenames: false 23 | types: [ python ] 24 | stages: [ pre-commit ] 25 | - id: hatch-test 26 | name: Unit tests 27 | entry: hatch test 28 | language: system 29 | pass_filenames: false 30 | types: [python] 31 | stages: [pre-commit] 32 | - id: commitizen-check 33 | name: Check commit message 34 | entry: hatch run cz check --commit-msg-file 35 | language: system 36 | stages: [commit-msg] -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the [Bug Reports](../../issues/new?template=bug_report.yml) file to report bugs or [Feature Requests](../../issues/new?template=feature_request.yml) to suggest features. 13 | 14 | For a list of known bugs and feature requests: 15 | - Check [Bug Reports](../../issues?q=is%3Aissue%20state%3Aopen%20label%3Abug) for currently tracked issues 16 | - See [Feature Requests](../../issues?q=is%3Aissue%20state%3Aopen%20label%3Aenhancement) for requested enhancements 17 | 18 | When filing an issue, please check for already tracked items 19 | 20 | Please try to include as much information as you can. Details like these are incredibly useful: 21 | 22 | * A reproducible test case or series of steps 23 | * The version of our code being used (commit ID) 24 | * Any modifications you've made relevant to the bug 25 | * Anything unusual about your environment or deployment 26 | 27 | 28 | ## Development Environment 29 | 30 | This project uses [hatchling](https://hatch.pypa.io/latest/build/#hatchling) as the build backend and [hatch](https://hatch.pypa.io/latest/) for development workflow management. 31 | 32 | ### Setting Up Your Development Environment 33 | 34 | 1. Entering virtual environment using `hatch` (recommended), then launch your IDE in the new shell. 35 | ```bash 36 | hatch shell dev 37 | ``` 38 | 39 | Alternatively, install development dependencies in a manually created virtual environment: 40 | ```bash 41 | pip install -e ".[dev]" && pip install -e ".[litellm]" 42 | ``` 43 | 44 | 45 | 2. Set up pre-commit hooks: 46 | ```bash 47 | pre-commit install -t pre-commit -t commit-msg 48 | ``` 49 | This will automatically run formatters and conventional commit checks on your code before each commit. 50 | 51 | 3. Run code formatters manually: 52 | ```bash 53 | hatch fmt --formatter 54 | ``` 55 | 56 | 4. Run linters: 57 | ```bash 58 | hatch fmt --linter 59 | ``` 60 | 61 | 5. Run unit tests: 62 | ```bash 63 | hatch test 64 | ``` 65 | 66 | 6. Run integration tests: 67 | ```bash 68 | hatch run test-integ 69 | ``` 70 | 71 | ### Pre-commit Hooks 72 | 73 | We use [pre-commit](https://pre-commit.com/) to automatically run quality checks before each commit. The hook will run `hatch run format`, `hatch run lint`, `hatch run test`, and `hatch run cz check` on when you make a commit, ensuring code consistency. 74 | 75 | The pre-commit hook is installed with: 76 | 77 | ```bash 78 | pre-commit install 79 | ``` 80 | 81 | You can also run the hooks manually on all files: 82 | 83 | ```bash 84 | pre-commit run --all-files 85 | ``` 86 | 87 | ### Code Formatting and Style Guidelines 88 | 89 | We use the following tools to ensure code quality: 90 | 1. **ruff** - For formatting and linting 91 | 2. **mypy** - For static type checking 92 | 93 | These tools are configured in the [pyproject.toml](./pyproject.toml) file. Please ensure your code passes all linting and type checks before submitting a pull request: 94 | 95 | ```bash 96 | # Run all checks 97 | hatch fmt --formatter 98 | hatch fmt --linter 99 | ``` 100 | 101 | If you're using an IDE like VS Code or PyCharm, consider configuring it to use these tools automatically. 102 | 103 | For additional details on styling, please see our dedicated [Style Guide](./STYLE_GUIDE.md). 104 | 105 | 106 | ## Contributing via Pull Requests 107 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 108 | 109 | 1. You are working against the latest source on the *main* branch. 110 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 111 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 112 | 113 | To send us a pull request, please: 114 | 115 | 1. Create a branch. 116 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 117 | 3. Format your code using `hatch fmt --formatter`. 118 | 4. Run linting checks with `hatch fmt --linter`. 119 | 5. Ensure local tests pass with `hatch test` and `hatch run test-integ`. 120 | 6. Commit to your branch using clear commit messages following the [Conventional Commits](https://www.conventionalcommits.org) specification. 121 | 7. Send us a pull request, answering any default questions in the pull request interface. 122 | 8. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 123 | 124 | 125 | ## Finding contributions to work on 126 | Looking at the existing issues is a great way to find something to contribute to. 127 | 128 | You can check: 129 | - Our known bugs list in [Bug Reports](../../issues?q=is%3Aissue%20state%3Aopen%20label%3Abug) for issues that need fixing 130 | - Feature requests in [Feature Requests](../../issues?q=is%3Aissue%20state%3Aopen%20label%3Aenhancement) for new functionality to implement 131 | 132 | 133 | ## Code of Conduct 134 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 135 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 136 | opensource-codeofconduct@amazon.com with any additional questions or comments. 137 | 138 | 139 | ## Security issue notifications 140 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 141 | 142 | 143 | ## Licensing 144 | 145 | See the [LICENSE](./LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 146 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -------------------------------------------------------------------------------- /STYLE_GUIDE.md: -------------------------------------------------------------------------------- 1 | # Style Guide 2 | 3 | ## Overview 4 | 5 | The Strands Agents style guide aims to establish consistent formatting, naming conventions, and structure across all code in the repository. We strive to make our code clean, readable, and maintainable. 6 | 7 | Where possible, we will codify these style guidelines into our linting rules and pre-commit hooks to automate enforcement and reduce the manual review burden. 8 | 9 | ## Log Formatting 10 | 11 | The format for Strands Agents logs is as follows: 12 | 13 | ```python 14 | logger.debug("field1=<%s>, field2=<%s>, ... | human readable message", field1, field2, ...) 15 | ``` 16 | 17 | ### Guidelines 18 | 19 | 1. **Context**: 20 | - Add context as `<FIELD>=<VALUE>` pairs at the beginning of the log 21 | - Many log services (CloudWatch, Splunk, etc.) look for these patterns to extract fields for searching 22 | - Use `,`'s to separate pairs 23 | - Enclose values in `<>` for readability 24 | - This is particularly helpful in displaying empty values (`field=` vs `field=<>`) 25 | - Use `%s` for string interpolation as recommended by Python logging 26 | - This is an optimization to skip string interpolation when the log level is not enabled 27 | 28 | 1. **Messages**: 29 | - Add human-readable messages at the end of the log 30 | - Use lowercase for consistency 31 | - Avoid punctuation (periods, exclamation points, etc.) to reduce clutter 32 | - Keep messages concise and focused on a single statement 33 | - If multiple statements are needed, separate them with the pipe character (`|`) 34 | - Example: `"processing request | starting validation"` 35 | 36 | ### Examples 37 | 38 | #### Good 39 | 40 | ```python 41 | logger.debug("user_id=<%s>, action=<%s> | user performed action", user_id, action) 42 | logger.info("request_id=<%s>, duration_ms=<%d> | request completed", request_id, duration) 43 | logger.warning("attempt=<%d>, max_attempts=<%d> | retry limit approaching", attempt, max_attempts) 44 | ``` 45 | 46 | #### Poor 47 | 48 | ```python 49 | # Avoid: No structured fields, direct variable interpolation in message 50 | logger.debug(f"User {user_id} performed action {action}") 51 | 52 | # Avoid: Inconsistent formatting, punctuation 53 | logger.info("Request completed in %d ms.", duration) 54 | 55 | # Avoid: No separation between fields and message 56 | logger.warning("Retry limit approaching! attempt=%d max_attempts=%d", attempt, max_attempts) 57 | ``` 58 | 59 | By following these log formatting guidelines, we ensure that logs are both human-readable and machine-parseable, making debugging and monitoring more efficient. 60 | -------------------------------------------------------------------------------- /src/strands/__init__.py: -------------------------------------------------------------------------------- 1 | """A framework for building, deploying, and managing AI agents.""" 2 | 3 | from . import agent, models, telemetry, types 4 | from .agent.agent import Agent 5 | from .tools.decorator import tool 6 | 7 | __all__ = ["Agent", "agent", "models", "tool", "types", "telemetry"] 8 | -------------------------------------------------------------------------------- /src/strands/agent/__init__.py: -------------------------------------------------------------------------------- 1 | """This package provides the core Agent interface and supporting components for building AI agents with the SDK. 2 | 3 | It includes: 4 | 5 | - Agent: The main interface for interacting with AI models and tools 6 | - ConversationManager: Classes for managing conversation history and context windows 7 | """ 8 | 9 | from .agent import Agent 10 | from .agent_result import AgentResult 11 | from .conversation_manager import ( 12 | ConversationManager, 13 | NullConversationManager, 14 | SlidingWindowConversationManager, 15 | SummarizingConversationManager, 16 | ) 17 | 18 | __all__ = [ 19 | "Agent", 20 | "AgentResult", 21 | "ConversationManager", 22 | "NullConversationManager", 23 | "SlidingWindowConversationManager", 24 | "SummarizingConversationManager", 25 | ] 26 | -------------------------------------------------------------------------------- /src/strands/agent/agent_result.py: -------------------------------------------------------------------------------- 1 | """Agent result handling for SDK. 2 | 3 | This module defines the AgentResult class which encapsulates the complete response from an agent's processing cycle. 4 | """ 5 | 6 | from dataclasses import dataclass 7 | from typing import Any 8 | 9 | from ..telemetry.metrics import EventLoopMetrics 10 | from ..types.content import Message 11 | from ..types.streaming import StopReason 12 | 13 | 14 | @dataclass 15 | class AgentResult: 16 | """Represents the last result of invoking an agent with a prompt. 17 | 18 | Attributes: 19 | stop_reason: The reason why the agent's processing stopped. 20 | message: The last message generated by the agent. 21 | metrics: Performance metrics collected during processing. 22 | state: Additional state information from the event loop. 23 | """ 24 | 25 | stop_reason: StopReason 26 | message: Message 27 | metrics: EventLoopMetrics 28 | state: Any 29 | 30 | def __str__(self) -> str: 31 | """Get the agent's last message as a string. 32 | 33 | This method extracts and concatenates all text content from the final message, ignoring any non-text content 34 | like images or structured data. 35 | 36 | Returns: 37 | The agent's last message as a string. 38 | """ 39 | content_array = self.message.get("content", []) 40 | 41 | result = "" 42 | for item in content_array: 43 | if isinstance(item, dict) and "text" in item: 44 | result += item.get("text", "") + "\n" 45 | 46 | return result 47 | -------------------------------------------------------------------------------- /src/strands/agent/conversation_manager/__init__.py: -------------------------------------------------------------------------------- 1 | """This package provides classes for managing conversation history during agent execution. 2 | 3 | It includes: 4 | 5 | - ConversationManager: Abstract base class defining the conversation management interface 6 | - NullConversationManager: A no-op implementation that does not modify conversation history 7 | - SlidingWindowConversationManager: An implementation that maintains a sliding window of messages to control context 8 | size while preserving conversation coherence 9 | - SummarizingConversationManager: An implementation that summarizes older context instead 10 | of simply trimming it 11 | 12 | Conversation managers help control memory usage and context length while maintaining relevant conversation state, which 13 | is critical for effective agent interactions. 14 | """ 15 | 16 | from .conversation_manager import ConversationManager 17 | from .null_conversation_manager import NullConversationManager 18 | from .sliding_window_conversation_manager import SlidingWindowConversationManager 19 | from .summarizing_conversation_manager import SummarizingConversationManager 20 | 21 | __all__ = [ 22 | "ConversationManager", 23 | "NullConversationManager", 24 | "SlidingWindowConversationManager", 25 | "SummarizingConversationManager", 26 | ] 27 | -------------------------------------------------------------------------------- /src/strands/agent/conversation_manager/conversation_manager.py: -------------------------------------------------------------------------------- 1 | """Abstract interface for conversation history management.""" 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import TYPE_CHECKING, Any, Optional 5 | 6 | from ...types.content import Message 7 | 8 | if TYPE_CHECKING: 9 | from ...agent.agent import Agent 10 | 11 | 12 | class ConversationManager(ABC): 13 | """Abstract base class for managing conversation history. 14 | 15 | This class provides an interface for implementing conversation management strategies to control the size of message 16 | arrays/conversation histories, helping to: 17 | 18 | - Manage memory usage 19 | - Control context length 20 | - Maintain relevant conversation state 21 | """ 22 | 23 | def __init__(self) -> None: 24 | """Initialize the ConversationManager. 25 | 26 | Attributes: 27 | removed_message_count: The messages that have been removed from the agents messages array. 28 | These represent messages provided by the user or LLM that have been removed, not messages 29 | included by the conversation manager through something like summarization. 30 | """ 31 | self.removed_message_count = 0 32 | 33 | def restore_from_session(self, state: dict[str, Any]) -> Optional[list[Message]]: 34 | """Restore the Conversation Manager's state from a session. 35 | 36 | Args: 37 | state: Previous state of the conversation manager 38 | Returns: 39 | Optional list of messages to prepend to the agents messages. By defualt returns None. 40 | """ 41 | if state.get("__name__") != self.__class__.__name__: 42 | raise ValueError("Invalid conversation manager state.") 43 | self.removed_message_count = state["removed_message_count"] 44 | return None 45 | 46 | def get_state(self) -> dict[str, Any]: 47 | """Get the current state of a Conversation Manager as a Json serializable dictionary.""" 48 | return { 49 | "__name__": self.__class__.__name__, 50 | "removed_message_count": self.removed_message_count, 51 | } 52 | 53 | @abstractmethod 54 | def apply_management(self, agent: "Agent", **kwargs: Any) -> None: 55 | """Applies management strategy to the provided agent. 56 | 57 | Processes the conversation history to maintain appropriate size by modifying the messages list in-place. 58 | Implementations should handle message pruning, summarization, or other size management techniques to keep the 59 | conversation context within desired bounds. 60 | 61 | Args: 62 | agent: The agent whose conversation history will be manage. 63 | This list is modified in-place. 64 | **kwargs: Additional keyword arguments for future extensibility. 65 | """ 66 | pass 67 | 68 | @abstractmethod 69 | def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs: Any) -> None: 70 | """Called when the model's context window is exceeded. 71 | 72 | This method should implement the specific strategy for reducing the window size when a context overflow occurs. 73 | It is typically called after a ContextWindowOverflowException is caught. 74 | 75 | Implementations might use strategies such as: 76 | 77 | - Removing the N oldest messages 78 | - Summarizing older context 79 | - Applying importance-based filtering 80 | - Maintaining critical conversation markers 81 | 82 | Args: 83 | agent: The agent whose conversation history will be reduced. 84 | This list is modified in-place. 85 | e: The exception that triggered the context reduction, if any. 86 | **kwargs: Additional keyword arguments for future extensibility. 87 | """ 88 | pass 89 | -------------------------------------------------------------------------------- /src/strands/agent/conversation_manager/null_conversation_manager.py: -------------------------------------------------------------------------------- 1 | """Null implementation of conversation management.""" 2 | 3 | from typing import TYPE_CHECKING, Any, Optional 4 | 5 | if TYPE_CHECKING: 6 | from ...agent.agent import Agent 7 | 8 | from ...types.exceptions import ContextWindowOverflowException 9 | from .conversation_manager import ConversationManager 10 | 11 | 12 | class NullConversationManager(ConversationManager): 13 | """A no-op conversation manager that does not modify the conversation history. 14 | 15 | Useful for: 16 | 17 | - Testing scenarios where conversation management should be disabled 18 | - Cases where conversation history is managed externally 19 | - Situations where the full conversation history should be preserved 20 | """ 21 | 22 | def apply_management(self, agent: "Agent", **kwargs: Any) -> None: 23 | """Does nothing to the conversation history. 24 | 25 | Args: 26 | agent: The agent whose conversation history will remain unmodified. 27 | **kwargs: Additional keyword arguments for future extensibility. 28 | """ 29 | pass 30 | 31 | def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs: Any) -> None: 32 | """Does not reduce context and raises an exception. 33 | 34 | Args: 35 | agent: The agent whose conversation history will remain unmodified. 36 | e: The exception that triggered the context reduction, if any. 37 | **kwargs: Additional keyword arguments for future extensibility. 38 | 39 | Raises: 40 | e: If provided. 41 | ContextWindowOverflowException: If e is None. 42 | """ 43 | if e: 44 | raise e 45 | else: 46 | raise ContextWindowOverflowException("Context window overflowed!") 47 | -------------------------------------------------------------------------------- /src/strands/agent/state.py: -------------------------------------------------------------------------------- 1 | """Agent state management.""" 2 | 3 | import copy 4 | import json 5 | from typing import Any, Dict, Optional 6 | 7 | 8 | class AgentState: 9 | """Represents an Agent's stateful information outside of context provided to a model. 10 | 11 | Provides a key-value store for agent state with JSON serialization validation and persistence support. 12 | Key features: 13 | - JSON serialization validation on assignment 14 | - Get/set/delete operations 15 | """ 16 | 17 | def __init__(self, initial_state: Optional[Dict[str, Any]] = None): 18 | """Initialize AgentState.""" 19 | self._state: Dict[str, Dict[str, Any]] 20 | if initial_state: 21 | self._validate_json_serializable(initial_state) 22 | self._state = copy.deepcopy(initial_state) 23 | else: 24 | self._state = {} 25 | 26 | def set(self, key: str, value: Any) -> None: 27 | """Set a value in the state. 28 | 29 | Args: 30 | key: The key to store the value under 31 | value: The value to store (must be JSON serializable) 32 | 33 | Raises: 34 | ValueError: If key is invalid, or if value is not JSON serializable 35 | """ 36 | self._validate_key(key) 37 | self._validate_json_serializable(value) 38 | 39 | self._state[key] = copy.deepcopy(value) 40 | 41 | def get(self, key: Optional[str] = None) -> Any: 42 | """Get a value or entire state. 43 | 44 | Args: 45 | key: The key to retrieve (if None, returns entire state object) 46 | 47 | Returns: 48 | The stored value, entire state dict, or None if not found 49 | """ 50 | if key is None: 51 | return copy.deepcopy(self._state) 52 | else: 53 | # Return specific key 54 | return copy.deepcopy(self._state.get(key)) 55 | 56 | def delete(self, key: str) -> None: 57 | """Delete a specific key from the state. 58 | 59 | Args: 60 | key: The key to delete 61 | """ 62 | self._validate_key(key) 63 | 64 | self._state.pop(key, None) 65 | 66 | def _validate_key(self, key: str) -> None: 67 | """Validate that a key is valid. 68 | 69 | Args: 70 | key: The key to validate 71 | 72 | Raises: 73 | ValueError: If key is invalid 74 | """ 75 | if key is None: 76 | raise ValueError("Key cannot be None") 77 | if not isinstance(key, str): 78 | raise ValueError("Key must be a string") 79 | if not key.strip(): 80 | raise ValueError("Key cannot be empty") 81 | 82 | def _validate_json_serializable(self, value: Any) -> None: 83 | """Validate that a value is JSON serializable. 84 | 85 | Args: 86 | value: The value to validate 87 | 88 | Raises: 89 | ValueError: If value is not JSON serializable 90 | """ 91 | try: 92 | json.dumps(value) 93 | except (TypeError, ValueError) as e: 94 | raise ValueError( 95 | f"Value is not JSON serializable: {type(value).__name__}. " 96 | f"Only JSON-compatible types (str, int, float, bool, list, dict, None) are allowed." 97 | ) from e 98 | -------------------------------------------------------------------------------- /src/strands/event_loop/__init__.py: -------------------------------------------------------------------------------- 1 | """This package provides the core event loop implementation for the agents SDK. 2 | 3 | The event loop enables conversational AI agents to process messages, execute tools, and handle errors in a controlled, 4 | iterative manner. 5 | """ 6 | 7 | from . import event_loop 8 | 9 | __all__ = ["event_loop"] 10 | -------------------------------------------------------------------------------- /src/strands/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | """Experimental features. 2 | 3 | This module implements experimental features that are subject to change in future revisions without notice. 4 | """ 5 | -------------------------------------------------------------------------------- /src/strands/experimental/hooks/__init__.py: -------------------------------------------------------------------------------- 1 | """Experimental hook functionality that has not yet reached stability.""" 2 | 3 | from .events import ( 4 | AfterModelInvocationEvent, 5 | AfterToolInvocationEvent, 6 | BeforeModelInvocationEvent, 7 | BeforeToolInvocationEvent, 8 | ) 9 | 10 | __all__ = [ 11 | "BeforeToolInvocationEvent", 12 | "AfterToolInvocationEvent", 13 | "BeforeModelInvocationEvent", 14 | "AfterModelInvocationEvent", 15 | ] 16 | -------------------------------------------------------------------------------- /src/strands/experimental/hooks/events.py: -------------------------------------------------------------------------------- 1 | """Experimental hook events emitted as part of invoking Agents. 2 | 3 | This module defines the events that are emitted as Agents run through the lifecycle of a request. 4 | """ 5 | 6 | from dataclasses import dataclass 7 | from typing import Any, Optional 8 | 9 | from ...hooks import HookEvent 10 | from ...types.content import Message 11 | from ...types.streaming import StopReason 12 | from ...types.tools import AgentTool, ToolResult, ToolUse 13 | 14 | 15 | @dataclass 16 | class BeforeToolInvocationEvent(HookEvent): 17 | """Event triggered before a tool is invoked. 18 | 19 | This event is fired just before the agent executes a tool, allowing hook 20 | providers to inspect, modify, or replace the tool that will be executed. 21 | The selected_tool can be modified by hook callbacks to change which tool 22 | gets executed. 23 | 24 | Attributes: 25 | selected_tool: The tool that will be invoked. Can be modified by hooks 26 | to change which tool gets executed. This may be None if tool lookup failed. 27 | tool_use: The tool parameters that will be passed to selected_tool. 28 | invocation_state: Keyword arguments that will be passed to the tool. 29 | """ 30 | 31 | selected_tool: Optional[AgentTool] 32 | tool_use: ToolUse 33 | invocation_state: dict[str, Any] 34 | 35 | def _can_write(self, name: str) -> bool: 36 | return name in ["selected_tool", "tool_use"] 37 | 38 | 39 | @dataclass 40 | class AfterToolInvocationEvent(HookEvent): 41 | """Event triggered after a tool invocation completes. 42 | 43 | This event is fired after the agent has finished executing a tool, 44 | regardless of whether the execution was successful or resulted in an error. 45 | Hook providers can use this event for cleanup, logging, or post-processing. 46 | 47 | Note: This event uses reverse callback ordering, meaning callbacks registered 48 | later will be invoked first during cleanup. 49 | 50 | Attributes: 51 | selected_tool: The tool that was invoked. It may be None if tool lookup failed. 52 | tool_use: The tool parameters that were passed to the tool invoked. 53 | invocation_state: Keyword arguments that were passed to the tool 54 | result: The result of the tool invocation. Either a ToolResult on success 55 | or an Exception if the tool execution failed. 56 | """ 57 | 58 | selected_tool: Optional[AgentTool] 59 | tool_use: ToolUse 60 | invocation_state: dict[str, Any] 61 | result: ToolResult 62 | exception: Optional[Exception] = None 63 | 64 | def _can_write(self, name: str) -> bool: 65 | return name == "result" 66 | 67 | @property 68 | def should_reverse_callbacks(self) -> bool: 69 | """True to invoke callbacks in reverse order.""" 70 | return True 71 | 72 | 73 | @dataclass 74 | class BeforeModelInvocationEvent(HookEvent): 75 | """Event triggered before the model is invoked. 76 | 77 | This event is fired just before the agent calls the model for inference, 78 | allowing hook providers to inspect or modify the messages and configuration 79 | that will be sent to the model. 80 | 81 | Note: This event is not fired for invocations to structured_output. 82 | """ 83 | 84 | pass 85 | 86 | 87 | @dataclass 88 | class AfterModelInvocationEvent(HookEvent): 89 | """Event triggered after the model invocation completes. 90 | 91 | This event is fired after the agent has finished calling the model, 92 | regardless of whether the invocation was successful or resulted in an error. 93 | Hook providers can use this event for cleanup, logging, or post-processing. 94 | 95 | Note: This event uses reverse callback ordering, meaning callbacks registered 96 | later will be invoked first during cleanup. 97 | 98 | Note: This event is not fired for invocations to structured_output. 99 | 100 | Attributes: 101 | stop_response: The model response data if invocation was successful, None if failed. 102 | exception: Exception if the model invocation failed, None if successful. 103 | """ 104 | 105 | @dataclass 106 | class ModelStopResponse: 107 | """Model response data from successful invocation. 108 | 109 | Attributes: 110 | stop_reason: The reason the model stopped generating. 111 | message: The generated message from the model. 112 | """ 113 | 114 | message: Message 115 | stop_reason: StopReason 116 | 117 | stop_response: Optional[ModelStopResponse] = None 118 | exception: Optional[Exception] = None 119 | 120 | @property 121 | def should_reverse_callbacks(self) -> bool: 122 | """True to invoke callbacks in reverse order.""" 123 | return True 124 | -------------------------------------------------------------------------------- /src/strands/handlers/__init__.py: -------------------------------------------------------------------------------- 1 | """Various handlers for performing custom actions on agent state. 2 | 3 | Examples include: 4 | 5 | - Displaying events from the event stream 6 | """ 7 | 8 | from .callback_handler import CompositeCallbackHandler, PrintingCallbackHandler, null_callback_handler 9 | 10 | __all__ = ["CompositeCallbackHandler", "null_callback_handler", "PrintingCallbackHandler"] 11 | -------------------------------------------------------------------------------- /src/strands/handlers/callback_handler.py: -------------------------------------------------------------------------------- 1 | """This module provides handlers for formatting and displaying events from the agent.""" 2 | 3 | from collections.abc import Callable 4 | from typing import Any 5 | 6 | 7 | class PrintingCallbackHandler: 8 | """Handler for streaming text output and tool invocations to stdout.""" 9 | 10 | def __init__(self) -> None: 11 | """Initialize handler.""" 12 | self.tool_count = 0 13 | self.previous_tool_use = None 14 | 15 | def __call__(self, **kwargs: Any) -> None: 16 | """Stream text output and tool invocations to stdout. 17 | 18 | Args: 19 | **kwargs: Callback event data including: 20 | - reasoningText (Optional[str]): Reasoning text to print if provided. 21 | - data (str): Text content to stream. 22 | - complete (bool): Whether this is the final chunk of a response. 23 | - current_tool_use (dict): Information about the current tool being used. 24 | """ 25 | reasoningText = kwargs.get("reasoningText", False) 26 | data = kwargs.get("data", "") 27 | complete = kwargs.get("complete", False) 28 | current_tool_use = kwargs.get("current_tool_use", {}) 29 | 30 | if reasoningText: 31 | print(reasoningText, end="") 32 | 33 | if data: 34 | print(data, end="" if not complete else "\n") 35 | 36 | if current_tool_use and current_tool_use.get("name"): 37 | tool_name = current_tool_use.get("name", "Unknown tool") 38 | if self.previous_tool_use != current_tool_use: 39 | self.previous_tool_use = current_tool_use 40 | self.tool_count += 1 41 | print(f"\nTool #{self.tool_count}: {tool_name}") 42 | 43 | if complete and data: 44 | print("\n") 45 | 46 | 47 | class CompositeCallbackHandler: 48 | """Class-based callback handler that combines multiple callback handlers. 49 | 50 | This handler allows multiple callback handlers to be invoked for the same events, 51 | enabling different processing or output formats for the same stream data. 52 | """ 53 | 54 | def __init__(self, *handlers: Callable) -> None: 55 | """Initialize handler.""" 56 | self.handlers = handlers 57 | 58 | def __call__(self, **kwargs: Any) -> None: 59 | """Invoke all handlers in the chain.""" 60 | for handler in self.handlers: 61 | handler(**kwargs) 62 | 63 | 64 | def null_callback_handler(**_kwargs: Any) -> None: 65 | """Callback handler that discards all output. 66 | 67 | Args: 68 | **_kwargs: Event data (ignored). 69 | """ 70 | return None 71 | -------------------------------------------------------------------------------- /src/strands/hooks/__init__.py: -------------------------------------------------------------------------------- 1 | """Typed hook system for extending agent functionality. 2 | 3 | This module provides a composable mechanism for building objects that can hook 4 | into specific events during the agent lifecycle. The hook system enables both 5 | built-in SDK components and user code to react to or modify agent behavior 6 | through strongly-typed event callbacks. 7 | 8 | Example Usage: 9 | ```python 10 | from strands.hooks import HookProvider, HookRegistry 11 | from strands.hooks.events import StartRequestEvent, EndRequestEvent 12 | 13 | class LoggingHooks(HookProvider): 14 | def register_hooks(self, registry: HookRegistry) -> None: 15 | registry.add_callback(StartRequestEvent, self.log_start) 16 | registry.add_callback(EndRequestEvent, self.log_end) 17 | 18 | def log_start(self, event: StartRequestEvent) -> None: 19 | print(f"Request started for {event.agent.name}") 20 | 21 | def log_end(self, event: EndRequestEvent) -> None: 22 | print(f"Request completed for {event.agent.name}") 23 | 24 | # Use with agent 25 | agent = Agent(hooks=[LoggingHooks()]) 26 | ``` 27 | 28 | This replaces the older callback_handler approach with a more composable, 29 | type-safe system that supports multiple subscribers per event type. 30 | """ 31 | 32 | from .events import ( 33 | AfterInvocationEvent, 34 | AgentInitializedEvent, 35 | BeforeInvocationEvent, 36 | MessageAddedEvent, 37 | ) 38 | from .registry import HookCallback, HookEvent, HookProvider, HookRegistry 39 | 40 | __all__ = [ 41 | "AgentInitializedEvent", 42 | "BeforeInvocationEvent", 43 | "AfterInvocationEvent", 44 | "MessageAddedEvent", 45 | "HookEvent", 46 | "HookProvider", 47 | "HookCallback", 48 | "HookRegistry", 49 | ] 50 | -------------------------------------------------------------------------------- /src/strands/hooks/events.py: -------------------------------------------------------------------------------- 1 | """Hook events emitted as part of invoking Agents. 2 | 3 | This module defines the events that are emitted as Agents run through the lifecycle of a request. 4 | """ 5 | 6 | from dataclasses import dataclass 7 | 8 | from ..types.content import Message 9 | from .registry import HookEvent 10 | 11 | 12 | @dataclass 13 | class AgentInitializedEvent(HookEvent): 14 | """Event triggered when an agent has finished initialization. 15 | 16 | This event is fired after the agent has been fully constructed and all 17 | built-in components have been initialized. Hook providers can use this 18 | event to perform setup tasks that require a fully initialized agent. 19 | """ 20 | 21 | pass 22 | 23 | 24 | @dataclass 25 | class BeforeInvocationEvent(HookEvent): 26 | """Event triggered at the beginning of a new agent request. 27 | 28 | This event is fired before the agent begins processing a new user request, 29 | before any model inference or tool execution occurs. Hook providers can 30 | use this event to perform request-level setup, logging, or validation. 31 | 32 | This event is triggered at the beginning of the following api calls: 33 | - Agent.__call__ 34 | - Agent.stream_async 35 | - Agent.structured_output 36 | """ 37 | 38 | pass 39 | 40 | 41 | @dataclass 42 | class AfterInvocationEvent(HookEvent): 43 | """Event triggered at the end of an agent request. 44 | 45 | This event is fired after the agent has completed processing a request, 46 | regardless of whether it completed successfully or encountered an error. 47 | Hook providers can use this event for cleanup, logging, or state persistence. 48 | 49 | Note: This event uses reverse callback ordering, meaning callbacks registered 50 | later will be invoked first during cleanup. 51 | 52 | This event is triggered at the end of the following api calls: 53 | - Agent.__call__ 54 | - Agent.stream_async 55 | - Agent.structured_output 56 | """ 57 | 58 | @property 59 | def should_reverse_callbacks(self) -> bool: 60 | """True to invoke callbacks in reverse order.""" 61 | return True 62 | 63 | 64 | @dataclass 65 | class MessageAddedEvent(HookEvent): 66 | """Event triggered when a message is added to the agent's conversation. 67 | 68 | This event is fired whenever the agent adds a new message to its internal 69 | message history, including user messages, assistant responses, and tool 70 | results. Hook providers can use this event for logging, monitoring, or 71 | implementing custom message processing logic. 72 | 73 | Note: This event is only triggered for messages added by the framework 74 | itself, not for messages manually added by tools or external code. 75 | 76 | Attributes: 77 | message: The message that was added to the conversation history. 78 | """ 79 | 80 | message: Message 81 | -------------------------------------------------------------------------------- /src/strands/hooks/rules.md: -------------------------------------------------------------------------------- 1 | # Hook System Rules 2 | 3 | ## Terminology 4 | 5 | - **Paired events**: Events that denote the beginning and end of an operation 6 | - **Hook callback**: A function that receives a strongly-typed event argument and performs some action in response 7 | 8 | ## Naming Conventions 9 | 10 | - All hook events have a suffix of `Event` 11 | - Paired events follow the naming convention of `Before{Item}Event` and `After{Item}Event` 12 | 13 | ## Paired Events 14 | 15 | - The final event in a pair returns `True` for `should_reverse_callbacks` 16 | - For every `Before` event there is a corresponding `After` event, even if an exception occurs 17 | 18 | ## Writable Properties 19 | 20 | For events with writable properties, those values are re-read after invoking the hook callbacks and used in subsequent processing. For example, `BeforeToolInvocationEvent.selected_tool` is writable - after invoking the callback for `BeforeToolInvocationEvent`, the `selected_tool` takes effect for the tool call. -------------------------------------------------------------------------------- /src/strands/models/__init__.py: -------------------------------------------------------------------------------- 1 | """SDK model providers. 2 | 3 | This package includes an abstract base Model class along with concrete implementations for specific providers. 4 | """ 5 | 6 | from . import bedrock, model 7 | from .bedrock import BedrockModel 8 | from .model import Model 9 | 10 | __all__ = ["bedrock", "model", "BedrockModel", "Model"] 11 | -------------------------------------------------------------------------------- /src/strands/models/model.py: -------------------------------------------------------------------------------- 1 | """Abstract base class for Agent model providers.""" 2 | 3 | import abc 4 | import logging 5 | from typing import Any, AsyncGenerator, AsyncIterable, Optional, Type, TypeVar, Union 6 | 7 | from pydantic import BaseModel 8 | 9 | from ..types.content import Messages 10 | from ..types.streaming import StreamEvent 11 | from ..types.tools import ToolSpec 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | T = TypeVar("T", bound=BaseModel) 16 | 17 | 18 | class Model(abc.ABC): 19 | """Abstract base class for Agent model providers. 20 | 21 | This class defines the interface for all model implementations in the Strands Agents SDK. It provides a 22 | standardized way to configure and process requests for different AI model providers. 23 | """ 24 | 25 | @abc.abstractmethod 26 | # pragma: no cover 27 | def update_config(self, **model_config: Any) -> None: 28 | """Update the model configuration with the provided arguments. 29 | 30 | Args: 31 | **model_config: Configuration overrides. 32 | """ 33 | pass 34 | 35 | @abc.abstractmethod 36 | # pragma: no cover 37 | def get_config(self) -> Any: 38 | """Return the model configuration. 39 | 40 | Returns: 41 | The model's configuration. 42 | """ 43 | pass 44 | 45 | @abc.abstractmethod 46 | # pragma: no cover 47 | def structured_output( 48 | self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any 49 | ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: 50 | """Get structured output from the model. 51 | 52 | Args: 53 | output_model: The output model to use for the agent. 54 | prompt: The prompt messages to use for the agent. 55 | system_prompt: System prompt to provide context to the model. 56 | **kwargs: Additional keyword arguments for future extensibility. 57 | 58 | Yields: 59 | Model events with the last being the structured output. 60 | 61 | Raises: 62 | ValidationException: The response format from the model does not match the output_model 63 | """ 64 | pass 65 | 66 | @abc.abstractmethod 67 | # pragma: no cover 68 | def stream( 69 | self, 70 | messages: Messages, 71 | tool_specs: Optional[list[ToolSpec]] = None, 72 | system_prompt: Optional[str] = None, 73 | **kwargs: Any, 74 | ) -> AsyncIterable[StreamEvent]: 75 | """Stream conversation with the model. 76 | 77 | This method handles the full lifecycle of conversing with the model: 78 | 79 | 1. Format the messages, tool specs, and configuration into a streaming request 80 | 2. Send the request to the model 81 | 3. Yield the formatted message chunks 82 | 83 | Args: 84 | messages: List of message objects to be processed by the model. 85 | tool_specs: List of tool specifications to make available to the model. 86 | system_prompt: System prompt to provide context to the model. 87 | **kwargs: Additional keyword arguments for future extensibility. 88 | 89 | Yields: 90 | Formatted message chunks from the model. 91 | 92 | Raises: 93 | ModelThrottledException: When the model service is throttling requests from the client. 94 | """ 95 | pass 96 | -------------------------------------------------------------------------------- /src/strands/multiagent/__init__.py: -------------------------------------------------------------------------------- 1 | """Multiagent capabilities for Strands Agents. 2 | 3 | This module provides support for multiagent systems, including agent-to-agent (A2A) 4 | communication protocols and coordination mechanisms. 5 | 6 | Submodules: 7 | a2a: Implementation of the Agent-to-Agent (A2A) protocol, which enables 8 | standardized communication between agents. 9 | """ 10 | 11 | from .base import MultiAgentBase, MultiAgentResult 12 | from .graph import GraphBuilder, GraphResult 13 | from .swarm import Swarm, SwarmResult 14 | 15 | __all__ = [ 16 | "GraphBuilder", 17 | "GraphResult", 18 | "MultiAgentBase", 19 | "MultiAgentResult", 20 | "Swarm", 21 | "SwarmResult", 22 | ] 23 | -------------------------------------------------------------------------------- /src/strands/multiagent/a2a/__init__.py: -------------------------------------------------------------------------------- 1 | """Agent-to-Agent (A2A) communication protocol implementation for Strands Agents. 2 | 3 | This module provides classes and utilities for enabling Strands Agents to communicate 4 | with other agents using the Agent-to-Agent (A2A) protocol. 5 | 6 | Docs: https://google-a2a.github.io/A2A/latest/ 7 | 8 | Classes: 9 | A2AAgent: A wrapper that adapts a Strands Agent to be A2A-compatible. 10 | """ 11 | 12 | from .executor import StrandsA2AExecutor 13 | from .server import A2AServer 14 | 15 | __all__ = ["A2AServer", "StrandsA2AExecutor"] 16 | -------------------------------------------------------------------------------- /src/strands/multiagent/base.py: -------------------------------------------------------------------------------- 1 | """Multi-Agent Base Class. 2 | 3 | Provides minimal foundation for multi-agent patterns (Swarm, Graph). 4 | """ 5 | 6 | from abc import ABC, abstractmethod 7 | from dataclasses import dataclass, field 8 | from enum import Enum 9 | from typing import Any, Union 10 | 11 | from ..agent import AgentResult 12 | from ..types.content import ContentBlock 13 | from ..types.event_loop import Metrics, Usage 14 | 15 | 16 | class Status(Enum): 17 | """Execution status for both graphs and nodes.""" 18 | 19 | PENDING = "pending" 20 | EXECUTING = "executing" 21 | COMPLETED = "completed" 22 | FAILED = "failed" 23 | 24 | 25 | @dataclass 26 | class NodeResult: 27 | """Unified result from node execution - handles both Agent and nested MultiAgentBase results. 28 | 29 | The status field represents the semantic outcome of the node's work: 30 | - COMPLETED: The node's task was successfully accomplished 31 | - FAILED: The node's task failed or produced an error 32 | """ 33 | 34 | # Core result data - single AgentResult, nested MultiAgentResult, or Exception 35 | result: Union[AgentResult, "MultiAgentResult", Exception] 36 | 37 | # Execution metadata 38 | execution_time: int = 0 39 | status: Status = Status.PENDING 40 | 41 | # Accumulated metrics from this node and all children 42 | accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0)) 43 | accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0)) 44 | execution_count: int = 0 45 | 46 | def get_agent_results(self) -> list[AgentResult]: 47 | """Get all AgentResult objects from this node, flattened if nested.""" 48 | if isinstance(self.result, Exception): 49 | return [] # No agent results for exceptions 50 | elif isinstance(self.result, AgentResult): 51 | return [self.result] 52 | else: 53 | # Flatten nested results from MultiAgentResult 54 | flattened = [] 55 | for nested_node_result in self.result.results.values(): 56 | flattened.extend(nested_node_result.get_agent_results()) 57 | return flattened 58 | 59 | 60 | @dataclass 61 | class MultiAgentResult: 62 | """Result from multi-agent execution with accumulated metrics. 63 | 64 | The status field represents the outcome of the MultiAgentBase execution: 65 | - COMPLETED: The execution was successfully accomplished 66 | - FAILED: The execution failed or produced an error 67 | """ 68 | 69 | status: Status = Status.PENDING 70 | results: dict[str, NodeResult] = field(default_factory=lambda: {}) 71 | accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0)) 72 | accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0)) 73 | execution_count: int = 0 74 | execution_time: int = 0 75 | 76 | 77 | class MultiAgentBase(ABC): 78 | """Base class for multi-agent helpers. 79 | 80 | This class integrates with existing Strands Agent instances and provides 81 | multi-agent orchestration capabilities. 82 | """ 83 | 84 | @abstractmethod 85 | async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> MultiAgentResult: 86 | """Invoke asynchronously.""" 87 | raise NotImplementedError("invoke_async not implemented") 88 | 89 | @abstractmethod 90 | def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> MultiAgentResult: 91 | """Invoke synchronously.""" 92 | raise NotImplementedError("__call__ not implemented") 93 | -------------------------------------------------------------------------------- /src/strands/py.typed: -------------------------------------------------------------------------------- 1 | # Marker file that indicates this package supports typing 2 | -------------------------------------------------------------------------------- /src/strands/session/__init__.py: -------------------------------------------------------------------------------- 1 | """Session module. 2 | 3 | This module provides session management functionality. 4 | """ 5 | 6 | from .file_session_manager import FileSessionManager 7 | from .repository_session_manager import RepositorySessionManager 8 | from .s3_session_manager import S3SessionManager 9 | from .session_manager import SessionManager 10 | from .session_repository import SessionRepository 11 | 12 | __all__ = [ 13 | "FileSessionManager", 14 | "RepositorySessionManager", 15 | "S3SessionManager", 16 | "SessionManager", 17 | "SessionRepository", 18 | ] 19 | -------------------------------------------------------------------------------- /src/strands/session/session_manager.py: -------------------------------------------------------------------------------- 1 | """Session manager interface for agent session management.""" 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import TYPE_CHECKING, Any 5 | 6 | from ..hooks.events import AfterInvocationEvent, AgentInitializedEvent, MessageAddedEvent 7 | from ..hooks.registry import HookProvider, HookRegistry 8 | from ..types.content import Message 9 | 10 | if TYPE_CHECKING: 11 | from ..agent.agent import Agent 12 | 13 | 14 | class SessionManager(HookProvider, ABC): 15 | """Abstract interface for managing sessions. 16 | 17 | A session manager is in charge of persisting the conversation and state of an agent across its interaction. 18 | Changes made to the agents conversation, state, or other attributes should be persisted immediately after 19 | they are changed. The different methods introduced in this class are called at important lifecycle events 20 | for an agent, and should be persisted in the session. 21 | """ 22 | 23 | def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: 24 | """Register hooks for persisting the agent to the session.""" 25 | # After the normal Agent initialization behavior, call the session initialize function to restore the agent 26 | registry.add_callback(AgentInitializedEvent, lambda event: self.initialize(event.agent)) 27 | 28 | # For each message appended to the Agents messages, store that message in the session 29 | registry.add_callback(MessageAddedEvent, lambda event: self.append_message(event.message, event.agent)) 30 | 31 | # Sync the agent into the session for each message in case the agent state was updated 32 | registry.add_callback(MessageAddedEvent, lambda event: self.sync_agent(event.agent)) 33 | 34 | # After an agent was invoked, sync it with the session to capture any conversation manager state updates 35 | registry.add_callback(AfterInvocationEvent, lambda event: self.sync_agent(event.agent)) 36 | 37 | @abstractmethod 38 | def redact_latest_message(self, redact_message: Message, agent: "Agent", **kwargs: Any) -> None: 39 | """Redact the message most recently appended to the agent in the session. 40 | 41 | Args: 42 | redact_message: New message to use that contains the redact content 43 | agent: Agent to apply the message redaction to 44 | **kwargs: Additional keyword arguments for future extensibility. 45 | """ 46 | 47 | @abstractmethod 48 | def append_message(self, message: Message, agent: "Agent", **kwargs: Any) -> None: 49 | """Append a message to the agent's session. 50 | 51 | Args: 52 | message: Message to add to the agent in the session 53 | agent: Agent to append the message to 54 | **kwargs: Additional keyword arguments for future extensibility. 55 | """ 56 | 57 | @abstractmethod 58 | def sync_agent(self, agent: "Agent", **kwargs: Any) -> None: 59 | """Serialize and sync the agent with the session storage. 60 | 61 | Args: 62 | agent: Agent who should be synchronized with the session storage 63 | **kwargs: Additional keyword arguments for future extensibility. 64 | """ 65 | 66 | @abstractmethod 67 | def initialize(self, agent: "Agent", **kwargs: Any) -> None: 68 | """Initialize an agent with a session. 69 | 70 | Args: 71 | agent: Agent to initialize 72 | **kwargs: Additional keyword arguments for future extensibility. 73 | """ 74 | -------------------------------------------------------------------------------- /src/strands/session/session_repository.py: -------------------------------------------------------------------------------- 1 | """Session repository interface for agent session management.""" 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import Any, Optional 5 | 6 | from ..types.session import Session, SessionAgent, SessionMessage 7 | 8 | 9 | class SessionRepository(ABC): 10 | """Abstract repository for creating, reading, and updating Sessions, AgentSessions, and AgentMessages.""" 11 | 12 | @abstractmethod 13 | def create_session(self, session: Session, **kwargs: Any) -> Session: 14 | """Create a new Session.""" 15 | 16 | @abstractmethod 17 | def read_session(self, session_id: str, **kwargs: Any) -> Optional[Session]: 18 | """Read a Session.""" 19 | 20 | @abstractmethod 21 | def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: Any) -> None: 22 | """Create a new Agent in a Session.""" 23 | 24 | @abstractmethod 25 | def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[SessionAgent]: 26 | """Read an Agent.""" 27 | 28 | @abstractmethod 29 | def update_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: Any) -> None: 30 | """Update an Agent.""" 31 | 32 | @abstractmethod 33 | def create_message(self, session_id: str, agent_id: str, session_message: SessionMessage, **kwargs: Any) -> None: 34 | """Create a new Message for the Agent.""" 35 | 36 | @abstractmethod 37 | def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs: Any) -> Optional[SessionMessage]: 38 | """Read a Message.""" 39 | 40 | @abstractmethod 41 | def update_message(self, session_id: str, agent_id: str, session_message: SessionMessage, **kwargs: Any) -> None: 42 | """Update a Message. 43 | 44 | A message is usually only updated when some content is redacted due to a guardrail. 45 | """ 46 | 47 | @abstractmethod 48 | def list_messages( 49 | self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0, **kwargs: Any 50 | ) -> list[SessionMessage]: 51 | """List Messages from an Agent with pagination.""" 52 | -------------------------------------------------------------------------------- /src/strands/telemetry/__init__.py: -------------------------------------------------------------------------------- 1 | """Telemetry module. 2 | 3 | This module provides metrics and tracing functionality. 4 | """ 5 | 6 | from .config import StrandsTelemetry 7 | from .metrics import EventLoopMetrics, MetricsClient, Trace, metrics_to_string 8 | from .tracer import Tracer, get_tracer 9 | 10 | __all__ = [ 11 | # Metrics 12 | "EventLoopMetrics", 13 | "Trace", 14 | "metrics_to_string", 15 | "MetricsClient", 16 | # Tracer 17 | "Tracer", 18 | "get_tracer", 19 | # Telemetry Setup 20 | "StrandsTelemetry", 21 | ] 22 | -------------------------------------------------------------------------------- /src/strands/telemetry/metrics_constants.py: -------------------------------------------------------------------------------- 1 | """Metrics that are emitted in Strands-Agents.""" 2 | 3 | STRANDS_EVENT_LOOP_CYCLE_COUNT = "strands.event_loop.cycle_count" 4 | STRANDS_EVENT_LOOP_START_CYCLE = "strands.event_loop.start_cycle" 5 | STRANDS_EVENT_LOOP_END_CYCLE = "strands.event_loop.end_cycle" 6 | STRANDS_TOOL_CALL_COUNT = "strands.tool.call_count" 7 | STRANDS_TOOL_SUCCESS_COUNT = "strands.tool.success_count" 8 | STRANDS_TOOL_ERROR_COUNT = "strands.tool.error_count" 9 | 10 | # Histograms 11 | STRANDS_EVENT_LOOP_LATENCY = "strands.event_loop.latency" 12 | STRANDS_TOOL_DURATION = "strands.tool.duration" 13 | STRANDS_EVENT_LOOP_CYCLE_DURATION = "strands.event_loop.cycle_duration" 14 | STRANDS_EVENT_LOOP_INPUT_TOKENS = "strands.event_loop.input.tokens" 15 | STRANDS_EVENT_LOOP_OUTPUT_TOKENS = "strands.event_loop.output.tokens" 16 | -------------------------------------------------------------------------------- /src/strands/tools/__init__.py: -------------------------------------------------------------------------------- 1 | """Agent tool interfaces and utilities. 2 | 3 | This module provides the core functionality for creating, managing, and executing tools through agents. 4 | """ 5 | 6 | from .decorator import tool 7 | from .structured_output import convert_pydantic_to_tool_spec 8 | from .tools import InvalidToolUseNameException, PythonAgentTool, normalize_schema, normalize_tool_spec 9 | 10 | __all__ = [ 11 | "tool", 12 | "PythonAgentTool", 13 | "InvalidToolUseNameException", 14 | "normalize_schema", 15 | "normalize_tool_spec", 16 | "convert_pydantic_to_tool_spec", 17 | ] 18 | -------------------------------------------------------------------------------- /src/strands/tools/executor.py: -------------------------------------------------------------------------------- 1 | """Tool execution functionality for the event loop.""" 2 | 3 | import asyncio 4 | import logging 5 | import time 6 | from typing import Any, Optional, cast 7 | 8 | from opentelemetry import trace 9 | 10 | from ..telemetry.metrics import EventLoopMetrics, Trace 11 | from ..telemetry.tracer import get_tracer 12 | from ..tools.tools import InvalidToolUseNameException, validate_tool_use 13 | from ..types.content import Message 14 | from ..types.tools import RunToolHandler, ToolGenerator, ToolResult, ToolUse 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | async def run_tools( 20 | handler: RunToolHandler, 21 | tool_uses: list[ToolUse], 22 | event_loop_metrics: EventLoopMetrics, 23 | invalid_tool_use_ids: list[str], 24 | tool_results: list[ToolResult], 25 | cycle_trace: Trace, 26 | parent_span: Optional[trace.Span] = None, 27 | ) -> ToolGenerator: 28 | """Execute tools concurrently. 29 | 30 | Args: 31 | handler: Tool handler processing function. 32 | tool_uses: List of tool uses to execute. 33 | event_loop_metrics: Metrics collection object. 34 | invalid_tool_use_ids: List of invalid tool use IDs. 35 | tool_results: List to populate with tool results. 36 | cycle_trace: Parent trace for the current cycle. 37 | parent_span: Parent span for the current cycle. 38 | 39 | Yields: 40 | Events of the tool stream. Tool results are appended to `tool_results`. 41 | """ 42 | 43 | async def work( 44 | tool_use: ToolUse, 45 | worker_id: int, 46 | worker_queue: asyncio.Queue, 47 | worker_event: asyncio.Event, 48 | stop_event: object, 49 | ) -> ToolResult: 50 | tracer = get_tracer() 51 | tool_call_span = tracer.start_tool_call_span(tool_use, parent_span) 52 | 53 | tool_name = tool_use["name"] 54 | tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name) 55 | tool_start_time = time.time() 56 | 57 | try: 58 | async for event in handler(tool_use): 59 | worker_queue.put_nowait((worker_id, event)) 60 | await worker_event.wait() 61 | worker_event.clear() 62 | 63 | result = cast(ToolResult, event) 64 | finally: 65 | worker_queue.put_nowait((worker_id, stop_event)) 66 | 67 | tool_success = result.get("status") == "success" 68 | tool_duration = time.time() - tool_start_time 69 | message = Message(role="user", content=[{"toolResult": result}]) 70 | event_loop_metrics.add_tool_usage(tool_use, tool_duration, tool_trace, tool_success, message) 71 | cycle_trace.add_child(tool_trace) 72 | 73 | if tool_call_span: 74 | tracer.end_tool_call_span(tool_call_span, result) 75 | 76 | return result 77 | 78 | tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids] 79 | worker_queue: asyncio.Queue[tuple[int, Any]] = asyncio.Queue() 80 | worker_events = [asyncio.Event() for _ in tool_uses] 81 | stop_event = object() 82 | 83 | workers = [ 84 | asyncio.create_task(work(tool_use, worker_id, worker_queue, worker_events[worker_id], stop_event)) 85 | for worker_id, tool_use in enumerate(tool_uses) 86 | ] 87 | 88 | worker_count = len(workers) 89 | while worker_count: 90 | worker_id, event = await worker_queue.get() 91 | if event is stop_event: 92 | worker_count -= 1 93 | continue 94 | 95 | yield event 96 | worker_events[worker_id].set() 97 | 98 | tool_results.extend([worker.result() for worker in workers]) 99 | 100 | 101 | def validate_and_prepare_tools( 102 | message: Message, 103 | tool_uses: list[ToolUse], 104 | tool_results: list[ToolResult], 105 | invalid_tool_use_ids: list[str], 106 | ) -> None: 107 | """Validate tool uses and prepare them for execution. 108 | 109 | Args: 110 | message: Current message. 111 | tool_uses: List to populate with tool uses. 112 | tool_results: List to populate with tool results for invalid tools. 113 | invalid_tool_use_ids: List to populate with invalid tool use IDs. 114 | """ 115 | # Extract tool uses from message 116 | for content in message["content"]: 117 | if isinstance(content, dict) and "toolUse" in content: 118 | tool_uses.append(content["toolUse"]) 119 | 120 | # Validate tool uses 121 | # Avoid modifying original `tool_uses` variable during iteration 122 | tool_uses_copy = tool_uses.copy() 123 | for tool in tool_uses_copy: 124 | try: 125 | validate_tool_use(tool) 126 | except InvalidToolUseNameException as e: 127 | # Replace the invalid toolUse name and return invalid name error as ToolResult to the LLM as context 128 | tool_uses.remove(tool) 129 | tool["name"] = "INVALID_TOOL_NAME" 130 | invalid_tool_use_ids.append(tool["toolUseId"]) 131 | tool_uses.append(tool) 132 | tool_results.append( 133 | { 134 | "toolUseId": tool["toolUseId"], 135 | "status": "error", 136 | "content": [{"text": f"Error: {str(e)}"}], 137 | } 138 | ) 139 | -------------------------------------------------------------------------------- /src/strands/tools/mcp/__init__.py: -------------------------------------------------------------------------------- 1 | """Model Context Protocol (MCP) integration. 2 | 3 | This package provides integration with the Model Context Protocol (MCP), allowing agents to use tools provided by MCP 4 | servers. 5 | 6 | - Docs: https://www.anthropic.com/news/model-context-protocol 7 | """ 8 | 9 | from .mcp_agent_tool import MCPAgentTool 10 | from .mcp_client import MCPClient 11 | from .mcp_types import MCPTransport 12 | 13 | __all__ = ["MCPAgentTool", "MCPClient", "MCPTransport"] 14 | -------------------------------------------------------------------------------- /src/strands/tools/mcp/mcp_agent_tool.py: -------------------------------------------------------------------------------- 1 | """MCP Agent Tool module for adapting Model Context Protocol tools to the agent framework. 2 | 3 | This module provides the MCPAgentTool class which serves as an adapter between 4 | MCP (Model Context Protocol) tools and the agent framework's tool interface. 5 | It allows MCP tools to be seamlessly integrated and used within the agent ecosystem. 6 | """ 7 | 8 | import logging 9 | from typing import TYPE_CHECKING, Any 10 | 11 | from mcp.types import Tool as MCPTool 12 | from typing_extensions import override 13 | 14 | from ...types.tools import AgentTool, ToolGenerator, ToolSpec, ToolUse 15 | 16 | if TYPE_CHECKING: 17 | from .mcp_client import MCPClient 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class MCPAgentTool(AgentTool): 23 | """Adapter class that wraps an MCP tool and exposes it as an AgentTool. 24 | 25 | This class bridges the gap between the MCP protocol's tool representation 26 | and the agent framework's tool interface, allowing MCP tools to be used 27 | seamlessly within the agent framework. 28 | """ 29 | 30 | def __init__(self, mcp_tool: MCPTool, mcp_client: "MCPClient") -> None: 31 | """Initialize a new MCPAgentTool instance. 32 | 33 | Args: 34 | mcp_tool: The MCP tool to adapt 35 | mcp_client: The MCP server connection to use for tool invocation 36 | """ 37 | super().__init__() 38 | logger.debug("tool_name=<%s> | creating mcp agent tool", mcp_tool.name) 39 | self.mcp_tool = mcp_tool 40 | self.mcp_client = mcp_client 41 | 42 | @property 43 | def tool_name(self) -> str: 44 | """Get the name of the tool. 45 | 46 | Returns: 47 | str: The name of the MCP tool 48 | """ 49 | return self.mcp_tool.name 50 | 51 | @property 52 | def tool_spec(self) -> ToolSpec: 53 | """Get the specification of the tool. 54 | 55 | This method converts the MCP tool specification to the agent framework's 56 | ToolSpec format, including the input schema and description. 57 | 58 | Returns: 59 | ToolSpec: The tool specification in the agent framework format 60 | """ 61 | description: str = self.mcp_tool.description or f"Tool which performs {self.mcp_tool.name}" 62 | return { 63 | "inputSchema": {"json": self.mcp_tool.inputSchema}, 64 | "name": self.mcp_tool.name, 65 | "description": description, 66 | } 67 | 68 | @property 69 | def tool_type(self) -> str: 70 | """Get the type of the tool. 71 | 72 | Returns: 73 | str: The type of the tool, always "python" for MCP tools 74 | """ 75 | return "python" 76 | 77 | @override 78 | async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any) -> ToolGenerator: 79 | """Stream the MCP tool. 80 | 81 | This method delegates the tool stream to the MCP server connection, passing the tool use ID, tool name, and 82 | input arguments. 83 | 84 | Args: 85 | tool_use: The tool use request containing tool ID and parameters. 86 | invocation_state: Context for the tool invocation, including agent state. 87 | **kwargs: Additional keyword arguments for future extensibility. 88 | 89 | Yields: 90 | Tool events with the last being the tool result. 91 | """ 92 | logger.debug("tool_name=<%s>, tool_use_id=<%s> | streaming", self.tool_name, tool_use["toolUseId"]) 93 | 94 | result = await self.mcp_client.call_tool_async( 95 | tool_use_id=tool_use["toolUseId"], 96 | name=self.tool_name, 97 | arguments=tool_use["input"], 98 | ) 99 | yield result 100 | -------------------------------------------------------------------------------- /src/strands/tools/mcp/mcp_types.py: -------------------------------------------------------------------------------- 1 | """Type definitions for MCP integration.""" 2 | 3 | from contextlib import AbstractAsyncContextManager 4 | 5 | from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream 6 | from mcp.client.streamable_http import GetSessionIdCallback 7 | from mcp.shared.memory import MessageStream 8 | from mcp.shared.message import SessionMessage 9 | 10 | """ 11 | MCPTransport defines the interface for MCP transport implementations. This abstracts 12 | communication with an MCP server, hiding details of the underlying transport mechanism (WebSocket, stdio, etc.). 13 | 14 | It represents an async context manager that yields a tuple of read and write streams for MCP communication. 15 | When used with `async with`, it should establish the connection and yield the streams, then clean up 16 | when the context is exited. 17 | 18 | The read stream receives messages from the client (or exceptions if parsing fails), while the write 19 | stream sends messages to the client. 20 | 21 | Example implementation (simplified): 22 | ```python 23 | @contextlib.asynccontextmanager 24 | async def my_transport_implementation(): 25 | # Set up connection 26 | read_stream_writer, read_stream = anyio.create_memory_object_stream(0) 27 | write_stream, write_stream_reader = anyio.create_memory_object_stream(0) 28 | 29 | # Start background tasks to handle actual I/O 30 | async with anyio.create_task_group() as tg: 31 | tg.start_soon(reader_task, read_stream_writer) 32 | tg.start_soon(writer_task, write_stream_reader) 33 | 34 | # Yield the streams to the caller 35 | yield (read_stream, write_stream) 36 | ``` 37 | """ 38 | # GetSessionIdCallback was added for HTTP Streaming but was not applied to the MessageStream type 39 | # https://github.com/modelcontextprotocol/python-sdk/blob/ed25167fa5d715733437996682e20c24470e8177/src/mcp/client/streamable_http.py#L418 40 | _MessageStreamWithGetSessionIdCallback = tuple[ 41 | MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage], GetSessionIdCallback 42 | ] 43 | MCPTransport = AbstractAsyncContextManager[MessageStream | _MessageStreamWithGetSessionIdCallback] 44 | -------------------------------------------------------------------------------- /src/strands/tools/watcher.py: -------------------------------------------------------------------------------- 1 | """Tool watcher for hot reloading tools during development. 2 | 3 | This module provides functionality to watch tool directories for changes and automatically reload tools when they are 4 | modified. 5 | """ 6 | 7 | import logging 8 | from pathlib import Path 9 | from typing import Any, Dict, Set 10 | 11 | from watchdog.events import FileSystemEventHandler 12 | from watchdog.observers import Observer 13 | 14 | from .registry import ToolRegistry 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | class ToolWatcher: 20 | """Watches tool directories for changes and reloads tools when they are modified.""" 21 | 22 | # This class uses class variables for the observer and handlers because watchdog allows only one Observer instance 23 | # per directory. Using class variables ensures that all ToolWatcher instances share a single Observer, with the 24 | # MasterChangeHandler routing file system events to the appropriate individual handlers for each registry. This 25 | # design pattern avoids conflicts when multiple tool registries are watching the same directories. 26 | 27 | _shared_observer = None 28 | _watched_dirs: Set[str] = set() 29 | _observer_started = False 30 | _registry_handlers: Dict[str, Dict[int, "ToolWatcher.ToolChangeHandler"]] = {} 31 | 32 | def __init__(self, tool_registry: ToolRegistry) -> None: 33 | """Initialize a tool watcher for the given tool registry. 34 | 35 | Args: 36 | tool_registry: The tool registry to report changes. 37 | """ 38 | self.tool_registry = tool_registry 39 | self.start() 40 | 41 | class ToolChangeHandler(FileSystemEventHandler): 42 | """Handler for tool file changes.""" 43 | 44 | def __init__(self, tool_registry: ToolRegistry) -> None: 45 | """Initialize a tool change handler. 46 | 47 | Args: 48 | tool_registry: The tool registry to update when tools change. 49 | """ 50 | self.tool_registry = tool_registry 51 | 52 | def on_modified(self, event: Any) -> None: 53 | """Reload tool if file modification detected. 54 | 55 | Args: 56 | event: The file system event that triggered this handler. 57 | """ 58 | if event.src_path.endswith(".py"): 59 | tool_path = Path(event.src_path) 60 | tool_name = tool_path.stem 61 | 62 | if tool_name not in ["__init__"]: 63 | logger.debug("tool_name=<%s> | tool change detected", tool_name) 64 | try: 65 | self.tool_registry.reload_tool(tool_name) 66 | except Exception as e: 67 | logger.error("tool_name=<%s>, exception=<%s> | failed to reload tool", tool_name, str(e)) 68 | 69 | class MasterChangeHandler(FileSystemEventHandler): 70 | """Master handler that delegates to all registered handlers.""" 71 | 72 | def __init__(self, dir_path: str) -> None: 73 | """Initialize a master change handler for a specific directory. 74 | 75 | Args: 76 | dir_path: The directory path to watch. 77 | """ 78 | self.dir_path = dir_path 79 | 80 | def on_modified(self, event: Any) -> None: 81 | """Delegate file modification events to all registered handlers. 82 | 83 | Args: 84 | event: The file system event that triggered this handler. 85 | """ 86 | if event.src_path.endswith(".py"): 87 | tool_path = Path(event.src_path) 88 | tool_name = tool_path.stem 89 | 90 | if tool_name not in ["__init__"]: 91 | # Delegate to all registered handlers for this directory 92 | for handler in ToolWatcher._registry_handlers.get(self.dir_path, {}).values(): 93 | try: 94 | handler.on_modified(event) 95 | except Exception as e: 96 | logger.error("exception=<%s> | handler error", str(e)) 97 | 98 | def start(self) -> None: 99 | """Start watching all tools directories for changes.""" 100 | # Initialize shared observer if not already done 101 | if ToolWatcher._shared_observer is None: 102 | ToolWatcher._shared_observer = Observer() 103 | 104 | # Create handler for this instance 105 | self.tool_change_handler = self.ToolChangeHandler(self.tool_registry) 106 | registry_id = id(self.tool_registry) 107 | 108 | # Get tools directories to watch 109 | tools_dirs = self.tool_registry.get_tools_dirs() 110 | 111 | for tools_dir in tools_dirs: 112 | dir_str = str(tools_dir) 113 | 114 | # Initialize the registry handlers dict for this directory if needed 115 | if dir_str not in ToolWatcher._registry_handlers: 116 | ToolWatcher._registry_handlers[dir_str] = {} 117 | 118 | # Store this handler with its registry id 119 | ToolWatcher._registry_handlers[dir_str][registry_id] = self.tool_change_handler 120 | 121 | # Schedule or update the master handler for this directory 122 | if dir_str not in ToolWatcher._watched_dirs: 123 | # First time seeing this directory, create a master handler 124 | master_handler = self.MasterChangeHandler(dir_str) 125 | ToolWatcher._shared_observer.schedule(master_handler, dir_str, recursive=False) 126 | ToolWatcher._watched_dirs.add(dir_str) 127 | logger.debug("tools_dir=<%s> | started watching tools directory", tools_dir) 128 | else: 129 | # Directory already being watched, just log it 130 | logger.debug("tools_dir=<%s> | directory already being watched", tools_dir) 131 | 132 | # Start the observer if not already started 133 | if not ToolWatcher._observer_started: 134 | ToolWatcher._shared_observer.start() 135 | ToolWatcher._observer_started = True 136 | logger.debug("tool directory watching initialized") 137 | -------------------------------------------------------------------------------- /src/strands/types/__init__.py: -------------------------------------------------------------------------------- 1 | """SDK type definitions.""" 2 | 3 | from .collections import PaginatedList 4 | 5 | __all__ = ["PaginatedList"] 6 | -------------------------------------------------------------------------------- /src/strands/types/collections.py: -------------------------------------------------------------------------------- 1 | """Generic collection types for the Strands SDK.""" 2 | 3 | from typing import Generic, List, Optional, TypeVar 4 | 5 | T = TypeVar("T") 6 | 7 | 8 | class PaginatedList(list, Generic[T]): 9 | """A generic list-like object that includes a pagination token. 10 | 11 | This maintains backwards compatibility by inheriting from list, 12 | so existing code that expects List[T] will continue to work. 13 | """ 14 | 15 | def __init__(self, data: List[T], token: Optional[str] = None): 16 | """Initialize a PaginatedList with data and an optional pagination token. 17 | 18 | Args: 19 | data: The list of items to store. 20 | token: Optional pagination token for retrieving additional items. 21 | """ 22 | super().__init__(data) 23 | self.pagination_token = token 24 | -------------------------------------------------------------------------------- /src/strands/types/content.py: -------------------------------------------------------------------------------- 1 | """Content-related type definitions for the SDK. 2 | 3 | This module defines the types used to represent messages, content blocks, and other content-related structures in the 4 | SDK. These types are modeled after the Bedrock API. 5 | 6 | - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html 7 | """ 8 | 9 | from typing import Dict, List, Literal, Optional 10 | 11 | from typing_extensions import TypedDict 12 | 13 | from .media import DocumentContent, ImageContent, VideoContent 14 | from .tools import ToolResult, ToolUse 15 | 16 | 17 | class GuardContentText(TypedDict): 18 | """Text content to be evaluated by guardrails. 19 | 20 | Attributes: 21 | qualifiers: The qualifiers describing the text block. 22 | text: The input text details to be evaluated by the guardrail. 23 | """ 24 | 25 | qualifiers: List[Literal["grounding_source", "query", "guard_content"]] 26 | text: str 27 | 28 | 29 | class GuardContent(TypedDict): 30 | """Content block to be evaluated by guardrails. 31 | 32 | Attributes: 33 | text: Text within content block to be evaluated by the guardrail. 34 | """ 35 | 36 | text: GuardContentText 37 | 38 | 39 | class ReasoningTextBlock(TypedDict, total=False): 40 | """Contains the reasoning that the model used to return the output. 41 | 42 | Attributes: 43 | signature: A token that verifies that the reasoning text was generated by the model. 44 | text: The reasoning that the model used to return the output. 45 | """ 46 | 47 | signature: Optional[str] 48 | text: str 49 | 50 | 51 | class ReasoningContentBlock(TypedDict, total=False): 52 | """Contains content regarding the reasoning that is carried out by the model. 53 | 54 | Attributes: 55 | reasoningText: The reasoning that the model used to return the output. 56 | redactedContent: The content in the reasoning that was encrypted by the model provider for safety reasons. 57 | """ 58 | 59 | reasoningText: ReasoningTextBlock 60 | redactedContent: bytes 61 | 62 | 63 | class CachePoint(TypedDict): 64 | """A cache point configuration for optimizing conversation history. 65 | 66 | Attributes: 67 | type: The type of cache point, typically "default". 68 | """ 69 | 70 | type: str 71 | 72 | 73 | class ContentBlock(TypedDict, total=False): 74 | """A block of content for a message that you pass to, or receive from, a model. 75 | 76 | Attributes: 77 | cachePoint: A cache point configuration to optimize conversation history. 78 | document: A document to include in the message. 79 | guardContent: Contains the content to assess with the guardrail. 80 | image: Image to include in the message. 81 | reasoningContent: Contains content regarding the reasoning that is carried out by the model. 82 | text: Text to include in the message. 83 | toolResult: The result for a tool request that a model makes. 84 | toolUse: Information about a tool use request from a model. 85 | video: Video to include in the message. 86 | """ 87 | 88 | cachePoint: CachePoint 89 | document: DocumentContent 90 | guardContent: GuardContent 91 | image: ImageContent 92 | reasoningContent: ReasoningContentBlock 93 | text: str 94 | toolResult: ToolResult 95 | toolUse: ToolUse 96 | video: VideoContent 97 | 98 | 99 | class SystemContentBlock(TypedDict, total=False): 100 | """Contains configurations for instructions to provide the model for how to handle input. 101 | 102 | Attributes: 103 | guardContent: A content block to assess with the guardrail. 104 | text: A system prompt for the model. 105 | """ 106 | 107 | guardContent: GuardContent 108 | text: str 109 | 110 | 111 | class DeltaContent(TypedDict, total=False): 112 | """A block of content in a streaming response. 113 | 114 | Attributes: 115 | text: The content text. 116 | toolUse: Information about a tool that the model is requesting to use. 117 | """ 118 | 119 | text: str 120 | toolUse: Dict[Literal["input"], str] 121 | 122 | 123 | class ContentBlockStartToolUse(TypedDict): 124 | """The start of a tool use block. 125 | 126 | Attributes: 127 | name: The name of the tool that the model is requesting to use. 128 | toolUseId: The ID for the tool request. 129 | """ 130 | 131 | name: str 132 | toolUseId: str 133 | 134 | 135 | class ContentBlockStart(TypedDict, total=False): 136 | """Content block start information. 137 | 138 | Attributes: 139 | toolUse: Information about a tool that the model is requesting to use. 140 | """ 141 | 142 | toolUse: Optional[ContentBlockStartToolUse] 143 | 144 | 145 | class ContentBlockDelta(TypedDict): 146 | """The content block delta event. 147 | 148 | Attributes: 149 | contentBlockIndex: The block index for a content block delta event. 150 | delta: The delta for a content block delta event. 151 | """ 152 | 153 | contentBlockIndex: int 154 | delta: DeltaContent 155 | 156 | 157 | class ContentBlockStop(TypedDict): 158 | """A content block stop event. 159 | 160 | Attributes: 161 | contentBlockIndex: The index for a content block. 162 | """ 163 | 164 | contentBlockIndex: int 165 | 166 | 167 | Role = Literal["user", "assistant"] 168 | """Role of a message sender. 169 | 170 | - "user": Messages from the user to the assistant 171 | - "assistant": Messages from the assistant to the user 172 | """ 173 | 174 | 175 | class Message(TypedDict): 176 | """A message in a conversation with the agent. 177 | 178 | Attributes: 179 | content: The message content. 180 | role: The role of the message sender. 181 | """ 182 | 183 | content: List[ContentBlock] 184 | role: Role 185 | 186 | 187 | Messages = List[Message] 188 | """A list of messages representing a conversation.""" 189 | -------------------------------------------------------------------------------- /src/strands/types/event_loop.py: -------------------------------------------------------------------------------- 1 | """Event loop-related type definitions for the SDK.""" 2 | 3 | from typing import Literal 4 | 5 | from typing_extensions import TypedDict 6 | 7 | 8 | class Usage(TypedDict): 9 | """Token usage information for model interactions. 10 | 11 | Attributes: 12 | inputTokens: Number of tokens sent in the request to the model.. 13 | outputTokens: Number of tokens that the model generated for the request. 14 | totalTokens: Total number of tokens (input + output). 15 | """ 16 | 17 | inputTokens: int 18 | outputTokens: int 19 | totalTokens: int 20 | 21 | 22 | class Metrics(TypedDict): 23 | """Performance metrics for model interactions. 24 | 25 | Attributes: 26 | latencyMs (int): Latency of the model request in milliseconds. 27 | """ 28 | 29 | latencyMs: int 30 | 31 | 32 | StopReason = Literal[ 33 | "content_filtered", 34 | "end_turn", 35 | "guardrail_intervened", 36 | "max_tokens", 37 | "stop_sequence", 38 | "tool_use", 39 | ] 40 | """Reason for the model ending its response generation. 41 | 42 | - "content_filtered": Content was filtered due to policy violation 43 | - "end_turn": Normal completion of the response 44 | - "guardrail_intervened": Guardrail system intervened 45 | - "max_tokens": Maximum token limit reached 46 | - "stop_sequence": Stop sequence encountered 47 | - "tool_use": Model requested to use a tool 48 | """ 49 | -------------------------------------------------------------------------------- /src/strands/types/exceptions.py: -------------------------------------------------------------------------------- 1 | """Exception-related type definitions for the SDK.""" 2 | 3 | from typing import Any 4 | 5 | 6 | class EventLoopException(Exception): 7 | """Exception raised by the event loop.""" 8 | 9 | def __init__(self, original_exception: Exception, request_state: Any = None) -> None: 10 | """Initialize exception. 11 | 12 | Args: 13 | original_exception: The original exception that was raised. 14 | request_state: The state of the request at the time of the exception. 15 | """ 16 | self.original_exception = original_exception 17 | self.request_state = request_state if request_state is not None else {} 18 | super().__init__(str(original_exception)) 19 | 20 | 21 | class ContextWindowOverflowException(Exception): 22 | """Exception raised when the context window is exceeded. 23 | 24 | This exception is raised when the input to a model exceeds the maximum context window size that the model can 25 | handle. This typically occurs when the combined length of the conversation history, system prompt, and current 26 | message is too large for the model to process. 27 | """ 28 | 29 | pass 30 | 31 | 32 | class MCPClientInitializationError(Exception): 33 | """Raised when the MCP server fails to initialize properly.""" 34 | 35 | pass 36 | 37 | 38 | class ModelThrottledException(Exception): 39 | """Exception raised when the model is throttled. 40 | 41 | This exception is raised when the model is throttled by the service. This typically occurs when the service is 42 | throttling the requests from the client. 43 | """ 44 | 45 | def __init__(self, message: str) -> None: 46 | """Initialize exception. 47 | 48 | Args: 49 | message: The message from the service that describes the throttling. 50 | """ 51 | self.message = message 52 | super().__init__(message) 53 | 54 | pass 55 | 56 | 57 | class SessionException(Exception): 58 | """Exception raised when session operations fail.""" 59 | 60 | pass 61 | -------------------------------------------------------------------------------- /src/strands/types/media.py: -------------------------------------------------------------------------------- 1 | """Media-related type definitions for the SDK. 2 | 3 | These types are modeled after the Bedrock API. 4 | 5 | - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html 6 | """ 7 | 8 | from typing import Literal 9 | 10 | from typing_extensions import TypedDict 11 | 12 | DocumentFormat = Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"] 13 | """Supported document formats.""" 14 | 15 | 16 | class DocumentSource(TypedDict): 17 | """Contains the content of a document. 18 | 19 | Attributes: 20 | bytes: The binary content of the document. 21 | """ 22 | 23 | bytes: bytes 24 | 25 | 26 | class DocumentContent(TypedDict): 27 | """A document to include in a message. 28 | 29 | Attributes: 30 | format: The format of the document (e.g., "pdf", "txt"). 31 | name: The name of the document. 32 | source: The source containing the document's binary content. 33 | """ 34 | 35 | format: Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"] 36 | name: str 37 | source: DocumentSource 38 | 39 | 40 | ImageFormat = Literal["png", "jpeg", "gif", "webp"] 41 | """Supported image formats.""" 42 | 43 | 44 | class ImageSource(TypedDict): 45 | """Contains the content of an image. 46 | 47 | Attributes: 48 | bytes: The binary content of the image. 49 | """ 50 | 51 | bytes: bytes 52 | 53 | 54 | class ImageContent(TypedDict): 55 | """An image to include in a message. 56 | 57 | Attributes: 58 | format: The format of the image (e.g., "png", "jpeg"). 59 | source: The source containing the image's binary content. 60 | """ 61 | 62 | format: ImageFormat 63 | source: ImageSource 64 | 65 | 66 | VideoFormat = Literal["flv", "mkv", "mov", "mpeg", "mpg", "mp4", "three_gp", "webm", "wmv"] 67 | """Supported video formats.""" 68 | 69 | 70 | class VideoSource(TypedDict): 71 | """Contains the content of a video. 72 | 73 | Attributes: 74 | bytes: The binary content of the video. 75 | """ 76 | 77 | bytes: bytes 78 | 79 | 80 | class VideoContent(TypedDict): 81 | """A video to include in a message. 82 | 83 | Attributes: 84 | format: The format of the video (e.g., "mp4", "avi"). 85 | source: The source containing the video's binary content. 86 | """ 87 | 88 | format: VideoFormat 89 | source: VideoSource 90 | -------------------------------------------------------------------------------- /src/strands/types/session.py: -------------------------------------------------------------------------------- 1 | """Data models for session management.""" 2 | 3 | import base64 4 | import inspect 5 | from dataclasses import asdict, dataclass, field 6 | from datetime import datetime, timezone 7 | from enum import Enum 8 | from typing import TYPE_CHECKING, Any, Dict, Optional 9 | 10 | from .content import Message 11 | 12 | if TYPE_CHECKING: 13 | from ..agent.agent import Agent 14 | 15 | 16 | class SessionType(str, Enum): 17 | """Enumeration of session types. 18 | 19 | As sessions are expanded to support new usecases like multi-agent patterns, 20 | new types will be added here. 21 | """ 22 | 23 | AGENT = "AGENT" 24 | 25 | 26 | def encode_bytes_values(obj: Any) -> Any: 27 | """Recursively encode any bytes values in an object to base64. 28 | 29 | Handles dictionaries, lists, and nested structures. 30 | """ 31 | if isinstance(obj, bytes): 32 | return {"__bytes_encoded__": True, "data": base64.b64encode(obj).decode()} 33 | elif isinstance(obj, dict): 34 | return {k: encode_bytes_values(v) for k, v in obj.items()} 35 | elif isinstance(obj, list): 36 | return [encode_bytes_values(item) for item in obj] 37 | else: 38 | return obj 39 | 40 | 41 | def decode_bytes_values(obj: Any) -> Any: 42 | """Recursively decode any base64-encoded bytes values in an object. 43 | 44 | Handles dictionaries, lists, and nested structures. 45 | """ 46 | if isinstance(obj, dict): 47 | if obj.get("__bytes_encoded__") is True and "data" in obj: 48 | return base64.b64decode(obj["data"]) 49 | return {k: decode_bytes_values(v) for k, v in obj.items()} 50 | elif isinstance(obj, list): 51 | return [decode_bytes_values(item) for item in obj] 52 | else: 53 | return obj 54 | 55 | 56 | @dataclass 57 | class SessionMessage: 58 | """Message within a SessionAgent. 59 | 60 | Attributes: 61 | message: Message content 62 | message_id: Index of the message in the conversation history 63 | redact_message: If the original message is redacted, this is the new content to use 64 | created_at: ISO format timestamp for when this message was created 65 | updated_at: ISO format timestamp for when this message was last updated 66 | """ 67 | 68 | message: Message 69 | message_id: int 70 | redact_message: Optional[Message] = None 71 | created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) 72 | updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) 73 | 74 | @classmethod 75 | def from_message(cls, message: Message, index: int) -> "SessionMessage": 76 | """Convert from a Message, base64 encoding bytes values.""" 77 | return cls( 78 | message=message, 79 | message_id=index, 80 | created_at=datetime.now(timezone.utc).isoformat(), 81 | updated_at=datetime.now(timezone.utc).isoformat(), 82 | ) 83 | 84 | def to_message(self) -> Message: 85 | """Convert SessionMessage back to a Message, decoding any bytes values. 86 | 87 | If the message was redacted, return the redact content instead. 88 | """ 89 | if self.redact_message is not None: 90 | return self.redact_message 91 | else: 92 | return self.message 93 | 94 | @classmethod 95 | def from_dict(cls, env: dict[str, Any]) -> "SessionMessage": 96 | """Initialize a SessionMessage from a dictionary, ignoring keys that are not class parameters.""" 97 | extracted_relevant_parameters = {k: v for k, v in env.items() if k in inspect.signature(cls).parameters} 98 | return cls(**decode_bytes_values(extracted_relevant_parameters)) 99 | 100 | def to_dict(self) -> dict[str, Any]: 101 | """Convert the SessionMessage to a dictionary representation.""" 102 | return encode_bytes_values(asdict(self)) # type: ignore 103 | 104 | 105 | @dataclass 106 | class SessionAgent: 107 | """Agent that belongs to a Session.""" 108 | 109 | agent_id: str 110 | state: Dict[str, Any] 111 | conversation_manager_state: Dict[str, Any] 112 | created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) 113 | updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) 114 | 115 | @classmethod 116 | def from_agent(cls, agent: "Agent") -> "SessionAgent": 117 | """Convert an Agent to a SessionAgent.""" 118 | if agent.agent_id is None: 119 | raise ValueError("agent_id needs to be defined.") 120 | return cls( 121 | agent_id=agent.agent_id, 122 | conversation_manager_state=agent.conversation_manager.get_state(), 123 | state=agent.state.get(), 124 | ) 125 | 126 | @classmethod 127 | def from_dict(cls, env: dict[str, Any]) -> "SessionAgent": 128 | """Initialize a SessionAgent from a dictionary, ignoring keys that are not calss parameters.""" 129 | return cls(**{k: v for k, v in env.items() if k in inspect.signature(cls).parameters}) 130 | 131 | def to_dict(self) -> dict[str, Any]: 132 | """Convert the SessionAgent to a dictionary representation.""" 133 | return asdict(self) 134 | 135 | 136 | @dataclass 137 | class Session: 138 | """Session data model.""" 139 | 140 | session_id: str 141 | session_type: SessionType 142 | created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) 143 | updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) 144 | 145 | @classmethod 146 | def from_dict(cls, env: dict[str, Any]) -> "Session": 147 | """Initialize a Session from a dictionary, ignoring keys that are not calss parameters.""" 148 | return cls(**{k: v for k, v in env.items() if k in inspect.signature(cls).parameters}) 149 | 150 | def to_dict(self) -> dict[str, Any]: 151 | """Convert the Session to a dictionary representation.""" 152 | return asdict(self) 153 | -------------------------------------------------------------------------------- /src/strands/types/traces.py: -------------------------------------------------------------------------------- 1 | """Tracing type definitions for the SDK.""" 2 | 3 | from typing import List, Union 4 | 5 | AttributeValue = Union[str, bool, float, int, List[str], List[bool], List[float], List[int]] 6 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/strands-agents/sdk-python/f5e24d402bb22dc1a54c94dd030b4be0f7e73261/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import logging 3 | import os 4 | import sys 5 | 6 | import boto3 7 | import moto 8 | import pytest 9 | 10 | ## Moto 11 | 12 | # Get the log level from the environment variable 13 | log_level = os.environ.get("LOG_LEVEL", "INFO").upper() 14 | 15 | logging.getLogger("strands").setLevel(log_level) 16 | logging.basicConfig( 17 | format="%(levelname)s | %(name)s | %(message)s", handlers=[logging.StreamHandler(stream=sys.stdout)] 18 | ) 19 | 20 | 21 | @pytest.fixture 22 | def moto_env(monkeypatch): 23 | monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test") 24 | monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test") 25 | monkeypatch.setenv("AWS_SECURITY_TOKEN", "test") 26 | monkeypatch.setenv("AWS_DEFAULT_REGION", "us-west-2") 27 | monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False) 28 | monkeypatch.delenv("OTEL_EXPORTER_OTLP_HEADERS", raising=False) 29 | 30 | 31 | @pytest.fixture 32 | def moto_mock_aws(): 33 | with moto.mock_aws(): 34 | yield 35 | 36 | 37 | @pytest.fixture 38 | def moto_cloudwatch_client(): 39 | return boto3.client("cloudwatch") 40 | 41 | 42 | ## Boto3 43 | 44 | 45 | @pytest.fixture 46 | def boto3_profile_name(): 47 | return "test-profile" 48 | 49 | 50 | @pytest.fixture 51 | def boto3_profile(boto3_profile_name): 52 | config = configparser.ConfigParser() 53 | config[boto3_profile_name] = { 54 | "aws_access_key_id": "test", 55 | "aws_secret_access_key": "test", 56 | } 57 | 58 | return config 59 | 60 | 61 | @pytest.fixture 62 | def boto3_profile_path(boto3_profile, tmp_path, monkeypatch): 63 | path = tmp_path / ".aws/credentials" 64 | path.parent.mkdir(exist_ok=True) 65 | with path.open("w") as fp: 66 | boto3_profile.write(fp) 67 | 68 | monkeypatch.setenv("AWS_SHARED_CREDENTIALS_FILE", str(path)) 69 | 70 | return path 71 | 72 | 73 | ## Async 74 | 75 | 76 | @pytest.fixture(scope="session") 77 | def agenerator(): 78 | async def agenerator(items): 79 | for item in items: 80 | yield item 81 | 82 | return agenerator 83 | 84 | 85 | @pytest.fixture(scope="session") 86 | def alist(): 87 | async def alist(items): 88 | return [item async for item in items] 89 | 90 | return alist 91 | 92 | 93 | ## Itertools 94 | 95 | 96 | @pytest.fixture(scope="session") 97 | def generate(): 98 | def generate(generator): 99 | events = [] 100 | 101 | try: 102 | while True: 103 | event = next(generator) 104 | events.append(event) 105 | 106 | except StopIteration as stop: 107 | return events, stop.value 108 | 109 | return generate 110 | -------------------------------------------------------------------------------- /tests/fixtures/mock_hook_provider.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, Tuple, Type 2 | 3 | from strands.hooks import HookEvent, HookProvider, HookRegistry 4 | 5 | 6 | class MockHookProvider(HookProvider): 7 | def __init__(self, event_types: list[Type]): 8 | self.events_received = [] 9 | self.events_types = event_types 10 | 11 | def get_events(self) -> Tuple[int, Iterator[HookEvent]]: 12 | return len(self.events_received), iter(self.events_received) 13 | 14 | def register_hooks(self, registry: HookRegistry) -> None: 15 | for event_type in self.events_types: 16 | registry.add_callback(event_type, self.add_event) 17 | 18 | def add_event(self, event: HookEvent) -> None: 19 | self.events_received.append(event) 20 | -------------------------------------------------------------------------------- /tests/fixtures/mock_session_repository.py: -------------------------------------------------------------------------------- 1 | from strands.session.session_repository import SessionRepository 2 | from strands.types.exceptions import SessionException 3 | from strands.types.session import SessionAgent, SessionMessage 4 | 5 | 6 | class MockedSessionRepository(SessionRepository): 7 | """Mock repository for testing.""" 8 | 9 | def __init__(self): 10 | """Initialize with empty storage.""" 11 | self.sessions = {} 12 | self.agents = {} 13 | self.messages = {} 14 | 15 | def create_session(self, session) -> None: 16 | """Create a session.""" 17 | session_id = session.session_id 18 | if session_id in self.sessions: 19 | raise SessionException(f"Session {session_id} already exists") 20 | self.sessions[session_id] = session 21 | self.agents[session_id] = {} 22 | self.messages[session_id] = {} 23 | 24 | def read_session(self, session_id) -> SessionAgent: 25 | """Read a session.""" 26 | return self.sessions.get(session_id) 27 | 28 | def create_agent(self, session_id, session_agent) -> None: 29 | """Create an agent.""" 30 | agent_id = session_agent.agent_id 31 | if session_id not in self.sessions: 32 | raise SessionException(f"Session {session_id} does not exist") 33 | if agent_id in self.agents.get(session_id, {}): 34 | raise SessionException(f"Agent {agent_id} already exists in session {session_id}") 35 | self.agents.setdefault(session_id, {})[agent_id] = session_agent 36 | self.messages.setdefault(session_id, {}).setdefault(agent_id, {}) 37 | return session_agent 38 | 39 | def read_agent(self, session_id, agent_id) -> SessionAgent: 40 | """Read an agent.""" 41 | if session_id not in self.sessions: 42 | return None 43 | return self.agents.get(session_id, {}).get(agent_id) 44 | 45 | def update_agent(self, session_id, session_agent) -> None: 46 | """Update an agent.""" 47 | agent_id = session_agent.agent_id 48 | if session_id not in self.sessions: 49 | raise SessionException(f"Session {session_id} does not exist") 50 | if agent_id not in self.agents.get(session_id, {}): 51 | raise SessionException(f"Agent {agent_id} does not exist in session {session_id}") 52 | self.agents[session_id][agent_id] = session_agent 53 | 54 | def create_message(self, session_id, agent_id, session_message) -> None: 55 | """Create a message.""" 56 | message_id = session_message.message_id 57 | if session_id not in self.sessions: 58 | raise SessionException(f"Session {session_id} does not exist") 59 | if agent_id not in self.agents.get(session_id, {}): 60 | raise SessionException(f"Agent {agent_id} does not exists in session {session_id}") 61 | if message_id in self.messages.get(session_id, {}).get(agent_id, {}): 62 | raise SessionException(f"Message {message_id} already exists in agent {agent_id} in session {session_id}") 63 | self.messages.setdefault(session_id, {}).setdefault(agent_id, {})[message_id] = session_message 64 | 65 | def read_message(self, session_id, agent_id, message_id) -> SessionMessage: 66 | """Read a message.""" 67 | if session_id not in self.sessions: 68 | return None 69 | if agent_id not in self.agents.get(session_id, {}): 70 | return None 71 | return self.messages.get(session_id, {}).get(agent_id, {}).get(message_id) 72 | 73 | def update_message(self, session_id, agent_id, session_message) -> None: 74 | """Update a message.""" 75 | 76 | message_id = session_message.message_id 77 | if session_id not in self.sessions: 78 | raise SessionException(f"Session {session_id} does not exist") 79 | if agent_id not in self.agents.get(session_id, {}): 80 | raise SessionException(f"Agent {agent_id} does not exist in session {session_id}") 81 | if message_id not in self.messages.get(session_id, {}).get(agent_id, {}): 82 | raise SessionException(f"Message {message_id} does not exist in session {session_id}") 83 | self.messages[session_id][agent_id][message_id] = session_message 84 | 85 | def list_messages(self, session_id, agent_id, limit=None, offset=0) -> list[SessionMessage]: 86 | """List messages.""" 87 | if session_id not in self.sessions: 88 | return [] 89 | if agent_id not in self.agents.get(session_id, {}): 90 | return [] 91 | 92 | messages = self.messages.get(session_id, {}).get(agent_id, {}) 93 | sorted_messages = [messages[key] for key in sorted(messages.keys())] 94 | 95 | if limit is not None: 96 | return sorted_messages[offset : offset + limit] 97 | return sorted_messages[offset:] 98 | -------------------------------------------------------------------------------- /tests/fixtures/mocked_model_provider.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Any, AsyncGenerator, Iterable, Optional, Type, TypedDict, TypeVar, Union 3 | 4 | from pydantic import BaseModel 5 | 6 | from strands.models import Model 7 | from strands.types.content import Message, Messages 8 | from strands.types.event_loop import StopReason 9 | from strands.types.streaming import StreamEvent 10 | from strands.types.tools import ToolSpec 11 | 12 | T = TypeVar("T", bound=BaseModel) 13 | 14 | 15 | class RedactionMessage(TypedDict): 16 | redactedUserContent: str 17 | redactedAssistantContent: str 18 | 19 | 20 | class MockedModelProvider(Model): 21 | """A mock implementation of the Model interface for testing purposes. 22 | 23 | This class simulates a model provider by returning pre-defined agent responses 24 | in sequence. It implements the Model interface methods and provides functionality 25 | to stream mock responses as events. 26 | """ 27 | 28 | def __init__(self, agent_responses: list[Union[Message, RedactionMessage]]): 29 | self.agent_responses = agent_responses 30 | self.index = 0 31 | 32 | def format_chunk(self, event: Any) -> StreamEvent: 33 | return event 34 | 35 | def format_request( 36 | self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None 37 | ) -> Any: 38 | return None 39 | 40 | def get_config(self) -> Any: 41 | pass 42 | 43 | def update_config(self, **model_config: Any) -> None: 44 | pass 45 | 46 | async def structured_output( 47 | self, 48 | output_model: Type[T], 49 | prompt: Messages, 50 | system_prompt: Optional[str] = None, 51 | **kwargs: Any, 52 | ) -> AsyncGenerator[Any, None]: 53 | pass 54 | 55 | async def stream( 56 | self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None 57 | ) -> AsyncGenerator[Any, None]: 58 | events = self.map_agent_message_to_events(self.agent_responses[self.index]) 59 | for event in events: 60 | yield event 61 | 62 | self.index += 1 63 | 64 | def map_agent_message_to_events(self, agent_message: Union[Message, RedactionMessage]) -> Iterable[dict[str, Any]]: 65 | stop_reason: StopReason = "end_turn" 66 | yield {"messageStart": {"role": "assistant"}} 67 | if agent_message.get("redactedAssistantContent"): 68 | yield {"redactContent": {"redactUserContentMessage": agent_message["redactedUserContent"]}} 69 | yield {"contentBlockStart": {"start": {}}} 70 | yield {"contentBlockDelta": {"delta": {"text": agent_message["redactedAssistantContent"]}}} 71 | yield {"contentBlockStop": {}} 72 | stop_reason = "guardrail_intervened" 73 | else: 74 | for content in agent_message["content"]: 75 | if "text" in content: 76 | yield {"contentBlockStart": {"start": {}}} 77 | yield {"contentBlockDelta": {"delta": {"text": content["text"]}}} 78 | yield {"contentBlockStop": {}} 79 | if "toolUse" in content: 80 | stop_reason = "tool_use" 81 | yield { 82 | "contentBlockStart": { 83 | "start": { 84 | "toolUse": { 85 | "name": content["toolUse"]["name"], 86 | "toolUseId": content["toolUse"]["toolUseId"], 87 | } 88 | } 89 | } 90 | } 91 | yield { 92 | "contentBlockDelta": {"delta": {"toolUse": {"input": json.dumps(content["toolUse"]["input"])}}} 93 | } 94 | yield {"contentBlockStop": {}} 95 | 96 | yield {"messageStop": {"stopReason": stop_reason}} 97 | -------------------------------------------------------------------------------- /tests/strands/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/strands-agents/sdk-python/f5e24d402bb22dc1a54c94dd030b4be0f7e73261/tests/strands/__init__.py -------------------------------------------------------------------------------- /tests/strands/agent/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/strands-agents/sdk-python/f5e24d402bb22dc1a54c94dd030b4be0f7e73261/tests/strands/agent/__init__.py -------------------------------------------------------------------------------- /tests/strands/agent/test_agent_result.py: -------------------------------------------------------------------------------- 1 | import unittest.mock 2 | from typing import cast 3 | 4 | import pytest 5 | 6 | from strands.agent.agent_result import AgentResult 7 | from strands.telemetry.metrics import EventLoopMetrics 8 | from strands.types.content import Message 9 | from strands.types.streaming import StopReason 10 | 11 | 12 | @pytest.fixture 13 | def mock_metrics(): 14 | return unittest.mock.Mock(spec=EventLoopMetrics) 15 | 16 | 17 | @pytest.fixture 18 | def simple_message(): 19 | return {"role": "assistant", "content": [{"text": "Hello world!"}]} 20 | 21 | 22 | @pytest.fixture 23 | def complex_message(): 24 | return { 25 | "role": "assistant", 26 | "content": [ 27 | {"text": "First paragraph"}, 28 | {"text": "Second paragraph"}, 29 | {"non_text_content": "This should be ignored"}, 30 | {"text": "Third paragraph"}, 31 | ], 32 | } 33 | 34 | 35 | @pytest.fixture 36 | def empty_message(): 37 | return {"role": "assistant", "content": []} 38 | 39 | 40 | def test__init__(mock_metrics, simple_message: Message): 41 | """Test that AgentResult can be properly initialized with all required fields.""" 42 | stop_reason: StopReason = "end_turn" 43 | state = {"key": "value"} 44 | 45 | result = AgentResult(stop_reason=stop_reason, message=simple_message, metrics=mock_metrics, state=state) 46 | 47 | assert result.stop_reason == stop_reason 48 | assert result.message == simple_message 49 | assert result.metrics == mock_metrics 50 | assert result.state == state 51 | 52 | 53 | def test__str__simple(mock_metrics, simple_message: Message): 54 | """Test that str() works with a simple message.""" 55 | result = AgentResult(stop_reason="end_turn", message=simple_message, metrics=mock_metrics, state={}) 56 | 57 | message_string = str(result) 58 | assert message_string == "Hello world!\n" 59 | 60 | 61 | def test__str__complex(mock_metrics, complex_message: Message): 62 | """Test that str() works with a complex message with multiple text blocks.""" 63 | result = AgentResult(stop_reason="end_turn", message=complex_message, metrics=mock_metrics, state={}) 64 | 65 | message_string = str(result) 66 | assert message_string == "First paragraph\nSecond paragraph\nThird paragraph\n" 67 | 68 | 69 | def test__str__empty(mock_metrics, empty_message: Message): 70 | """Test that str() works with an empty message.""" 71 | result = AgentResult(stop_reason="end_turn", message=empty_message, metrics=mock_metrics, state={}) 72 | 73 | message_string = str(result) 74 | assert message_string == "" 75 | 76 | 77 | def test__str__no_content(mock_metrics): 78 | """Test that str() works with a message that has no content field.""" 79 | message_without_content = cast(Message, {"role": "assistant"}) 80 | 81 | result = AgentResult(stop_reason="end_turn", message=message_without_content, metrics=mock_metrics, state={}) 82 | 83 | message_string = str(result) 84 | assert message_string == "" 85 | 86 | 87 | def test__str__non_dict_content(mock_metrics): 88 | """Test that str() handles non-dictionary content items gracefully.""" 89 | message_with_non_dict = cast( 90 | Message, 91 | {"role": "assistant", "content": [{"text": "Valid text"}, "Not a dictionary", {"text": "More valid text"}]}, 92 | ) 93 | 94 | result = AgentResult(stop_reason="end_turn", message=message_with_non_dict, metrics=mock_metrics, state={}) 95 | 96 | message_string = str(result) 97 | assert message_string == "Valid text\nMore valid text\n" 98 | -------------------------------------------------------------------------------- /tests/strands/agent/test_agent_state.py: -------------------------------------------------------------------------------- 1 | """Tests for AgentState class.""" 2 | 3 | import pytest 4 | 5 | from strands import Agent, tool 6 | from strands.agent.state import AgentState 7 | from strands.types.content import Messages 8 | 9 | from ...fixtures.mocked_model_provider import MockedModelProvider 10 | 11 | 12 | def test_set_and_get(): 13 | """Test basic set and get operations.""" 14 | state = AgentState() 15 | state.set("key", "value") 16 | assert state.get("key") == "value" 17 | 18 | 19 | def test_get_nonexistent_key(): 20 | """Test getting nonexistent key returns None.""" 21 | state = AgentState() 22 | assert state.get("nonexistent") is None 23 | 24 | 25 | def test_get_entire_state(): 26 | """Test getting entire state when no key specified.""" 27 | state = AgentState() 28 | state.set("key1", "value1") 29 | state.set("key2", "value2") 30 | 31 | result = state.get() 32 | assert result == {"key1": "value1", "key2": "value2"} 33 | 34 | 35 | def test_initialize_and_get_entire_state(): 36 | """Test getting entire state when no key specified.""" 37 | state = AgentState({"key1": "value1", "key2": "value2"}) 38 | 39 | result = state.get() 40 | assert result == {"key1": "value1", "key2": "value2"} 41 | 42 | 43 | def test_initialize_with_error(): 44 | with pytest.raises(ValueError, match="not JSON serializable"): 45 | AgentState({"object", object()}) 46 | 47 | 48 | def test_delete(): 49 | """Test deleting keys.""" 50 | state = AgentState() 51 | state.set("key1", "value1") 52 | state.set("key2", "value2") 53 | 54 | state.delete("key1") 55 | 56 | assert state.get("key1") is None 57 | assert state.get("key2") == "value2" 58 | 59 | 60 | def test_delete_nonexistent_key(): 61 | """Test deleting nonexistent key doesn't raise error.""" 62 | state = AgentState() 63 | state.delete("nonexistent") # Should not raise 64 | 65 | 66 | def test_json_serializable_values(): 67 | """Test that only JSON-serializable values are accepted.""" 68 | state = AgentState() 69 | 70 | # Valid JSON types 71 | state.set("string", "test") 72 | state.set("int", 42) 73 | state.set("bool", True) 74 | state.set("list", [1, 2, 3]) 75 | state.set("dict", {"nested": "value"}) 76 | state.set("null", None) 77 | 78 | # Invalid JSON types should raise ValueError 79 | with pytest.raises(ValueError, match="not JSON serializable"): 80 | state.set("function", lambda x: x) 81 | 82 | with pytest.raises(ValueError, match="not JSON serializable"): 83 | state.set("object", object()) 84 | 85 | 86 | def test_key_validation(): 87 | """Test key validation for set and delete operations.""" 88 | state = AgentState() 89 | 90 | # Invalid keys for set 91 | with pytest.raises(ValueError, match="Key cannot be None"): 92 | state.set(None, "value") 93 | 94 | with pytest.raises(ValueError, match="Key cannot be empty"): 95 | state.set("", "value") 96 | 97 | with pytest.raises(ValueError, match="Key must be a string"): 98 | state.set(123, "value") 99 | 100 | # Invalid keys for delete 101 | with pytest.raises(ValueError, match="Key cannot be None"): 102 | state.delete(None) 103 | 104 | with pytest.raises(ValueError, match="Key cannot be empty"): 105 | state.delete("") 106 | 107 | 108 | def test_initial_state(): 109 | """Test initialization with initial state.""" 110 | initial = {"key1": "value1", "key2": "value2"} 111 | state = AgentState(initial_state=initial) 112 | 113 | assert state.get("key1") == "value1" 114 | assert state.get("key2") == "value2" 115 | assert state.get() == initial 116 | 117 | 118 | def test_agent_state_update_from_tool(): 119 | @tool 120 | def update_state(agent: Agent): 121 | agent.state.set("hello", "world") 122 | agent.state.set("foo", "baz") 123 | 124 | agent_messages: Messages = [ 125 | { 126 | "role": "assistant", 127 | "content": [{"toolUse": {"name": "update_state", "toolUseId": "123", "input": {}}}], 128 | }, 129 | {"role": "assistant", "content": [{"text": "I invoked a tool!"}]}, 130 | ] 131 | mocked_model_provider = MockedModelProvider(agent_messages) 132 | 133 | agent = Agent( 134 | model=mocked_model_provider, 135 | tools=[update_state], 136 | state={"foo": "bar"}, 137 | ) 138 | 139 | assert agent.state.get("hello") is None 140 | assert agent.state.get("foo") == "bar" 141 | 142 | agent("Invoke Mocked!") 143 | 144 | assert agent.state.get("hello") == "world" 145 | assert agent.state.get("foo") == "baz" 146 | -------------------------------------------------------------------------------- /tests/strands/event_loop/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/strands-agents/sdk-python/f5e24d402bb22dc1a54c94dd030b4be0f7e73261/tests/strands/event_loop/__init__.py -------------------------------------------------------------------------------- /tests/strands/experimental/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/strands-agents/sdk-python/f5e24d402bb22dc1a54c94dd030b4be0f7e73261/tests/strands/experimental/__init__.py -------------------------------------------------------------------------------- /tests/strands/experimental/hooks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/strands-agents/sdk-python/f5e24d402bb22dc1a54c94dd030b4be0f7e73261/tests/strands/experimental/hooks/__init__.py -------------------------------------------------------------------------------- /tests/strands/experimental/hooks/test_events.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import Mock 2 | 3 | import pytest 4 | 5 | from strands.experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent 6 | from strands.hooks import ( 7 | AfterInvocationEvent, 8 | AgentInitializedEvent, 9 | BeforeInvocationEvent, 10 | MessageAddedEvent, 11 | ) 12 | from strands.types.tools import ToolResult, ToolUse 13 | 14 | 15 | @pytest.fixture 16 | def agent(): 17 | return Mock() 18 | 19 | 20 | @pytest.fixture 21 | def tool(): 22 | tool = Mock() 23 | tool.tool_name = "test_tool" 24 | return tool 25 | 26 | 27 | @pytest.fixture 28 | def tool_use(): 29 | return ToolUse(name="test_tool", toolUseId="123", input={"param": "value"}) 30 | 31 | 32 | @pytest.fixture 33 | def tool_invocation_state(): 34 | return {"param": "value"} 35 | 36 | 37 | @pytest.fixture 38 | def tool_result(): 39 | return ToolResult(content=[{"text": "result"}], status="success", toolUseId="123") 40 | 41 | 42 | @pytest.fixture 43 | def initialized_event(agent): 44 | return AgentInitializedEvent(agent=agent) 45 | 46 | 47 | @pytest.fixture 48 | def start_request_event(agent): 49 | return BeforeInvocationEvent(agent=agent) 50 | 51 | 52 | @pytest.fixture 53 | def messaged_added_event(agent): 54 | return MessageAddedEvent(agent=agent, message=Mock()) 55 | 56 | 57 | @pytest.fixture 58 | def end_request_event(agent): 59 | return AfterInvocationEvent(agent=agent) 60 | 61 | 62 | @pytest.fixture 63 | def before_tool_event(agent, tool, tool_use, tool_invocation_state): 64 | return BeforeToolInvocationEvent( 65 | agent=agent, 66 | selected_tool=tool, 67 | tool_use=tool_use, 68 | invocation_state=tool_invocation_state, 69 | ) 70 | 71 | 72 | @pytest.fixture 73 | def after_tool_event(agent, tool, tool_use, tool_invocation_state, tool_result): 74 | return AfterToolInvocationEvent( 75 | agent=agent, 76 | selected_tool=tool, 77 | tool_use=tool_use, 78 | invocation_state=tool_invocation_state, 79 | result=tool_result, 80 | ) 81 | 82 | 83 | def test_event_should_reverse_callbacks( 84 | initialized_event, 85 | start_request_event, 86 | messaged_added_event, 87 | end_request_event, 88 | before_tool_event, 89 | after_tool_event, 90 | ): 91 | # note that we ignore E712 (explicit booleans) for consistency/readability purposes 92 | 93 | assert initialized_event.should_reverse_callbacks == False # noqa: E712 94 | 95 | assert messaged_added_event.should_reverse_callbacks == False # noqa: E712 96 | 97 | assert start_request_event.should_reverse_callbacks == False # noqa: E712 98 | assert end_request_event.should_reverse_callbacks == True # noqa: E712 99 | 100 | assert before_tool_event.should_reverse_callbacks == False # noqa: E712 101 | assert after_tool_event.should_reverse_callbacks == True # noqa: E712 102 | 103 | 104 | def test_message_added_event_cannot_write_properties(messaged_added_event): 105 | with pytest.raises(AttributeError, match="Property agent is not writable"): 106 | messaged_added_event.agent = Mock() 107 | with pytest.raises(AttributeError, match="Property message is not writable"): 108 | messaged_added_event.message = {} 109 | 110 | 111 | def test_before_tool_invocation_event_can_write_properties(before_tool_event): 112 | new_tool_use = ToolUse(name="new_tool", toolUseId="456", input={}) 113 | before_tool_event.selected_tool = None # Should not raise 114 | before_tool_event.tool_use = new_tool_use # Should not raise 115 | 116 | 117 | def test_before_tool_invocation_event_cannot_write_properties(before_tool_event): 118 | with pytest.raises(AttributeError, match="Property agent is not writable"): 119 | before_tool_event.agent = Mock() 120 | with pytest.raises(AttributeError, match="Property invocation_state is not writable"): 121 | before_tool_event.invocation_state = {} 122 | 123 | 124 | def test_after_tool_invocation_event_can_write_properties(after_tool_event): 125 | new_result = ToolResult(content=[{"text": "new result"}], status="success", toolUseId="456") 126 | after_tool_event.result = new_result # Should not raise 127 | 128 | 129 | def test_after_tool_invocation_event_cannot_write_properties(after_tool_event): 130 | with pytest.raises(AttributeError, match="Property agent is not writable"): 131 | after_tool_event.agent = Mock() 132 | with pytest.raises(AttributeError, match="Property selected_tool is not writable"): 133 | after_tool_event.selected_tool = None 134 | with pytest.raises(AttributeError, match="Property tool_use is not writable"): 135 | after_tool_event.tool_use = ToolUse(name="new", toolUseId="456", input={}) 136 | with pytest.raises(AttributeError, match="Property invocation_state is not writable"): 137 | after_tool_event.invocation_state = {} 138 | with pytest.raises(AttributeError, match="Property exception is not writable"): 139 | after_tool_event.exception = Exception("test") 140 | -------------------------------------------------------------------------------- /tests/strands/experimental/hooks/test_hook_registry.py: -------------------------------------------------------------------------------- 1 | import unittest.mock 2 | from dataclasses import dataclass 3 | from typing import List 4 | from unittest.mock import MagicMock, Mock 5 | 6 | import pytest 7 | 8 | from strands.hooks import HookEvent, HookProvider, HookRegistry 9 | 10 | 11 | @dataclass 12 | class TestEvent(HookEvent): 13 | @property 14 | def should_reverse_callbacks(self) -> bool: 15 | return False 16 | 17 | 18 | @dataclass 19 | class TestAfterEvent(HookEvent): 20 | @property 21 | def should_reverse_callbacks(self) -> bool: 22 | return True 23 | 24 | 25 | class TestHookProvider(HookProvider): 26 | """Test hook provider for testing hook registry.""" 27 | 28 | def __init__(self): 29 | self.registered = False 30 | 31 | def register_hooks(self, registry: HookRegistry) -> None: 32 | self.registered = True 33 | 34 | 35 | @pytest.fixture 36 | def hook_registry(): 37 | return HookRegistry() 38 | 39 | 40 | @pytest.fixture 41 | def test_event(): 42 | return TestEvent(agent=Mock()) 43 | 44 | 45 | @pytest.fixture 46 | def test_after_event(): 47 | return TestAfterEvent(agent=Mock()) 48 | 49 | 50 | def test_hook_registry_init(): 51 | """Test that HookRegistry initializes with an empty callbacks dictionary.""" 52 | registry = HookRegistry() 53 | assert registry._registered_callbacks == {} 54 | 55 | 56 | def test_add_callback(hook_registry, test_event): 57 | """Test that callbacks can be added to the registry.""" 58 | callback = unittest.mock.Mock() 59 | hook_registry.add_callback(TestEvent, callback) 60 | 61 | assert TestEvent in hook_registry._registered_callbacks 62 | assert callback in hook_registry._registered_callbacks[TestEvent] 63 | 64 | 65 | def test_add_multiple_callbacks_same_event(hook_registry, test_event): 66 | """Test that multiple callbacks can be added for the same event type.""" 67 | callback1 = unittest.mock.Mock() 68 | callback2 = unittest.mock.Mock() 69 | 70 | hook_registry.add_callback(TestEvent, callback1) 71 | hook_registry.add_callback(TestEvent, callback2) 72 | 73 | assert len(hook_registry._registered_callbacks[TestEvent]) == 2 74 | assert callback1 in hook_registry._registered_callbacks[TestEvent] 75 | assert callback2 in hook_registry._registered_callbacks[TestEvent] 76 | 77 | 78 | def test_add_hook(hook_registry): 79 | """Test that hooks can be added to the registry.""" 80 | hook_provider = MagicMock() 81 | hook_registry.add_hook(hook_provider) 82 | 83 | assert hook_provider.register_hooks.call_count == 1 84 | 85 | 86 | def test_get_callbacks_for_normal_event(hook_registry, test_event): 87 | """Test that get_callbacks_for returns callbacks in the correct order for normal events.""" 88 | callback1 = unittest.mock.Mock() 89 | callback2 = unittest.mock.Mock() 90 | 91 | hook_registry.add_callback(TestEvent, callback1) 92 | hook_registry.add_callback(TestEvent, callback2) 93 | 94 | callbacks = list(hook_registry.get_callbacks_for(test_event)) 95 | 96 | assert len(callbacks) == 2 97 | assert callbacks[0] == callback1 98 | assert callbacks[1] == callback2 99 | 100 | 101 | def test_get_callbacks_for_after_event(hook_registry, test_after_event): 102 | """Test that get_callbacks_for returns callbacks in reverse order for after events.""" 103 | callback1 = Mock() 104 | callback2 = Mock() 105 | 106 | hook_registry.add_callback(TestAfterEvent, callback1) 107 | hook_registry.add_callback(TestAfterEvent, callback2) 108 | 109 | callbacks = list(hook_registry.get_callbacks_for(test_after_event)) 110 | 111 | assert len(callbacks) == 2 112 | assert callbacks[0] == callback2 # Reverse order 113 | assert callbacks[1] == callback1 # Reverse order 114 | 115 | 116 | def test_invoke_callbacks(hook_registry, test_event): 117 | """Test that invoke_callbacks calls all registered callbacks for an event.""" 118 | callback1 = Mock() 119 | callback2 = Mock() 120 | 121 | hook_registry.add_callback(TestEvent, callback1) 122 | hook_registry.add_callback(TestEvent, callback2) 123 | 124 | hook_registry.invoke_callbacks(test_event) 125 | 126 | callback1.assert_called_once_with(test_event) 127 | callback2.assert_called_once_with(test_event) 128 | 129 | 130 | def test_invoke_callbacks_no_registered_callbacks(hook_registry, test_event): 131 | """Test that invoke_callbacks doesn't fail when there are no registered callbacks.""" 132 | # No callbacks registered 133 | hook_registry.invoke_callbacks(test_event) 134 | # Test passes if no exception is raised 135 | 136 | 137 | def test_invoke_callbacks_after_event(hook_registry, test_after_event): 138 | """Test that invoke_callbacks calls callbacks in reverse order for after events.""" 139 | call_order: List[str] = [] 140 | 141 | def callback1(_event): 142 | call_order.append("callback1") 143 | 144 | def callback2(_event): 145 | call_order.append("callback2") 146 | 147 | hook_registry.add_callback(TestAfterEvent, callback1) 148 | hook_registry.add_callback(TestAfterEvent, callback2) 149 | 150 | hook_registry.invoke_callbacks(test_after_event) 151 | 152 | assert call_order == ["callback2", "callback1"] # Reverse order 153 | 154 | 155 | def test_has_callbacks(hook_registry, test_event): 156 | """Test that has_callbacks returns correct boolean values.""" 157 | # Empty registry should return False 158 | assert not hook_registry.has_callbacks() 159 | 160 | # Registry with callbacks should return True 161 | callback = Mock() 162 | hook_registry.add_callback(TestEvent, callback) 163 | assert hook_registry.has_callbacks() 164 | 165 | # Test with multiple event types 166 | hook_registry.add_callback(TestAfterEvent, Mock()) 167 | assert hook_registry.has_callbacks() 168 | -------------------------------------------------------------------------------- /tests/strands/handlers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/strands-agents/sdk-python/f5e24d402bb22dc1a54c94dd030b4be0f7e73261/tests/strands/handlers/__init__.py -------------------------------------------------------------------------------- /tests/strands/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/strands-agents/sdk-python/f5e24d402bb22dc1a54c94dd030b4be0f7e73261/tests/strands/models/__init__.py -------------------------------------------------------------------------------- /tests/strands/models/test_model.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pydantic import BaseModel 3 | 4 | from strands.models import Model as SAModel 5 | 6 | 7 | class Person(BaseModel): 8 | name: str 9 | age: int 10 | 11 | 12 | class TestModel(SAModel): 13 | def update_config(self, **model_config): 14 | return model_config 15 | 16 | def get_config(self): 17 | return 18 | 19 | async def structured_output(self, output_model, prompt=None, system_prompt=None, **kwargs): 20 | yield {"output": output_model(name="test", age=20)} 21 | 22 | async def stream(self, messages, tool_specs=None, system_prompt=None): 23 | yield {"messageStart": {"role": "assistant"}} 24 | yield {"contentBlockStart": {"start": {}}} 25 | yield {"contentBlockDelta": {"delta": {"text": f"Processed {len(messages)} messages"}}} 26 | yield {"contentBlockStop": {}} 27 | yield {"messageStop": {"stopReason": "end_turn"}} 28 | yield { 29 | "metadata": { 30 | "usage": {"inputTokens": 10, "outputTokens": 15, "totalTokens": 25}, 31 | "metrics": {"latencyMs": 100}, 32 | } 33 | } 34 | 35 | 36 | @pytest.fixture 37 | def model(): 38 | return TestModel() 39 | 40 | 41 | @pytest.fixture 42 | def messages(): 43 | return [ 44 | { 45 | "role": "user", 46 | "content": [{"text": "hello"}], 47 | }, 48 | ] 49 | 50 | 51 | @pytest.fixture 52 | def tool_specs(): 53 | return [ 54 | { 55 | "name": "test_tool", 56 | "description": "A test tool", 57 | "inputSchema": { 58 | "json": { 59 | "type": "object", 60 | "properties": { 61 | "input": {"type": "string"}, 62 | }, 63 | "required": ["input"], 64 | }, 65 | }, 66 | }, 67 | ] 68 | 69 | 70 | @pytest.fixture 71 | def system_prompt(): 72 | return "s1" 73 | 74 | 75 | @pytest.mark.asyncio 76 | async def test_stream(model, messages, tool_specs, system_prompt, alist): 77 | response = model.stream(messages, tool_specs, system_prompt) 78 | 79 | tru_events = await alist(response) 80 | exp_events = [ 81 | {"messageStart": {"role": "assistant"}}, 82 | {"contentBlockStart": {"start": {}}}, 83 | {"contentBlockDelta": {"delta": {"text": "Processed 1 messages"}}}, 84 | {"contentBlockStop": {}}, 85 | {"messageStop": {"stopReason": "end_turn"}}, 86 | { 87 | "metadata": { 88 | "usage": {"inputTokens": 10, "outputTokens": 15, "totalTokens": 25}, 89 | "metrics": {"latencyMs": 100}, 90 | } 91 | }, 92 | ] 93 | assert tru_events == exp_events 94 | 95 | 96 | @pytest.mark.asyncio 97 | async def test_structured_output(model, alist): 98 | response = model.structured_output(Person, prompt=messages, system_prompt=system_prompt) 99 | events = await alist(response) 100 | 101 | tru_output = events[-1]["output"] 102 | exp_output = Person(name="test", age=20) 103 | assert tru_output == exp_output 104 | -------------------------------------------------------------------------------- /tests/strands/multiagent/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for the multiagent module.""" 2 | -------------------------------------------------------------------------------- /tests/strands/multiagent/a2a/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for the A2A module.""" 2 | -------------------------------------------------------------------------------- /tests/strands/multiagent/a2a/conftest.py: -------------------------------------------------------------------------------- 1 | """Common fixtures for A2A module tests.""" 2 | 3 | from unittest.mock import AsyncMock, MagicMock 4 | 5 | import pytest 6 | from a2a.server.agent_execution import RequestContext 7 | from a2a.server.events import EventQueue 8 | 9 | from strands.agent.agent import Agent as SAAgent 10 | from strands.agent.agent_result import AgentResult as SAAgentResult 11 | 12 | 13 | @pytest.fixture 14 | def mock_strands_agent(): 15 | """Create a mock Strands Agent for testing.""" 16 | agent = MagicMock(spec=SAAgent) 17 | agent.name = "Test Agent" 18 | agent.description = "A test agent for unit testing" 19 | 20 | # Setup default response 21 | mock_result = MagicMock(spec=SAAgentResult) 22 | mock_result.message = {"content": [{"text": "Test response"}]} 23 | agent.return_value = mock_result 24 | 25 | # Setup async methods 26 | agent.invoke_async = AsyncMock(return_value=mock_result) 27 | agent.stream_async = AsyncMock(return_value=iter([])) 28 | 29 | # Setup mock tool registry 30 | mock_tool_registry = MagicMock() 31 | mock_tool_registry.get_all_tools_config.return_value = {} 32 | agent.tool_registry = mock_tool_registry 33 | 34 | return agent 35 | 36 | 37 | @pytest.fixture 38 | def mock_request_context(): 39 | """Create a mock RequestContext for testing.""" 40 | context = MagicMock(spec=RequestContext) 41 | context.get_user_input.return_value = "Test input" 42 | return context 43 | 44 | 45 | @pytest.fixture 46 | def mock_event_queue(): 47 | """Create a mock EventQueue for testing.""" 48 | queue = MagicMock(spec=EventQueue) 49 | queue.enqueue_event = AsyncMock() 50 | return queue 51 | -------------------------------------------------------------------------------- /tests/strands/session/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for session management.""" 2 | -------------------------------------------------------------------------------- /tests/strands/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/strands-agents/sdk-python/f5e24d402bb22dc1a54c94dd030b4be0f7e73261/tests/strands/tools/__init__.py -------------------------------------------------------------------------------- /tests/strands/tools/mcp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/strands-agents/sdk-python/f5e24d402bb22dc1a54c94dd030b4be0f7e73261/tests/strands/tools/mcp/__init__.py -------------------------------------------------------------------------------- /tests/strands/tools/mcp/test_mcp_agent_tool.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock 2 | 3 | import pytest 4 | from mcp.types import Tool as MCPTool 5 | 6 | from strands.tools.mcp import MCPAgentTool, MCPClient 7 | 8 | 9 | @pytest.fixture 10 | def mock_mcp_tool(): 11 | mock_tool = MagicMock(spec=MCPTool) 12 | mock_tool.name = "test_tool" 13 | mock_tool.description = "A test tool" 14 | mock_tool.inputSchema = {"type": "object", "properties": {}} 15 | return mock_tool 16 | 17 | 18 | @pytest.fixture 19 | def mock_mcp_client(): 20 | mock_server = MagicMock(spec=MCPClient) 21 | mock_server.call_tool_sync.return_value = { 22 | "status": "success", 23 | "toolUseId": "test-123", 24 | "content": [{"text": "Success result"}], 25 | } 26 | return mock_server 27 | 28 | 29 | @pytest.fixture 30 | def mcp_agent_tool(mock_mcp_tool, mock_mcp_client): 31 | return MCPAgentTool(mock_mcp_tool, mock_mcp_client) 32 | 33 | 34 | def test_tool_name(mcp_agent_tool, mock_mcp_tool): 35 | assert mcp_agent_tool.tool_name == "test_tool" 36 | assert mcp_agent_tool.tool_name == mock_mcp_tool.name 37 | 38 | 39 | def test_tool_type(mcp_agent_tool): 40 | assert mcp_agent_tool.tool_type == "python" 41 | 42 | 43 | def test_tool_spec_with_description(mcp_agent_tool, mock_mcp_tool): 44 | tool_spec = mcp_agent_tool.tool_spec 45 | 46 | assert tool_spec["name"] == "test_tool" 47 | assert tool_spec["description"] == "A test tool" 48 | assert tool_spec["inputSchema"]["json"] == {"type": "object", "properties": {}} 49 | 50 | 51 | def test_tool_spec_without_description(mock_mcp_tool, mock_mcp_client): 52 | mock_mcp_tool.description = None 53 | 54 | agent_tool = MCPAgentTool(mock_mcp_tool, mock_mcp_client) 55 | tool_spec = agent_tool.tool_spec 56 | 57 | assert tool_spec["description"] == "Tool which performs test_tool" 58 | 59 | 60 | @pytest.mark.asyncio 61 | async def test_stream(mcp_agent_tool, mock_mcp_client, alist): 62 | tool_use = {"toolUseId": "test-123", "name": "test_tool", "input": {"param": "value"}} 63 | 64 | tru_events = await alist(mcp_agent_tool.stream(tool_use, {})) 65 | exp_events = [mock_mcp_client.call_tool_async.return_value] 66 | 67 | assert tru_events == exp_events 68 | mock_mcp_client.call_tool_async.assert_called_once_with( 69 | tool_use_id="test-123", name="test_tool", arguments={"param": "value"} 70 | ) 71 | -------------------------------------------------------------------------------- /tests/strands/tools/test_registry.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for the SDK tool registry module. 3 | """ 4 | 5 | from unittest.mock import MagicMock 6 | 7 | import pytest 8 | 9 | import strands 10 | from strands.tools import PythonAgentTool 11 | from strands.tools.decorator import DecoratedFunctionTool, tool 12 | from strands.tools.registry import ToolRegistry 13 | 14 | 15 | def test_load_tool_from_filepath_failure(): 16 | """Test error handling when load_tool fails.""" 17 | tool_registry = ToolRegistry() 18 | error_message = "Failed to load tool failing_tool: Tool file not found: /path/to/failing_tool.py" 19 | 20 | with pytest.raises(ValueError, match=error_message): 21 | tool_registry.load_tool_from_filepath("failing_tool", "/path/to/failing_tool.py") 22 | 23 | 24 | def test_process_tools_with_invalid_path(): 25 | """Test that process_tools raises an exception when a non-path string is passed.""" 26 | tool_registry = ToolRegistry() 27 | invalid_path = "not a filepath" 28 | 29 | with pytest.raises(ValueError, match=f"Failed to load tool {invalid_path.split('.')[0]}: Tool file not found:.*"): 30 | tool_registry.process_tools([invalid_path]) 31 | 32 | 33 | def test_register_tool_with_similar_name_raises(): 34 | tool_1 = PythonAgentTool(tool_name="tool-like-this", tool_spec=MagicMock(), tool_func=lambda: None) 35 | tool_2 = PythonAgentTool(tool_name="tool_like_this", tool_spec=MagicMock(), tool_func=lambda: None) 36 | 37 | tool_registry = ToolRegistry() 38 | 39 | tool_registry.register_tool(tool_1) 40 | 41 | with pytest.raises(ValueError) as err: 42 | tool_registry.register_tool(tool_2) 43 | 44 | assert ( 45 | str(err.value) == "Tool name 'tool_like_this' already exists as 'tool-like-this'. " 46 | "Cannot add a duplicate tool which differs by a '-' or '_'" 47 | ) 48 | 49 | 50 | def test_get_all_tool_specs_returns_right_tool_specs(): 51 | tool_1 = strands.tool(lambda a: a, name="tool_1") 52 | tool_2 = strands.tool(lambda b: b, name="tool_2") 53 | 54 | tool_registry = ToolRegistry() 55 | 56 | tool_registry.register_tool(tool_1) 57 | tool_registry.register_tool(tool_2) 58 | 59 | tool_specs = tool_registry.get_all_tool_specs() 60 | 61 | assert tool_specs == [ 62 | tool_1.tool_spec, 63 | tool_2.tool_spec, 64 | ] 65 | 66 | 67 | def test_scan_module_for_tools(): 68 | @tool 69 | def tool_function_1(a): 70 | return a 71 | 72 | @tool 73 | def tool_function_2(b): 74 | return b 75 | 76 | def tool_function_3(c): 77 | return c 78 | 79 | def tool_function_4(d): 80 | return d 81 | 82 | tool_function_4.tool_spec = "invalid" 83 | 84 | mock_module = MagicMock() 85 | mock_module.tool_function_1 = tool_function_1 86 | mock_module.tool_function_2 = tool_function_2 87 | mock_module.tool_function_3 = tool_function_3 88 | mock_module.tool_function_4 = tool_function_4 89 | 90 | tool_registry = ToolRegistry() 91 | 92 | tools = tool_registry._scan_module_for_tools(mock_module) 93 | 94 | assert len(tools) == 2 95 | assert all(isinstance(tool, DecoratedFunctionTool) for tool in tools) 96 | -------------------------------------------------------------------------------- /tests/strands/tools/test_watcher.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for the SDK tool watcher module. 3 | """ 4 | 5 | from unittest.mock import MagicMock, patch 6 | 7 | import pytest 8 | 9 | from strands.tools.registry import ToolRegistry 10 | from strands.tools.watcher import ToolWatcher 11 | 12 | 13 | def test_tool_watcher_initialization(): 14 | """Test that the handler initializes with the correct tool registry.""" 15 | tool_registry = ToolRegistry() 16 | watcher = ToolWatcher(tool_registry) 17 | assert watcher.tool_registry == tool_registry 18 | 19 | 20 | @pytest.mark.parametrize( 21 | "test_case", 22 | [ 23 | # Regular Python file - should reload 24 | { 25 | "description": "Python file", 26 | "src_path": "/path/to/test_tool.py", 27 | "is_directory": False, 28 | "should_reload": True, 29 | "expected_tool_name": "test_tool", 30 | }, 31 | # Non-Python file - should not reload 32 | { 33 | "description": "Non-Python file", 34 | "src_path": "/path/to/test_tool.txt", 35 | "is_directory": False, 36 | "should_reload": False, 37 | }, 38 | # __init__.py file - should not reload 39 | { 40 | "description": "Init file", 41 | "src_path": "/path/to/__init__.py", 42 | "is_directory": False, 43 | "should_reload": False, 44 | }, 45 | # Directory path - should not reload 46 | { 47 | "description": "Directory path", 48 | "src_path": "/path/to/tools_directory", 49 | "is_directory": True, 50 | "should_reload": False, 51 | }, 52 | # Python file marked as directory - should still reload 53 | { 54 | "description": "Python file marked as directory", 55 | "src_path": "/path/to/test_tool2.py", 56 | "is_directory": True, 57 | "should_reload": True, 58 | "expected_tool_name": "test_tool2", 59 | }, 60 | ], 61 | ) 62 | @patch.object(ToolRegistry, "reload_tool") 63 | def test_on_modified_cases(mock_reload_tool, test_case): 64 | """Test various cases for the on_modified method.""" 65 | tool_registry = ToolRegistry() 66 | watcher = ToolWatcher(tool_registry) 67 | 68 | # Create a mock event with the specified properties 69 | event = MagicMock() 70 | event.src_path = test_case["src_path"] 71 | if "is_directory" in test_case: 72 | event.is_directory = test_case["is_directory"] 73 | 74 | # Call the on_modified method 75 | watcher.tool_change_handler.on_modified(event) 76 | 77 | # Verify the expected behavior 78 | if test_case["should_reload"]: 79 | mock_reload_tool.assert_called_once_with(test_case["expected_tool_name"]) 80 | else: 81 | mock_reload_tool.assert_not_called() 82 | 83 | 84 | @patch.object(ToolRegistry, "reload_tool", side_effect=Exception("Test error")) 85 | def test_on_modified_error_handling(mock_reload_tool): 86 | """Test that on_modified handles errors during tool reloading.""" 87 | tool_registry = ToolRegistry() 88 | watcher = ToolWatcher(tool_registry) 89 | 90 | # Create a mock event with a Python file path 91 | event = MagicMock() 92 | event.src_path = "/path/to/test_tool.py" 93 | 94 | # Call the on_modified method - should not raise an exception 95 | watcher.tool_change_handler.on_modified(event) 96 | 97 | # Verify that reload_tool was called 98 | mock_reload_tool.assert_called_once_with("test_tool") 99 | -------------------------------------------------------------------------------- /tests/strands/types/test_session.py: -------------------------------------------------------------------------------- 1 | import json 2 | from uuid import uuid4 3 | 4 | from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager 5 | from strands.types.session import ( 6 | Session, 7 | SessionAgent, 8 | SessionMessage, 9 | SessionType, 10 | decode_bytes_values, 11 | encode_bytes_values, 12 | ) 13 | 14 | 15 | def test_session_json_serializable(): 16 | session = Session(session_id=str(uuid4()), session_type=SessionType.AGENT) 17 | # json dumps will fail if its not json serializable 18 | session_json_string = json.dumps(session.to_dict()) 19 | loaded_session = Session.from_dict(json.loads(session_json_string)) 20 | assert loaded_session is not None 21 | 22 | 23 | def test_agent_json_serializable(): 24 | agent = SessionAgent( 25 | agent_id=str(uuid4()), state={"foo": "bar"}, conversation_manager_state=NullConversationManager().get_state() 26 | ) 27 | # json dumps will fail if its not json serializable 28 | agent_json_string = json.dumps(agent.to_dict()) 29 | loaded_agent = SessionAgent.from_dict(json.loads(agent_json_string)) 30 | assert loaded_agent is not None 31 | 32 | 33 | def test_message_json_serializable(): 34 | message = SessionMessage(message={"role": "user", "content": [{"text": "Hello!"}]}, message_id=0) 35 | # json dumps will fail if its not json serializable 36 | message_json_string = json.dumps(message.to_dict()) 37 | loaded_message = SessionMessage.from_dict(json.loads(message_json_string)) 38 | assert loaded_message is not None 39 | 40 | 41 | def test_bytes_encoding_decoding(): 42 | # Test simple bytes 43 | test_bytes = b"Hello, world!" 44 | encoded = encode_bytes_values(test_bytes) 45 | assert isinstance(encoded, dict) 46 | assert encoded["__bytes_encoded__"] is True 47 | decoded = decode_bytes_values(encoded) 48 | assert decoded == test_bytes 49 | 50 | # Test nested structure with bytes 51 | test_data = { 52 | "text": "Hello", 53 | "binary": b"Binary data", 54 | "nested": {"more_binary": b"More binary data", "list_with_binary": [b"Item 1", "Text item", b"Item 3"]}, 55 | } 56 | 57 | encoded = encode_bytes_values(test_data) 58 | # Verify it's JSON serializable 59 | json_str = json.dumps(encoded) 60 | # Deserialize and decode 61 | decoded = decode_bytes_values(json.loads(json_str)) 62 | 63 | # Verify the decoded data matches the original 64 | assert decoded["text"] == test_data["text"] 65 | assert decoded["binary"] == test_data["binary"] 66 | assert decoded["nested"]["more_binary"] == test_data["nested"]["more_binary"] 67 | assert decoded["nested"]["list_with_binary"][0] == test_data["nested"]["list_with_binary"][0] 68 | assert decoded["nested"]["list_with_binary"][1] == test_data["nested"]["list_with_binary"][1] 69 | assert decoded["nested"]["list_with_binary"][2] == test_data["nested"]["list_with_binary"][2] 70 | 71 | 72 | def test_session_message_with_bytes(): 73 | # Create a message with bytes content 74 | message = { 75 | "role": "user", 76 | "content": [{"text": "Here is some binary data"}, {"binary_data": b"This is binary data"}], 77 | } 78 | 79 | # Create a SessionMessage 80 | session_message = SessionMessage.from_message(message, 0) 81 | 82 | # Verify it's JSON serializable 83 | message_json_string = json.dumps(session_message.to_dict()) 84 | 85 | # Load it back 86 | loaded_message = SessionMessage.from_dict(json.loads(message_json_string)) 87 | 88 | # Convert back to original message and verify 89 | original_message = loaded_message.to_message() 90 | 91 | assert original_message["role"] == message["role"] 92 | assert original_message["content"][0]["text"] == message["content"][0]["text"] 93 | assert original_message["content"][1]["binary_data"] == message["content"][1]["binary_data"] 94 | -------------------------------------------------------------------------------- /tests_integ/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/strands-agents/sdk-python/f5e24d402bb22dc1a54c94dd030b4be0f7e73261/tests_integ/__init__.py -------------------------------------------------------------------------------- /tests_integ/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | ## Data 4 | 5 | 6 | @pytest.fixture 7 | def yellow_img(pytestconfig): 8 | path = pytestconfig.rootdir / "tests_integ/yellow.png" 9 | with open(path, "rb") as fp: 10 | return fp.read() 11 | 12 | 13 | ## Async 14 | 15 | 16 | @pytest.fixture(scope="session") 17 | def agenerator(): 18 | async def agenerator(items): 19 | for item in items: 20 | yield item 21 | 22 | return agenerator 23 | 24 | 25 | @pytest.fixture(scope="session") 26 | def alist(): 27 | async def alist(items): 28 | return [item async for item in items] 29 | 30 | return alist 31 | -------------------------------------------------------------------------------- /tests_integ/echo_server.py: -------------------------------------------------------------------------------- 1 | """ 2 | Echo Server for MCP Integration Testing 3 | 4 | This module implements a simple echo server using the Model Context Protocol (MCP). 5 | It provides a basic tool that echoes back any input string, which is useful for 6 | testing the MCP communication flow and validating that messages are properly 7 | transmitted between the client and server. 8 | 9 | The server runs with stdio transport, making it suitable for integration tests 10 | where the client can spawn this process and communicate with it through standard 11 | input/output streams. 12 | 13 | Usage: 14 | Run this file directly to start the echo server: 15 | $ python echo_server.py 16 | """ 17 | 18 | from mcp.server import FastMCP 19 | 20 | 21 | def start_echo_server(): 22 | """ 23 | Initialize and start the MCP echo server. 24 | 25 | Creates a FastMCP server instance with a single 'echo' tool that returns 26 | any input string back to the caller. The server uses stdio transport 27 | for communication. 28 | """ 29 | mcp = FastMCP("Echo Server") 30 | 31 | @mcp.tool(description="Echos response back to the user") 32 | def echo(to_echo: str) -> str: 33 | return to_echo 34 | 35 | mcp.run(transport="stdio") 36 | 37 | 38 | if __name__ == "__main__": 39 | start_echo_server() 40 | -------------------------------------------------------------------------------- /tests_integ/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/strands-agents/sdk-python/f5e24d402bb22dc1a54c94dd030b4be0f7e73261/tests_integ/models/__init__.py -------------------------------------------------------------------------------- /tests_integ/models/conformance.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from strands.types.models import Model 4 | from tests_integ.models.providers import ProviderInfo, all_providers 5 | 6 | 7 | def get_models(): 8 | return [ 9 | pytest.param( 10 | provider_info, 11 | id=provider_info.id, # Adds the provider name to the test name 12 | marks=[provider_info.mark], # ignores tests that don't have the requirements 13 | ) 14 | for provider_info in all_providers 15 | ] 16 | 17 | 18 | @pytest.fixture(params=get_models()) 19 | def provider_info(request) -> ProviderInfo: 20 | return request.param 21 | 22 | 23 | @pytest.fixture() 24 | def model(provider_info): 25 | return provider_info.create_model() 26 | 27 | 28 | def test_model_can_be_constructed(model: Model): 29 | assert model is not None 30 | pass 31 | -------------------------------------------------------------------------------- /tests_integ/models/providers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Aggregates all providers for testing all providers in one go. 3 | """ 4 | 5 | import os 6 | from typing import Callable, Optional 7 | 8 | import requests 9 | from pytest import mark 10 | 11 | from strands.models import BedrockModel, Model 12 | from strands.models.anthropic import AnthropicModel 13 | from strands.models.litellm import LiteLLMModel 14 | from strands.models.llamaapi import LlamaAPIModel 15 | from strands.models.mistral import MistralModel 16 | from strands.models.ollama import OllamaModel 17 | from strands.models.openai import OpenAIModel 18 | from strands.models.writer import WriterModel 19 | 20 | 21 | class ProviderInfo: 22 | """Provider-based info for providers that require an APIKey via environment variables.""" 23 | 24 | def __init__( 25 | self, 26 | id: str, 27 | factory: Callable[[], Model], 28 | environment_variable: Optional[str] = None, 29 | ) -> None: 30 | self.id = id 31 | self.model_factory = factory 32 | self.mark = mark.skipif( 33 | environment_variable is not None and environment_variable not in os.environ, 34 | reason=f"{environment_variable} environment variable missing", 35 | ) 36 | 37 | def create_model(self) -> Model: 38 | return self.model_factory() 39 | 40 | 41 | class OllamaProviderInfo(ProviderInfo): 42 | """Special case ollama as it's dependent on the server being available.""" 43 | 44 | def __init__(self): 45 | super().__init__( 46 | id="ollama", factory=lambda: OllamaModel(host="http://localhost:11434", model_id="llama3.3:70b") 47 | ) 48 | 49 | is_server_available = False 50 | try: 51 | is_server_available = requests.get("http://localhost:11434").ok 52 | except requests.exceptions.ConnectionError: 53 | pass 54 | 55 | self.mark = mark.skipif( 56 | not is_server_available, 57 | reason="Local Ollama endpoint not available at localhost:11434", 58 | ) 59 | 60 | 61 | anthropic = ProviderInfo( 62 | id="anthropic", 63 | environment_variable="ANTHROPIC_API_KEY", 64 | factory=lambda: AnthropicModel( 65 | client_args={ 66 | "api_key": os.getenv("ANTHROPIC_API_KEY"), 67 | }, 68 | model_id="claude-3-7-sonnet-20250219", 69 | max_tokens=512, 70 | ), 71 | ) 72 | bedrock = ProviderInfo(id="bedrock", factory=lambda: BedrockModel()) 73 | cohere = ProviderInfo( 74 | id="cohere", 75 | environment_variable="CO_API_KEY", 76 | factory=lambda: OpenAIModel( 77 | client_args={ 78 | "base_url": "https://api.cohere.com/compatibility/v1", 79 | "api_key": os.getenv("CO_API_KEY"), 80 | }, 81 | model_id="command-a-03-2025", 82 | params={"stream_options": None}, 83 | ), 84 | ) 85 | litellm = ProviderInfo( 86 | id="litellm", factory=lambda: LiteLLMModel(model_id="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0") 87 | ) 88 | llama = ProviderInfo( 89 | id="llama", 90 | environment_variable="LLAMA_API_KEY", 91 | factory=lambda: LlamaAPIModel( 92 | model_id="Llama-4-Maverick-17B-128E-Instruct-FP8", 93 | client_args={ 94 | "api_key": os.getenv("LLAMA_API_KEY"), 95 | }, 96 | ), 97 | ) 98 | mistral = ProviderInfo( 99 | id="mistral", 100 | environment_variable="MISTRAL_API_KEY", 101 | factory=lambda: MistralModel( 102 | model_id="mistral-medium-latest", 103 | api_key=os.getenv("MISTRAL_API_KEY"), 104 | stream=True, 105 | temperature=0.7, 106 | max_tokens=1000, 107 | top_p=0.9, 108 | ), 109 | ) 110 | openai = ProviderInfo( 111 | id="openai", 112 | environment_variable="OPENAI_API_KEY", 113 | factory=lambda: OpenAIModel( 114 | model_id="gpt-4o", 115 | client_args={ 116 | "api_key": os.getenv("OPENAI_API_KEY"), 117 | }, 118 | ), 119 | ) 120 | writer = ProviderInfo( 121 | id="writer", 122 | environment_variable="WRITER_API_KEY", 123 | factory=lambda: WriterModel( 124 | model_id="palmyra-x4", 125 | client_args={"api_key": os.getenv("WRITER_API_KEY", "")}, 126 | stream_options={"include_usage": True}, 127 | ), 128 | ) 129 | 130 | ollama = OllamaProviderInfo() 131 | 132 | 133 | all_providers = [ 134 | bedrock, 135 | anthropic, 136 | cohere, 137 | llama, 138 | litellm, 139 | mistral, 140 | openai, 141 | writer, 142 | ] 143 | -------------------------------------------------------------------------------- /tests_integ/models/test_model_anthropic.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pydantic 4 | import pytest 5 | 6 | import strands 7 | from strands import Agent 8 | from strands.models.anthropic import AnthropicModel 9 | from tests_integ.models import providers 10 | 11 | # these tests only run if we have the anthropic api key 12 | pytestmark = providers.anthropic.mark 13 | 14 | 15 | @pytest.fixture 16 | def model(): 17 | return AnthropicModel( 18 | client_args={ 19 | "api_key": os.getenv("ANTHROPIC_API_KEY"), 20 | }, 21 | model_id="claude-3-7-sonnet-20250219", 22 | max_tokens=512, 23 | ) 24 | 25 | 26 | @pytest.fixture 27 | def tools(): 28 | @strands.tool 29 | def tool_time() -> str: 30 | return "12:00" 31 | 32 | @strands.tool 33 | def tool_weather() -> str: 34 | return "sunny" 35 | 36 | return [tool_time, tool_weather] 37 | 38 | 39 | @pytest.fixture 40 | def system_prompt(): 41 | return "You are an AI assistant." 42 | 43 | 44 | @pytest.fixture 45 | def agent(model, tools, system_prompt): 46 | return Agent(model=model, tools=tools, system_prompt=system_prompt) 47 | 48 | 49 | @pytest.fixture 50 | def weather(): 51 | class Weather(pydantic.BaseModel): 52 | """Extracts the time and weather from the user's message with the exact strings.""" 53 | 54 | time: str 55 | weather: str 56 | 57 | return Weather(time="12:00", weather="sunny") 58 | 59 | 60 | @pytest.fixture 61 | def yellow_color(): 62 | class Color(pydantic.BaseModel): 63 | """Describes a color.""" 64 | 65 | name: str 66 | 67 | @pydantic.field_validator("name", mode="after") 68 | @classmethod 69 | def lower(_, value): 70 | return value.lower() 71 | 72 | return Color(name="yellow") 73 | 74 | 75 | def test_agent_invoke(agent): 76 | result = agent("What is the time and weather in New York?") 77 | text = result.message["content"][0]["text"].lower() 78 | 79 | assert all(string in text for string in ["12:00", "sunny"]) 80 | 81 | 82 | @pytest.mark.asyncio 83 | async def test_agent_invoke_async(agent): 84 | result = await agent.invoke_async("What is the time and weather in New York?") 85 | text = result.message["content"][0]["text"].lower() 86 | 87 | assert all(string in text for string in ["12:00", "sunny"]) 88 | 89 | 90 | @pytest.mark.asyncio 91 | async def test_agent_stream_async(agent): 92 | stream = agent.stream_async("What is the time and weather in New York?") 93 | async for event in stream: 94 | _ = event 95 | 96 | result = event["result"] 97 | text = result.message["content"][0]["text"].lower() 98 | 99 | assert all(string in text for string in ["12:00", "sunny"]) 100 | 101 | 102 | def test_structured_output(agent, weather): 103 | tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny") 104 | exp_weather = weather 105 | assert tru_weather == exp_weather 106 | 107 | 108 | @pytest.mark.asyncio 109 | async def test_agent_structured_output_async(agent, weather): 110 | tru_weather = await agent.structured_output_async(type(weather), "The time is 12:00 and the weather is sunny") 111 | exp_weather = weather 112 | assert tru_weather == exp_weather 113 | 114 | 115 | def test_invoke_multi_modal_input(agent, yellow_img): 116 | content = [ 117 | {"text": "what is in this image"}, 118 | { 119 | "image": { 120 | "format": "png", 121 | "source": { 122 | "bytes": yellow_img, 123 | }, 124 | }, 125 | }, 126 | ] 127 | result = agent(content) 128 | text = result.message["content"][0]["text"].lower() 129 | 130 | assert "yellow" in text 131 | 132 | 133 | def test_structured_output_multi_modal_input(agent, yellow_img, yellow_color): 134 | content = [ 135 | {"text": "Is this image red, blue, or yellow?"}, 136 | { 137 | "image": { 138 | "format": "png", 139 | "source": { 140 | "bytes": yellow_img, 141 | }, 142 | }, 143 | }, 144 | ] 145 | tru_color = agent.structured_output(type(yellow_color), content) 146 | exp_color = yellow_color 147 | assert tru_color == exp_color 148 | -------------------------------------------------------------------------------- /tests_integ/models/test_model_bedrock.py: -------------------------------------------------------------------------------- 1 | import pydantic 2 | import pytest 3 | 4 | import strands 5 | from strands import Agent 6 | from strands.models import BedrockModel 7 | 8 | 9 | @pytest.fixture 10 | def system_prompt(): 11 | return "You are an AI assistant that uses & instead of ." 12 | 13 | 14 | @pytest.fixture 15 | def streaming_model(): 16 | return BedrockModel( 17 | streaming=True, 18 | ) 19 | 20 | 21 | @pytest.fixture 22 | def non_streaming_model(): 23 | return BedrockModel( 24 | streaming=False, 25 | ) 26 | 27 | 28 | @pytest.fixture 29 | def streaming_agent(streaming_model, system_prompt): 30 | return Agent(model=streaming_model, system_prompt=system_prompt, load_tools_from_directory=False) 31 | 32 | 33 | @pytest.fixture 34 | def non_streaming_agent(non_streaming_model, system_prompt): 35 | return Agent(model=non_streaming_model, system_prompt=system_prompt, load_tools_from_directory=False) 36 | 37 | 38 | @pytest.fixture 39 | def yellow_color(): 40 | class Color(pydantic.BaseModel): 41 | """Describes a color.""" 42 | 43 | name: str 44 | 45 | @pydantic.field_validator("name", mode="after") 46 | @classmethod 47 | def lower(_, value): 48 | return value.lower() 49 | 50 | return Color(name="yellow") 51 | 52 | 53 | def test_streaming_agent(streaming_agent): 54 | """Test agent with streaming model.""" 55 | result = streaming_agent("Hello!") 56 | 57 | assert len(str(result)) > 0 58 | 59 | 60 | def test_non_streaming_agent(non_streaming_agent): 61 | """Test agent with non-streaming model.""" 62 | result = non_streaming_agent("Hello!") 63 | 64 | assert len(str(result)) > 0 65 | 66 | 67 | @pytest.mark.asyncio 68 | async def test_streaming_model_events(streaming_model, alist): 69 | """Test streaming model events.""" 70 | messages = [{"role": "user", "content": [{"text": "Hello"}]}] 71 | 72 | # Call stream and collect events 73 | events = await alist(streaming_model.stream(messages)) 74 | 75 | # Verify basic structure of events 76 | assert any("messageStart" in event for event in events) 77 | assert any("contentBlockDelta" in event for event in events) 78 | assert any("messageStop" in event for event in events) 79 | 80 | 81 | @pytest.mark.asyncio 82 | async def test_non_streaming_model_events(non_streaming_model, alist): 83 | """Test non-streaming model events.""" 84 | messages = [{"role": "user", "content": [{"text": "Hello"}]}] 85 | 86 | # Call stream and collect events 87 | events = await alist(non_streaming_model.stream(messages)) 88 | 89 | # Verify basic structure of events 90 | assert any("messageStart" in event for event in events) 91 | assert any("contentBlockDelta" in event for event in events) 92 | assert any("messageStop" in event for event in events) 93 | 94 | 95 | def test_tool_use_streaming(streaming_model): 96 | """Test tool use with streaming model.""" 97 | 98 | tool_was_called = False 99 | 100 | @strands.tool 101 | def calculator(expression: str) -> float: 102 | """Calculate the result of a mathematical expression.""" 103 | 104 | nonlocal tool_was_called 105 | tool_was_called = True 106 | return eval(expression) 107 | 108 | agent = Agent(model=streaming_model, tools=[calculator], load_tools_from_directory=False) 109 | result = agent("What is 123 + 456?") 110 | 111 | # Print the full message content for debugging 112 | print("\nFull message content:") 113 | import json 114 | 115 | print(json.dumps(result.message["content"], indent=2)) 116 | 117 | assert tool_was_called 118 | 119 | 120 | def test_tool_use_non_streaming(non_streaming_model): 121 | """Test tool use with non-streaming model.""" 122 | 123 | tool_was_called = False 124 | 125 | @strands.tool 126 | def calculator(expression: str) -> float: 127 | """Calculate the result of a mathematical expression.""" 128 | 129 | nonlocal tool_was_called 130 | tool_was_called = True 131 | return eval(expression) 132 | 133 | agent = Agent(model=non_streaming_model, tools=[calculator], load_tools_from_directory=False) 134 | agent("What is 123 + 456?") 135 | 136 | assert tool_was_called 137 | 138 | 139 | def test_structured_output_streaming(streaming_model): 140 | """Test structured output with streaming model.""" 141 | 142 | class Weather(pydantic.BaseModel): 143 | time: str 144 | weather: str 145 | 146 | agent = Agent(model=streaming_model) 147 | 148 | result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") 149 | assert isinstance(result, Weather) 150 | assert result.time == "12:00" 151 | assert result.weather == "sunny" 152 | 153 | 154 | def test_structured_output_non_streaming(non_streaming_model): 155 | """Test structured output with non-streaming model.""" 156 | 157 | class Weather(pydantic.BaseModel): 158 | time: str 159 | weather: str 160 | 161 | agent = Agent(model=non_streaming_model) 162 | 163 | result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") 164 | assert isinstance(result, Weather) 165 | assert result.time == "12:00" 166 | assert result.weather == "sunny" 167 | 168 | 169 | def test_invoke_multi_modal_input(streaming_agent, yellow_img): 170 | content = [ 171 | {"text": "what is in this image"}, 172 | { 173 | "image": { 174 | "format": "png", 175 | "source": { 176 | "bytes": yellow_img, 177 | }, 178 | }, 179 | }, 180 | ] 181 | result = streaming_agent(content) 182 | text = result.message["content"][0]["text"].lower() 183 | 184 | assert "yellow" in text 185 | 186 | 187 | def test_structured_output_multi_modal_input(streaming_agent, yellow_img, yellow_color): 188 | content = [ 189 | {"text": "Is this image red, blue, or yellow?"}, 190 | { 191 | "image": { 192 | "format": "png", 193 | "source": { 194 | "bytes": yellow_img, 195 | }, 196 | }, 197 | }, 198 | ] 199 | tru_color = streaming_agent.structured_output(type(yellow_color), content) 200 | exp_color = yellow_color 201 | assert tru_color == exp_color 202 | -------------------------------------------------------------------------------- /tests_integ/models/test_model_cohere.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | import strands 6 | from strands import Agent 7 | from strands.models.openai import OpenAIModel 8 | from tests_integ.models import providers 9 | 10 | # these tests only run if we have the cohere api key 11 | pytestmark = providers.cohere.mark 12 | 13 | 14 | @pytest.fixture 15 | def model(): 16 | return OpenAIModel( 17 | client_args={ 18 | "base_url": "https://api.cohere.com/compatibility/v1", 19 | "api_key": os.getenv("CO_API_KEY"), 20 | }, 21 | model_id="command-a-03-2025", 22 | params={"stream_options": None}, 23 | ) 24 | 25 | 26 | @pytest.fixture 27 | def tools(): 28 | @strands.tool 29 | def tool_time() -> str: 30 | return "12:00" 31 | 32 | @strands.tool 33 | def tool_weather() -> str: 34 | return "sunny" 35 | 36 | return [tool_time, tool_weather] 37 | 38 | 39 | @pytest.fixture 40 | def agent(model, tools): 41 | return Agent(model=model, tools=tools) 42 | 43 | 44 | def test_agent(agent): 45 | result = agent("What is the time and weather in New York?") 46 | text = result.message["content"][0]["text"].lower() 47 | assert all(string in text for string in ["12:00", "sunny"]) 48 | -------------------------------------------------------------------------------- /tests_integ/models/test_model_litellm.py: -------------------------------------------------------------------------------- 1 | import pydantic 2 | import pytest 3 | 4 | import strands 5 | from strands import Agent 6 | from strands.models.litellm import LiteLLMModel 7 | 8 | 9 | @pytest.fixture 10 | def model(): 11 | return LiteLLMModel(model_id="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0") 12 | 13 | 14 | @pytest.fixture 15 | def tools(): 16 | @strands.tool 17 | def tool_time() -> str: 18 | return "12:00" 19 | 20 | @strands.tool 21 | def tool_weather() -> str: 22 | return "sunny" 23 | 24 | return [tool_time, tool_weather] 25 | 26 | 27 | @pytest.fixture 28 | def agent(model, tools): 29 | return Agent(model=model, tools=tools) 30 | 31 | 32 | @pytest.fixture 33 | def weather(): 34 | class Weather(pydantic.BaseModel): 35 | """Extracts the time and weather from the user's message with the exact strings.""" 36 | 37 | time: str 38 | weather: str 39 | 40 | return Weather(time="12:00", weather="sunny") 41 | 42 | 43 | @pytest.fixture 44 | def yellow_color(): 45 | class Color(pydantic.BaseModel): 46 | """Describes a color.""" 47 | 48 | name: str 49 | 50 | @pydantic.field_validator("name", mode="after") 51 | @classmethod 52 | def lower(_, value): 53 | return value.lower() 54 | 55 | return Color(name="yellow") 56 | 57 | 58 | def test_agent_invoke(agent): 59 | result = agent("What is the time and weather in New York?") 60 | text = result.message["content"][0]["text"].lower() 61 | 62 | assert all(string in text for string in ["12:00", "sunny"]) 63 | 64 | 65 | @pytest.mark.asyncio 66 | async def test_agent_invoke_async(agent): 67 | result = await agent.invoke_async("What is the time and weather in New York?") 68 | text = result.message["content"][0]["text"].lower() 69 | 70 | assert all(string in text for string in ["12:00", "sunny"]) 71 | 72 | 73 | @pytest.mark.asyncio 74 | async def test_agent_stream_async(agent): 75 | stream = agent.stream_async("What is the time and weather in New York?") 76 | async for event in stream: 77 | _ = event 78 | 79 | result = event["result"] 80 | text = result.message["content"][0]["text"].lower() 81 | 82 | assert all(string in text for string in ["12:00", "sunny"]) 83 | 84 | 85 | def test_structured_output(agent, weather): 86 | tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny") 87 | exp_weather = weather 88 | assert tru_weather == exp_weather 89 | 90 | 91 | @pytest.mark.asyncio 92 | async def test_agent_structured_output_async(agent, weather): 93 | tru_weather = await agent.structured_output_async(type(weather), "The time is 12:00 and the weather is sunny") 94 | exp_weather = weather 95 | assert tru_weather == exp_weather 96 | 97 | 98 | def test_invoke_multi_modal_input(agent, yellow_img): 99 | content = [ 100 | {"text": "Is this image red, blue, or yellow?"}, 101 | { 102 | "image": { 103 | "format": "png", 104 | "source": { 105 | "bytes": yellow_img, 106 | }, 107 | }, 108 | }, 109 | ] 110 | result = agent(content) 111 | text = result.message["content"][0]["text"].lower() 112 | 113 | assert "yellow" in text 114 | 115 | 116 | def test_structured_output_multi_modal_input(agent, yellow_img, yellow_color): 117 | content = [ 118 | {"text": "what is in this image"}, 119 | { 120 | "image": { 121 | "format": "png", 122 | "source": { 123 | "bytes": yellow_img, 124 | }, 125 | }, 126 | }, 127 | ] 128 | tru_color = agent.structured_output(type(yellow_color), content) 129 | exp_color = yellow_color 130 | assert tru_color == exp_color 131 | -------------------------------------------------------------------------------- /tests_integ/models/test_model_llamaapi.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | import os 3 | 4 | import pytest 5 | 6 | import strands 7 | from strands import Agent 8 | from strands.models.llamaapi import LlamaAPIModel 9 | from tests_integ.models import providers 10 | 11 | # these tests only run if we have the llama api key 12 | pytestmark = providers.llama.mark 13 | 14 | 15 | @pytest.fixture 16 | def model(): 17 | return LlamaAPIModel( 18 | model_id="Llama-4-Maverick-17B-128E-Instruct-FP8", 19 | client_args={ 20 | "api_key": os.getenv("LLAMA_API_KEY"), 21 | }, 22 | ) 23 | 24 | 25 | @pytest.fixture 26 | def tools(): 27 | @strands.tool 28 | def tool_time() -> str: 29 | return "12:00" 30 | 31 | @strands.tool 32 | def tool_weather() -> str: 33 | return "sunny" 34 | 35 | return [tool_time, tool_weather] 36 | 37 | 38 | @pytest.fixture 39 | def agent(model, tools): 40 | return Agent(model=model, tools=tools) 41 | 42 | 43 | def test_agent(agent): 44 | result = agent("What is the time and weather in New York?") 45 | text = result.message["content"][0]["text"].lower() 46 | 47 | assert all(string in text for string in ["12:00", "sunny"]) 48 | -------------------------------------------------------------------------------- /tests_integ/models/test_model_mistral.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from pydantic import BaseModel 5 | 6 | import strands 7 | from strands import Agent 8 | from strands.models.mistral import MistralModel 9 | from tests_integ.models import providers 10 | 11 | # these tests only run if we have the mistral api key 12 | pytestmark = providers.mistral.mark 13 | 14 | 15 | @pytest.fixture() 16 | def streaming_model(): 17 | return MistralModel( 18 | model_id="mistral-medium-latest", 19 | api_key=os.getenv("MISTRAL_API_KEY"), 20 | stream=True, 21 | temperature=0.7, 22 | max_tokens=1000, 23 | top_p=0.9, 24 | ) 25 | 26 | 27 | @pytest.fixture() 28 | def non_streaming_model(): 29 | return MistralModel( 30 | model_id="mistral-medium-latest", 31 | api_key=os.getenv("MISTRAL_API_KEY"), 32 | stream=False, 33 | temperature=0.7, 34 | max_tokens=1000, 35 | top_p=0.9, 36 | ) 37 | 38 | 39 | @pytest.fixture() 40 | def system_prompt(): 41 | return "You are an AI assistant that provides helpful and accurate information." 42 | 43 | 44 | @pytest.fixture() 45 | def tools(): 46 | @strands.tool 47 | def tool_time() -> str: 48 | return "12:00" 49 | 50 | @strands.tool 51 | def tool_weather() -> str: 52 | return "sunny" 53 | 54 | return [tool_time, tool_weather] 55 | 56 | 57 | @pytest.fixture() 58 | def streaming_agent(streaming_model, tools): 59 | return Agent(model=streaming_model, tools=tools) 60 | 61 | 62 | @pytest.fixture() 63 | def non_streaming_agent(non_streaming_model, tools): 64 | return Agent(model=non_streaming_model, tools=tools) 65 | 66 | 67 | @pytest.fixture(params=["streaming_agent", "non_streaming_agent"]) 68 | def agent(request): 69 | return request.getfixturevalue(request.param) 70 | 71 | 72 | @pytest.fixture() 73 | def weather(): 74 | class Weather(BaseModel): 75 | """Extracts the time and weather from the user's message with the exact strings.""" 76 | 77 | time: str 78 | weather: str 79 | 80 | return Weather(time="12:00", weather="sunny") 81 | 82 | 83 | def test_agent_invoke(agent): 84 | result = agent("What is the time and weather in New York?") 85 | text = result.message["content"][0]["text"].lower() 86 | 87 | assert all(string in text for string in ["12:00", "sunny"]) 88 | 89 | 90 | @pytest.mark.asyncio 91 | async def test_agent_invoke_async(agent): 92 | result = await agent.invoke_async("What is the time and weather in New York?") 93 | text = result.message["content"][0]["text"].lower() 94 | 95 | assert all(string in text for string in ["12:00", "sunny"]) 96 | 97 | 98 | @pytest.mark.asyncio 99 | async def test_agent_stream_async(agent): 100 | stream = agent.stream_async("What is the time and weather in New York?") 101 | async for event in stream: 102 | _ = event 103 | 104 | result = event["result"] 105 | text = result.message["content"][0]["text"].lower() 106 | 107 | assert all(string in text for string in ["12:00", "sunny"]) 108 | 109 | 110 | def test_agent_structured_output(non_streaming_agent, weather): 111 | tru_weather = non_streaming_agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny") 112 | exp_weather = weather 113 | assert tru_weather == exp_weather 114 | 115 | 116 | @pytest.mark.asyncio 117 | async def test_agent_structured_output_async(non_streaming_agent, weather): 118 | tru_weather = await non_streaming_agent.structured_output_async( 119 | type(weather), "The time is 12:00 and the weather is sunny" 120 | ) 121 | exp_weather = weather 122 | assert tru_weather == exp_weather 123 | -------------------------------------------------------------------------------- /tests_integ/models/test_model_ollama.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pydantic import BaseModel 3 | 4 | import strands 5 | from strands import Agent 6 | from strands.models.ollama import OllamaModel 7 | from tests_integ.models import providers 8 | 9 | # these tests only run if we have the ollama is running 10 | pytestmark = providers.ollama.mark 11 | 12 | 13 | @pytest.fixture 14 | def model(): 15 | return OllamaModel(host="http://localhost:11434", model_id="llama3.3:70b") 16 | 17 | 18 | @pytest.fixture 19 | def tools(): 20 | @strands.tool 21 | def tool_time() -> str: 22 | return "12:00" 23 | 24 | @strands.tool 25 | def tool_weather() -> str: 26 | return "sunny" 27 | 28 | return [tool_time, tool_weather] 29 | 30 | 31 | @pytest.fixture 32 | def agent(model, tools): 33 | return Agent(model=model, tools=tools) 34 | 35 | 36 | @pytest.fixture 37 | def weather(): 38 | class Weather(BaseModel): 39 | """Extracts the time and weather from the user's message with the exact strings.""" 40 | 41 | time: str 42 | weather: str 43 | 44 | return Weather(time="12:00", weather="sunny") 45 | 46 | 47 | def test_agent_invoke(agent): 48 | result = agent("What is the time and weather in New York?") 49 | text = result.message["content"][0]["text"].lower() 50 | 51 | assert all(string in text for string in ["12:00", "sunny"]) 52 | 53 | 54 | @pytest.mark.asyncio 55 | async def test_agent_invoke_async(agent): 56 | result = await agent.invoke_async("What is the time and weather in New York?") 57 | text = result.message["content"][0]["text"].lower() 58 | 59 | assert all(string in text for string in ["12:00", "sunny"]) 60 | 61 | 62 | @pytest.mark.asyncio 63 | async def test_agent_stream_async(agent): 64 | stream = agent.stream_async("What is the time and weather in New York?") 65 | async for event in stream: 66 | _ = event 67 | 68 | result = event["result"] 69 | text = result.message["content"][0]["text"].lower() 70 | 71 | assert all(string in text for string in ["12:00", "sunny"]) 72 | 73 | 74 | def test_agent_structured_output(agent, weather): 75 | tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny") 76 | exp_weather = weather 77 | assert tru_weather == exp_weather 78 | 79 | 80 | @pytest.mark.asyncio 81 | async def test_agent_structured_output_async(agent, weather): 82 | tru_weather = await agent.structured_output_async(type(weather), "The time is 12:00 and the weather is sunny") 83 | exp_weather = weather 84 | assert tru_weather == exp_weather 85 | -------------------------------------------------------------------------------- /tests_integ/models/test_model_openai.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pydantic 4 | import pytest 5 | 6 | import strands 7 | from strands import Agent, tool 8 | from strands.models.openai import OpenAIModel 9 | from tests_integ.models import providers 10 | 11 | # these tests only run if we have the openai api key 12 | pytestmark = providers.openai.mark 13 | 14 | 15 | @pytest.fixture 16 | def model(): 17 | return OpenAIModel( 18 | model_id="gpt-4o", 19 | client_args={ 20 | "api_key": os.getenv("OPENAI_API_KEY"), 21 | }, 22 | ) 23 | 24 | 25 | @pytest.fixture 26 | def tools(): 27 | @strands.tool 28 | def tool_time() -> str: 29 | return "12:00" 30 | 31 | @strands.tool 32 | def tool_weather() -> str: 33 | return "sunny" 34 | 35 | return [tool_time, tool_weather] 36 | 37 | 38 | @pytest.fixture 39 | def agent(model, tools): 40 | return Agent(model=model, tools=tools) 41 | 42 | 43 | @pytest.fixture 44 | def weather(): 45 | class Weather(pydantic.BaseModel): 46 | """Extracts the time and weather from the user's message with the exact strings.""" 47 | 48 | time: str 49 | weather: str 50 | 51 | return Weather(time="12:00", weather="sunny") 52 | 53 | 54 | @pytest.fixture 55 | def yellow_color(): 56 | class Color(pydantic.BaseModel): 57 | """Describes a color.""" 58 | 59 | name: str 60 | 61 | @pydantic.field_validator("name", mode="after") 62 | @classmethod 63 | def lower(_, value): 64 | return value.lower() 65 | 66 | return Color(name="yellow") 67 | 68 | 69 | @pytest.fixture(scope="module") 70 | def test_image_path(request): 71 | return request.config.rootpath / "tests_integ" / "test_image.png" 72 | 73 | 74 | def test_agent_invoke(agent): 75 | result = agent("What is the time and weather in New York?") 76 | text = result.message["content"][0]["text"].lower() 77 | 78 | assert all(string in text for string in ["12:00", "sunny"]) 79 | 80 | 81 | @pytest.mark.asyncio 82 | async def test_agent_invoke_async(agent): 83 | result = await agent.invoke_async("What is the time and weather in New York?") 84 | text = result.message["content"][0]["text"].lower() 85 | 86 | assert all(string in text for string in ["12:00", "sunny"]) 87 | 88 | 89 | @pytest.mark.asyncio 90 | async def test_agent_stream_async(agent): 91 | stream = agent.stream_async("What is the time and weather in New York?") 92 | async for event in stream: 93 | _ = event 94 | 95 | result = event["result"] 96 | text = result.message["content"][0]["text"].lower() 97 | 98 | assert all(string in text for string in ["12:00", "sunny"]) 99 | 100 | 101 | def test_agent_structured_output(agent, weather): 102 | tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny") 103 | exp_weather = weather 104 | assert tru_weather == exp_weather 105 | 106 | 107 | @pytest.mark.asyncio 108 | async def test_agent_structured_output_async(agent, weather): 109 | tru_weather = await agent.structured_output_async(type(weather), "The time is 12:00 and the weather is sunny") 110 | exp_weather = weather 111 | assert tru_weather == exp_weather 112 | 113 | 114 | def test_invoke_multi_modal_input(agent, yellow_img): 115 | content = [ 116 | {"text": "what is in this image"}, 117 | { 118 | "image": { 119 | "format": "png", 120 | "source": { 121 | "bytes": yellow_img, 122 | }, 123 | }, 124 | }, 125 | ] 126 | result = agent(content) 127 | text = result.message["content"][0]["text"].lower() 128 | 129 | assert "yellow" in text 130 | 131 | 132 | def test_structured_output_multi_modal_input(agent, yellow_img, yellow_color): 133 | content = [ 134 | {"text": "Is this image red, blue, or yellow?"}, 135 | { 136 | "image": { 137 | "format": "png", 138 | "source": { 139 | "bytes": yellow_img, 140 | }, 141 | }, 142 | }, 143 | ] 144 | tru_color = agent.structured_output(type(yellow_color), content) 145 | exp_color = yellow_color 146 | assert tru_color == exp_color 147 | 148 | 149 | @pytest.mark.skip("https://github.com/strands-agents/sdk-python/issues/320") 150 | def test_tool_returning_images(model, yellow_img): 151 | @tool 152 | def tool_with_image_return(): 153 | return { 154 | "status": "success", 155 | "content": [ 156 | { 157 | "image": { 158 | "format": "png", 159 | "source": {"bytes": yellow_img}, 160 | } 161 | }, 162 | ], 163 | } 164 | 165 | agent = Agent(model, tools=[tool_with_image_return]) 166 | # NOTE - this currently fails with: "Invalid 'messages[3]'. Image URLs are only allowed for messages with role 167 | # 'user', but this message with role 'tool' contains an image URL." 168 | # See https://github.com/strands-agents/sdk-python/issues/320 for additional details 169 | agent("Run the the tool and analyze the image") 170 | -------------------------------------------------------------------------------- /tests_integ/models/test_model_writer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from pydantic import BaseModel 5 | 6 | import strands 7 | from strands import Agent 8 | from strands.models.writer import WriterModel 9 | from tests_integ.models import providers 10 | 11 | # these tests only run if we have the writer api key 12 | pytestmark = providers.writer.mark 13 | 14 | 15 | @pytest.fixture 16 | def model(): 17 | return WriterModel( 18 | model_id="palmyra-x4", 19 | client_args={"api_key": os.getenv("WRITER_API_KEY", "")}, 20 | stream_options={"include_usage": True}, 21 | ) 22 | 23 | 24 | @pytest.fixture 25 | def system_prompt(): 26 | return "You are a smart assistant, that uses @ instead of all punctuation marks" 27 | 28 | 29 | @pytest.fixture 30 | def tools(): 31 | @strands.tool 32 | def tool_time() -> str: 33 | return "12:00" 34 | 35 | @strands.tool 36 | def tool_weather() -> str: 37 | return "sunny" 38 | 39 | return [tool_time, tool_weather] 40 | 41 | 42 | @pytest.fixture 43 | def agent(model, tools, system_prompt): 44 | return Agent(model=model, tools=tools, system_prompt=system_prompt, load_tools_from_directory=False) 45 | 46 | 47 | def test_agent(agent): 48 | result = agent("What is the time and weather in New York?") 49 | text = result.message["content"][0]["text"].lower() 50 | 51 | assert all(string in text for string in ["12:00", "sunny"]) 52 | 53 | 54 | @pytest.mark.asyncio 55 | async def test_agent_async(agent): 56 | result = await agent.invoke_async("What is the time and weather in New York?") 57 | text = result.message["content"][0]["text"].lower() 58 | 59 | assert all(string in text for string in ["12:00", "sunny"]) 60 | 61 | 62 | @pytest.mark.asyncio 63 | async def test_agent_stream_async(agent): 64 | stream = agent.stream_async("What is the time and weather in New York?") 65 | async for event in stream: 66 | _ = event 67 | 68 | result = event["result"] 69 | text = result.message["content"][0]["text"].lower() 70 | 71 | assert all(string in text for string in ["12:00", "sunny"]) 72 | 73 | 74 | def test_structured_output(agent): 75 | class Weather(BaseModel): 76 | time: str 77 | weather: str 78 | 79 | result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") 80 | 81 | assert isinstance(result, Weather) 82 | assert result.time == "12:00" 83 | assert result.weather == "sunny" 84 | 85 | 86 | @pytest.mark.asyncio 87 | async def test_structured_output_async(agent): 88 | class Weather(BaseModel): 89 | time: str 90 | weather: str 91 | 92 | result = await agent.structured_output_async(Weather, "The time is 12:00 and the weather is sunny") 93 | 94 | assert isinstance(result, Weather) 95 | assert result.time == "12:00" 96 | assert result.weather == "sunny" 97 | -------------------------------------------------------------------------------- /tests_integ/test_agent_async.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import strands 4 | 5 | 6 | @pytest.fixture 7 | def agent(): 8 | return strands.Agent() 9 | 10 | 11 | @pytest.mark.asyncio 12 | async def test_stream_async(agent): 13 | stream = agent.stream_async("hello") 14 | 15 | exp_message = "" 16 | async for event in stream: 17 | if "event" in event and "contentBlockDelta" in event["event"]: 18 | exp_message += event["event"]["contentBlockDelta"]["delta"]["text"] 19 | 20 | tru_message = agent.messages[-1]["content"][0]["text"] 21 | 22 | assert tru_message == exp_message 23 | -------------------------------------------------------------------------------- /tests_integ/test_bedrock_cache_point.py: -------------------------------------------------------------------------------- 1 | from strands import Agent 2 | from strands.types.content import Messages 3 | 4 | 5 | def test_bedrock_cache_point(): 6 | messages: Messages = [ 7 | { 8 | "role": "user", 9 | "content": [ 10 | { 11 | "text": "Some really long text!" * 1000 # Minimum token count for cachePoint is 1024 tokens 12 | }, 13 | {"cachePoint": {"type": "default"}}, 14 | ], 15 | }, 16 | {"role": "assistant", "content": [{"text": "Blue!"}]}, 17 | ] 18 | 19 | cache_point_usage = 0 20 | 21 | def cache_point_callback_handler(**kwargs): 22 | nonlocal cache_point_usage 23 | if "event" in kwargs and kwargs["event"] and "metadata" in kwargs["event"] and kwargs["event"]["metadata"]: 24 | metadata = kwargs["event"]["metadata"] 25 | if "usage" in metadata and metadata["usage"]: 26 | if "cacheReadInputTokens" in metadata["usage"] or "cacheWriteInputTokens" in metadata["usage"]: 27 | cache_point_usage += 1 28 | 29 | agent = Agent(messages=messages, callback_handler=cache_point_callback_handler, load_tools_from_directory=False) 30 | agent("What is favorite color?") 31 | assert cache_point_usage > 0 32 | -------------------------------------------------------------------------------- /tests_integ/test_context_overflow.py: -------------------------------------------------------------------------------- 1 | from strands import Agent 2 | from strands.types.content import Messages 3 | 4 | 5 | def test_context_window_overflow(): 6 | messages: Messages = [ 7 | {"role": "user", "content": [{"text": "Too much text!" * 100000}]}, 8 | {"role": "assistant", "content": [{"text": "That was a lot of text!"}]}, 9 | ] 10 | 11 | agent = Agent(messages=messages, load_tools_from_directory=False) 12 | agent("Hi!") 13 | assert len(agent.messages) == 2 14 | -------------------------------------------------------------------------------- /tests_integ/test_function_tools.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Test script for function-based tools 4 | """ 5 | 6 | import logging 7 | from typing import Optional 8 | 9 | from strands import Agent, tool 10 | 11 | logging.getLogger("strands").setLevel(logging.DEBUG) 12 | logging.basicConfig(format="%(levelname)s | %(name)s | %(message)s", handlers=[logging.StreamHandler()]) 13 | 14 | 15 | @tool 16 | def word_counter(text: str) -> str: 17 | """ 18 | Count words in text. 19 | 20 | Args: 21 | text: Text to analyze 22 | """ 23 | count = len(text.split()) 24 | return f"Word count: {count}" 25 | 26 | 27 | @tool(name="count_chars", description="Count characters in text") 28 | def count_chars(text: str, include_spaces: Optional[bool] = True) -> str: 29 | """ 30 | Count characters in text. 31 | 32 | Args: 33 | text: Text to analyze 34 | include_spaces: Whether to include spaces in the count 35 | """ 36 | if not include_spaces: 37 | text = text.replace(" ", "") 38 | return f"Character count: {len(text)}" 39 | 40 | 41 | # Initialize agent with function tools 42 | agent = Agent(tools=[word_counter, count_chars]) 43 | 44 | print("\n===== Testing Direct Tool Access =====") 45 | # Use the tools directly 46 | word_result = agent.tool.word_counter(text="Hello world, this is a test") 47 | print(f"\nWord counter result: {word_result}") 48 | 49 | char_result = agent.tool.count_chars(text="Hello world!", include_spaces=False) 50 | print(f"\nCharacter counter result: {char_result}") 51 | 52 | print("\n===== Testing Natural Language Access =====") 53 | # Use through natural language 54 | nl_result = agent("Count the words in this sentence: 'The quick brown fox jumps over the lazy dog'") 55 | print(f"\nNL Result: {nl_result}") 56 | -------------------------------------------------------------------------------- /tests_integ/test_hot_tool_reload_decorator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Integration test for hot tool reloading functionality with the @tool decorator. 3 | 4 | This test verifies that the Strands Agent can automatically detect and load 5 | new tools created with the @tool decorator when they are added to a tools directory. 6 | """ 7 | 8 | import logging 9 | import os 10 | import time 11 | from pathlib import Path 12 | 13 | from strands import Agent 14 | 15 | logging.getLogger("strands").setLevel(logging.DEBUG) 16 | logging.basicConfig(format="%(levelname)s | %(name)s | %(message)s", handlers=[logging.StreamHandler()]) 17 | 18 | 19 | def test_hot_reload_decorator(): 20 | """ 21 | Test that the Agent automatically loads tools created with @tool decorator 22 | when added to the current working directory's tools folder. 23 | """ 24 | # Set up the tools directory in current working directory 25 | tools_dir = Path.cwd() / "tools" 26 | os.makedirs(tools_dir, exist_ok=True) 27 | 28 | # Tool path that will need cleanup 29 | test_tool_path = tools_dir / "uppercase.py" 30 | 31 | try: 32 | # Create an Agent instance without any tools 33 | agent = Agent(load_tools_from_directory=True) 34 | 35 | # Create a test tool using @tool decorator 36 | with open(test_tool_path, "w") as f: 37 | f.write(""" 38 | from strands import tool 39 | 40 | @tool 41 | def uppercase(text: str) -> str: 42 | \"\"\"Convert text to uppercase.\"\"\" 43 | return f"Input: {text}, Output: {text.upper()}" 44 | """) 45 | 46 | # Wait for tool detection 47 | time.sleep(3) 48 | 49 | # Verify the tool was automatically loaded 50 | assert "uppercase" in agent.tool_names, "Agent should have detected and loaded the uppercase tool" 51 | 52 | # Test calling the dynamically loaded tool 53 | result = agent.tool.uppercase(text="hello world") 54 | 55 | # Check that the result is successful 56 | assert result.get("status") == "success", "Tool call should be successful" 57 | 58 | # Check the content of the response 59 | content_list = result.get("content", []) 60 | assert len(content_list) > 0, "Tool response should have content" 61 | 62 | # Check that the expected message is in the content 63 | text_content = next((item.get("text") for item in content_list if "text" in item), "") 64 | assert "Input: hello world, Output: HELLO WORLD" in text_content 65 | 66 | finally: 67 | # Clean up - remove the test file 68 | if test_tool_path.exists(): 69 | os.remove(test_tool_path) 70 | 71 | 72 | def test_hot_reload_decorator_update(): 73 | """ 74 | Test that the Agent detects updates to tools created with @tool decorator. 75 | """ 76 | # Set up the tools directory in current working directory 77 | tools_dir = Path.cwd() / "tools" 78 | os.makedirs(tools_dir, exist_ok=True) 79 | 80 | # Tool path that will need cleanup - make sure filename matches function name 81 | test_tool_path = tools_dir / "greeting.py" 82 | 83 | try: 84 | # Create an Agent instance 85 | agent = Agent(load_tools_from_directory=True) 86 | 87 | # Create the initial version of the tool 88 | with open(test_tool_path, "w") as f: 89 | f.write(""" 90 | from strands import tool 91 | 92 | @tool 93 | def greeting(name: str) -> str: 94 | \"\"\"Generate a simple greeting.\"\"\" 95 | return f"Hello, {name}!" 96 | """) 97 | 98 | # Wait for tool detection 99 | time.sleep(3) 100 | 101 | # Verify the tool was loaded 102 | assert "greeting" in agent.tool_names, "Agent should have detected and loaded the greeting tool" 103 | 104 | # Test calling the tool 105 | result1 = agent.tool.greeting(name="Strands") 106 | text_content1 = next((item.get("text") for item in result1.get("content", []) if "text" in item), "") 107 | assert "Hello, Strands!" in text_content1, "Tool should return simple greeting" 108 | 109 | # Update the tool with new functionality 110 | with open(test_tool_path, "w") as f: 111 | f.write(""" 112 | from strands import tool 113 | import datetime 114 | 115 | @tool 116 | def greeting(name: str, formal: bool = False) -> str: 117 | \"\"\"Generate a greeting with optional formality.\"\"\" 118 | current_hour = datetime.datetime.now().hour 119 | time_of_day = "morning" if current_hour < 12 else "afternoon" if current_hour < 18 else "evening" 120 | 121 | if formal: 122 | return f"Good {time_of_day}, {name}. It's a pleasure to meet you." 123 | else: 124 | return f"Hey {name}! How's your {time_of_day} going?" 125 | """) 126 | 127 | # Wait for hot reload to detect the change 128 | time.sleep(3) 129 | 130 | # Test calling the updated tool 131 | result2 = agent.tool.greeting(name="Strands", formal=True) 132 | text_content2 = next((item.get("text") for item in result2.get("content", []) if "text" in item), "") 133 | assert "Good" in text_content2 and "Strands" in text_content2 and "pleasure to meet you" in text_content2 134 | 135 | # Test with informal parameter 136 | result3 = agent.tool.greeting(name="Strands", formal=False) 137 | text_content3 = next((item.get("text") for item in result3.get("content", []) if "text" in item), "") 138 | assert "Hey Strands!" in text_content3 and "going" in text_content3 139 | 140 | finally: 141 | # Clean up - remove the test file 142 | if test_tool_path.exists(): 143 | os.remove(test_tool_path) 144 | -------------------------------------------------------------------------------- /tests_integ/test_mcp_client.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import os 3 | import threading 4 | import time 5 | from typing import List, Literal 6 | 7 | import pytest 8 | from mcp import StdioServerParameters, stdio_client 9 | from mcp.client.sse import sse_client 10 | from mcp.client.streamable_http import streamablehttp_client 11 | from mcp.types import ImageContent as MCPImageContent 12 | 13 | from strands import Agent 14 | from strands.tools.mcp.mcp_client import MCPClient 15 | from strands.tools.mcp.mcp_types import MCPTransport 16 | from strands.types.content import Message 17 | from strands.types.tools import ToolUse 18 | 19 | 20 | def start_calculator_server(transport: Literal["sse", "streamable-http"], port=int): 21 | """ 22 | Initialize and start an MCP calculator server for integration testing. 23 | 24 | This function creates a FastMCP server instance that provides a simple 25 | calculator tool for performing addition operations. The server uses 26 | Server-Sent Events (SSE) transport for communication, making it accessible 27 | over HTTP. 28 | """ 29 | from mcp.server import FastMCP 30 | 31 | mcp = FastMCP("Calculator Server", port=port) 32 | 33 | @mcp.tool(description="Calculator tool which performs calculations") 34 | def calculator(x: int, y: int) -> int: 35 | return x + y 36 | 37 | @mcp.tool(description="Generates a custom image") 38 | def generate_custom_image() -> MCPImageContent: 39 | try: 40 | with open("tests_integ/yellow.png", "rb") as image_file: 41 | encoded_image = base64.b64encode(image_file.read()) 42 | return MCPImageContent(type="image", data=encoded_image, mimeType="image/png") 43 | except Exception as e: 44 | print("Error while generating custom image: {}".format(e)) 45 | 46 | mcp.run(transport=transport) 47 | 48 | 49 | def test_mcp_client(): 50 | """ 51 | Test should yield output similar to the following 52 | {'role': 'user', 'content': [{'text': 'add 1 and 2, then echo the result back to me'}]} 53 | {'role': 'assistant', 'content': [{'text': "I'll help you add 1 and 2 and then echo the result back to you.\n\nFirst, I'll calculate 1 + 2:"}, {'toolUse': {'toolUseId': 'tooluse_17ptaKUxQB20ySZxwgiI_w', 'name': 'calculator', 'input': {'x': 1, 'y': 2}}}]} 54 | {'role': 'user', 'content': [{'toolResult': {'status': 'success', 'toolUseId': 'tooluse_17ptaKUxQB20ySZxwgiI_w', 'content': [{'text': '3'}]}}]} 55 | {'role': 'assistant', 'content': [{'text': "\n\nNow I'll echo the result back to you:"}, {'toolUse': {'toolUseId': 'tooluse_GlOc5SN8TE6ti8jVZJMBOg', 'name': 'echo', 'input': {'to_echo': '3'}}}]} 56 | {'role': 'user', 'content': [{'toolResult': {'status': 'success', 'toolUseId': 'tooluse_GlOc5SN8TE6ti8jVZJMBOg', 'content': [{'text': '3'}]}}]} 57 | {'role': 'assistant', 'content': [{'text': '\n\nThe result of adding 1 and 2 is 3.'}]} 58 | """ # noqa: E501 59 | 60 | server_thread = threading.Thread( 61 | target=start_calculator_server, kwargs={"transport": "sse", "port": 8000}, daemon=True 62 | ) 63 | server_thread.start() 64 | time.sleep(2) # wait for server to startup completely 65 | 66 | sse_mcp_client = MCPClient(lambda: sse_client("http://127.0.0.1:8000/sse")) 67 | stdio_mcp_client = MCPClient( 68 | lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) 69 | ) 70 | with sse_mcp_client, stdio_mcp_client: 71 | agent = Agent(tools=sse_mcp_client.list_tools_sync() + stdio_mcp_client.list_tools_sync()) 72 | agent("add 1 and 2, then echo the result back to me") 73 | 74 | tool_use_content_blocks = _messages_to_content_blocks(agent.messages) 75 | assert any([block["name"] == "echo" for block in tool_use_content_blocks]) 76 | assert any([block["name"] == "calculator" for block in tool_use_content_blocks]) 77 | 78 | image_prompt = """ 79 | Generate a custom image, then tell me if the image is red, blue, yellow, pink, orange, or green. 80 | RESPOND ONLY WITH THE COLOR 81 | """ 82 | assert any( 83 | [ 84 | "yellow".casefold() in block["text"].casefold() 85 | for block in agent(image_prompt).message["content"] 86 | if "text" in block 87 | ] 88 | ) 89 | 90 | 91 | def test_can_reuse_mcp_client(): 92 | stdio_mcp_client = MCPClient( 93 | lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) 94 | ) 95 | with stdio_mcp_client: 96 | stdio_mcp_client.list_tools_sync() 97 | pass 98 | with stdio_mcp_client: 99 | agent = Agent(tools=stdio_mcp_client.list_tools_sync()) 100 | agent("echo the following to me <to_echo>DOG</to_echo>") 101 | 102 | tool_use_content_blocks = _messages_to_content_blocks(agent.messages) 103 | assert any([block["name"] == "echo" for block in tool_use_content_blocks]) 104 | 105 | 106 | @pytest.mark.skipif( 107 | condition=os.environ.get("GITHUB_ACTIONS") == "true", 108 | reason="streamable transport is failing in GitHub actions, debugging if linux compatibility issue", 109 | ) 110 | def test_streamable_http_mcp_client(): 111 | server_thread = threading.Thread( 112 | target=start_calculator_server, kwargs={"transport": "streamable-http", "port": 8001}, daemon=True 113 | ) 114 | server_thread.start() 115 | time.sleep(2) # wait for server to startup completely 116 | 117 | def transport_callback() -> MCPTransport: 118 | return streamablehttp_client(url="http://127.0.0.1:8001/mcp") 119 | 120 | streamable_http_client = MCPClient(transport_callback) 121 | with streamable_http_client: 122 | agent = Agent(tools=streamable_http_client.list_tools_sync()) 123 | agent("add 1 and 2 using a calculator") 124 | 125 | tool_use_content_blocks = _messages_to_content_blocks(agent.messages) 126 | assert any([block["name"] == "calculator" for block in tool_use_content_blocks]) 127 | 128 | 129 | def _messages_to_content_blocks(messages: List[Message]) -> List[ToolUse]: 130 | return [block["toolUse"] for message in messages for block in message["content"] if "toolUse" in block] 131 | -------------------------------------------------------------------------------- /tests_integ/test_multiagent_swarm.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from strands import Agent, tool 4 | from strands.multiagent.swarm import Swarm 5 | from strands.types.content import ContentBlock 6 | 7 | 8 | @tool 9 | def web_search(query: str) -> str: 10 | """Search the web for information.""" 11 | # Mock implementation 12 | return f"Results for '{query}': 25% yearly growth assumption, reaching $1.81 trillion by 2030" 13 | 14 | 15 | @tool 16 | def calculate(expression: str) -> str: 17 | """Calculate the result of a mathematical expression.""" 18 | try: 19 | return f"The result of {expression} is {eval(expression)}" 20 | except Exception as e: 21 | return f"Error calculating {expression}: {str(e)}" 22 | 23 | 24 | @pytest.fixture 25 | def researcher_agent(): 26 | """Create an agent specialized in research.""" 27 | return Agent( 28 | name="researcher", 29 | system_prompt=( 30 | "You are a research specialist who excels at finding information. When you need to perform calculations or" 31 | " format documents, hand off to the appropriate specialist." 32 | ), 33 | tools=[web_search], 34 | ) 35 | 36 | 37 | @pytest.fixture 38 | def analyst_agent(): 39 | """Create an agent specialized in data analysis.""" 40 | return Agent( 41 | name="analyst", 42 | system_prompt=( 43 | "You are a data analyst who excels at calculations and numerical analysis. When you need" 44 | " research or document formatting, hand off to the appropriate specialist." 45 | ), 46 | tools=[calculate], 47 | ) 48 | 49 | 50 | @pytest.fixture 51 | def writer_agent(): 52 | """Create an agent specialized in writing and formatting.""" 53 | return Agent( 54 | name="writer", 55 | system_prompt=( 56 | "You are a professional writer who excels at formatting and presenting information. When you need research" 57 | " or calculations, hand off to the appropriate specialist." 58 | ), 59 | ) 60 | 61 | 62 | def test_swarm_execution_with_string(researcher_agent, analyst_agent, writer_agent): 63 | """Test swarm execution with string input.""" 64 | # Create the swarm 65 | swarm = Swarm([researcher_agent, analyst_agent, writer_agent]) 66 | 67 | # Define a task that requires collaboration 68 | task = ( 69 | "Research the current AI agent market trends, calculate the growth rate assuming 25% yearly growth, " 70 | "and create a basic report" 71 | ) 72 | 73 | # Execute the swarm 74 | result = swarm(task) 75 | 76 | # Verify results 77 | assert result.status.value == "completed" 78 | assert len(result.results) > 0 79 | assert result.execution_time > 0 80 | assert result.execution_count > 0 81 | 82 | # Verify agent history - at least one agent should have been used 83 | assert len(result.node_history) > 0 84 | 85 | 86 | @pytest.mark.asyncio 87 | async def test_swarm_execution_with_image(researcher_agent, analyst_agent, writer_agent, yellow_img): 88 | """Test swarm execution with image input.""" 89 | # Create the swarm 90 | swarm = Swarm([researcher_agent, analyst_agent, writer_agent]) 91 | 92 | # Create content blocks with text and image 93 | content_blocks: list[ContentBlock] = [ 94 | {"text": "Analyze this image and create a report about what you see:"}, 95 | {"image": {"format": "png", "source": {"bytes": yellow_img}}}, 96 | ] 97 | 98 | # Execute the swarm with multi-modal input 99 | result = await swarm.invoke_async(content_blocks) 100 | 101 | # Verify results 102 | assert result.status.value == "completed" 103 | assert len(result.results) > 0 104 | assert result.execution_time > 0 105 | assert result.execution_count > 0 106 | 107 | # Verify agent history - at least one agent should have been used 108 | assert len(result.node_history) > 0 109 | -------------------------------------------------------------------------------- /tests_integ/test_stream_agent.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test script for Strands' custom callback handler functionality. 3 | Demonstrates different patterns of callback handling and processing. 4 | """ 5 | 6 | import logging 7 | 8 | from strands import Agent 9 | 10 | logging.getLogger("strands").setLevel(logging.DEBUG) 11 | logging.basicConfig(format="%(levelname)s | %(name)s | %(message)s", handlers=[logging.StreamHandler()]) 12 | 13 | 14 | class ToolCountingCallbackHandler: 15 | def __init__(self): 16 | self.tool_count = 0 17 | self.message_count = 0 18 | 19 | def callback_handler(self, **kwargs) -> None: 20 | """ 21 | Custom callback handler that processes and displays different types of events. 22 | 23 | Args: 24 | **kwargs: Callback event data including: 25 | - data: Regular output 26 | - complete: Completion status 27 | - message: Message processing 28 | - current_tool_use: Tool execution 29 | """ 30 | # Extract event data 31 | data = kwargs.get("data", "") 32 | complete = kwargs.get("complete", False) 33 | message = kwargs.get("message", {}) 34 | current_tool_use = kwargs.get("current_tool_use", {}) 35 | 36 | # Handle regular data output 37 | if data: 38 | print(f"🔄 Data: {data}") 39 | 40 | # Handle tool execution events 41 | if current_tool_use: 42 | self.tool_count += 1 43 | tool_name = current_tool_use.get("name", "") 44 | tool_input = current_tool_use.get("input", {}) 45 | print(f"🛠️ Tool Execution #{self.tool_count}\nTool: {tool_name}\nInput: {tool_input}") 46 | 47 | # Handle message processing 48 | if message: 49 | self.message_count += 1 50 | print(f"📝 Message #{self.message_count}") 51 | 52 | # Handle completion 53 | if complete: 54 | self.console.print("✨ Callback Complete", style="bold green") 55 | 56 | 57 | def test_basic_interaction(): 58 | """Test basic AGI interaction with custom callback handler.""" 59 | print("\nTesting Basic Interaction") 60 | 61 | # Initialize agent with custom handler 62 | agent = Agent( 63 | callback_handler=ToolCountingCallbackHandler().callback_handler, 64 | load_tools_from_directory=False, 65 | ) 66 | 67 | # Simple prompt to test callbacking 68 | agent("Tell me a short joke from your general knowledge") 69 | 70 | print("\nBasic Interaction Complete") 71 | -------------------------------------------------------------------------------- /tests_integ/yellow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/strands-agents/sdk-python/f5e24d402bb22dc1a54c94dd030b4be0f7e73261/tests_integ/yellow.png --------------------------------------------------------------------------------