The response has been limited to 50k tokens of the smallest files in the repo. You can remove this limitation by removing the max tokens filter.
├── .github
    ├── ISSUE_TEMPLATE
    │   ├── bug_report.yml
    │   ├── config.yml
    │   └── feature_request.yml
    ├── PULL_REQUEST_TEMPLATE.md
    ├── dependabot.yml
    └── workflows
    │   ├── integration-test.yml
    │   ├── pr-and-push.yml
    │   ├── pypi-publish-on-release.yml
    │   └── test-lint.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CONTRIBUTING.md
├── LICENSE
├── NOTICE
├── README.md
├── STYLE_GUIDE.md
├── pyproject.toml
├── src
    └── strands
    │   ├── __init__.py
    │   ├── agent
    │       ├── __init__.py
    │       ├── agent.py
    │       ├── agent_result.py
    │       ├── conversation_manager
    │       │   ├── __init__.py
    │       │   ├── conversation_manager.py
    │       │   ├── null_conversation_manager.py
    │       │   ├── sliding_window_conversation_manager.py
    │       │   └── summarizing_conversation_manager.py
    │       └── state.py
    │   ├── event_loop
    │       ├── __init__.py
    │       ├── event_loop.py
    │       └── streaming.py
    │   ├── experimental
    │       ├── __init__.py
    │       └── hooks
    │       │   ├── __init__.py
    │       │   └── events.py
    │   ├── handlers
    │       ├── __init__.py
    │       └── callback_handler.py
    │   ├── hooks
    │       ├── __init__.py
    │       ├── events.py
    │       ├── registry.py
    │       └── rules.md
    │   ├── models
    │       ├── __init__.py
    │       ├── anthropic.py
    │       ├── bedrock.py
    │       ├── litellm.py
    │       ├── llamaapi.py
    │       ├── mistral.py
    │       ├── model.py
    │       ├── ollama.py
    │       ├── openai.py
    │       └── writer.py
    │   ├── multiagent
    │       ├── __init__.py
    │       ├── a2a
    │       │   ├── __init__.py
    │       │   ├── executor.py
    │       │   └── server.py
    │       ├── base.py
    │       ├── graph.py
    │       └── swarm.py
    │   ├── py.typed
    │   ├── session
    │       ├── __init__.py
    │       ├── file_session_manager.py
    │       ├── repository_session_manager.py
    │       ├── s3_session_manager.py
    │       ├── session_manager.py
    │       └── session_repository.py
    │   ├── telemetry
    │       ├── __init__.py
    │       ├── config.py
    │       ├── metrics.py
    │       ├── metrics_constants.py
    │       └── tracer.py
    │   ├── tools
    │       ├── __init__.py
    │       ├── decorator.py
    │       ├── executor.py
    │       ├── loader.py
    │       ├── mcp
    │       │   ├── __init__.py
    │       │   ├── mcp_agent_tool.py
    │       │   ├── mcp_client.py
    │       │   └── mcp_types.py
    │       ├── registry.py
    │       ├── structured_output.py
    │       ├── tools.py
    │       └── watcher.py
    │   └── types
    │       ├── __init__.py
    │       ├── collections.py
    │       ├── content.py
    │       ├── event_loop.py
    │       ├── exceptions.py
    │       ├── guardrails.py
    │       ├── media.py
    │       ├── session.py
    │       ├── streaming.py
    │       ├── tools.py
    │       └── traces.py
├── tests
    ├── __init__.py
    ├── conftest.py
    ├── fixtures
    │   ├── mock_hook_provider.py
    │   ├── mock_session_repository.py
    │   └── mocked_model_provider.py
    └── strands
    │   ├── __init__.py
    │   ├── agent
    │       ├── __init__.py
    │       ├── test_agent.py
    │       ├── test_agent_hooks.py
    │       ├── test_agent_result.py
    │       ├── test_agent_state.py
    │       ├── test_conversation_manager.py
    │       └── test_summarizing_conversation_manager.py
    │   ├── event_loop
    │       ├── __init__.py
    │       ├── test_event_loop.py
    │       └── test_streaming.py
    │   ├── experimental
    │       ├── __init__.py
    │       └── hooks
    │       │   ├── __init__.py
    │       │   ├── test_events.py
    │       │   └── test_hook_registry.py
    │   ├── handlers
    │       ├── __init__.py
    │       └── test_callback_handler.py
    │   ├── models
    │       ├── __init__.py
    │       ├── test_anthropic.py
    │       ├── test_bedrock.py
    │       ├── test_litellm.py
    │       ├── test_llamaapi.py
    │       ├── test_mistral.py
    │       ├── test_model.py
    │       ├── test_ollama.py
    │       ├── test_openai.py
    │       └── test_writer.py
    │   ├── multiagent
    │       ├── __init__.py
    │       ├── a2a
    │       │   ├── __init__.py
    │       │   ├── conftest.py
    │       │   ├── test_executor.py
    │       │   └── test_server.py
    │       ├── test_base.py
    │       ├── test_graph.py
    │       └── test_swarm.py
    │   ├── session
    │       ├── __init__.py
    │       ├── test_file_session_manager.py
    │       ├── test_repository_session_manager.py
    │       └── test_s3_session_manager.py
    │   ├── telemetry
    │       ├── test_config.py
    │       ├── test_metrics.py
    │       └── test_tracer.py
    │   ├── tools
    │       ├── __init__.py
    │       ├── mcp
    │       │   ├── __init__.py
    │       │   ├── test_mcp_agent_tool.py
    │       │   └── test_mcp_client.py
    │       ├── test_decorator.py
    │       ├── test_executor.py
    │       ├── test_loader.py
    │       ├── test_registry.py
    │       ├── test_structured_output.py
    │       ├── test_tools.py
    │       └── test_watcher.py
    │   └── types
    │       └── test_session.py
└── tests_integ
    ├── __init__.py
    ├── conftest.py
    ├── echo_server.py
    ├── models
        ├── __init__.py
        ├── conformance.py
        ├── providers.py
        ├── test_model_anthropic.py
        ├── test_model_bedrock.py
        ├── test_model_cohere.py
        ├── test_model_litellm.py
        ├── test_model_llamaapi.py
        ├── test_model_mistral.py
        ├── test_model_ollama.py
        ├── test_model_openai.py
        └── test_model_writer.py
    ├── test_agent_async.py
    ├── test_bedrock_cache_point.py
    ├── test_bedrock_guardrails.py
    ├── test_context_overflow.py
    ├── test_function_tools.py
    ├── test_hot_tool_reload_decorator.py
    ├── test_mcp_client.py
    ├── test_multiagent_graph.py
    ├── test_multiagent_swarm.py
    ├── test_session.py
    ├── test_stream_agent.py
    ├── test_summarizing_conversation_manager_integration.py
    └── yellow.png


/.github/ISSUE_TEMPLATE/bug_report.yml:
--------------------------------------------------------------------------------
 1 | name: Bug Report
 2 | description: Report a bug in the Strands Agents SDK
 3 | title: "[BUG] "
 4 | labels: ["bug", "triage"]
 5 | assignees: []
 6 | body:
 7 |   - type: markdown
 8 |     attributes:
 9 |       value: |
10 |         Thanks for taking the time to fill out this bug report for Strands SDK!
11 |   - type: checkboxes
12 |     id: "checks"
13 |     attributes:
14 |       label: "Checks"
15 |       options:
16 |         - label: "I have updated to the lastest minor and patch version of Strands"
17 |           required: true
18 |         - label: "I have checked the documentation and this is not expected behavior"
19 |           required: true
20 |         - label: "I have searched [./issues](./issues?q=) and there are no duplicates of my issue"
21 |           required: true
22 |   - type: input
23 |     id: strands-version
24 |     attributes:
25 |       label: Strands Version
26 |       description: Which version of Strands are you using?
27 |       placeholder: e.g., 0.5.2
28 |     validations:
29 |       required: true
30 |   - type: input
31 |     id: python-version
32 |     attributes:
33 |       label: Python Version
34 |       description: Which version of Python are you using?
35 |       placeholder: e.g., 3.10.5
36 |     validations:
37 |       required: true
38 |   - type: input
39 |     id: os
40 |     attributes:
41 |       label: Operating System
42 |       description: Which operating system are you using?
43 |       placeholder: e.g., macOS 12.6
44 |     validations:
45 |       required: true
46 |   - type: dropdown
47 |     id: installation-method
48 |     attributes:
49 |       label: Installation Method
50 |       description: How did you install Strands?
51 |       options:
52 |         - pip
53 |         - git clone
54 |         - binary
55 |         - other
56 |     validations:
57 |       required: true
58 |   - type: textarea
59 |     id: steps-to-reproduce
60 |     attributes:
61 |       label: Steps to Reproduce
62 |       description: Detailed steps to reproduce the behavior
63 |       placeholder: |
64 |         1. Install Strands using...
65 |         2. Run the command...
66 |         3. See error...
67 |     validations:
68 |       required: true
69 |   - type: textarea
70 |     id: expected-behavior
71 |     attributes:
72 |       label: Expected Behavior
73 |       description: A clear description of what you expected to happen
74 |     validations:
75 |       required: true
76 |   - type: textarea
77 |     id: actual-behavior
78 |     attributes:
79 |       label: Actual Behavior
80 |       description: What actually happened
81 |     validations:
82 |       required: true
83 |   - type: textarea
84 |     id: additional-context
85 |     attributes:
86 |       label: Additional Context
87 |       description: Any other relevant information, logs, screenshots, etc.
88 |   - type: textarea
89 |     id: possible-solution
90 |     attributes:
91 |       label: Possible Solution
92 |       description: Optional - If you have suggestions on how to fix the bug
93 |   - type: input
94 |     id: related-issues
95 |     attributes:
96 |       label: Related Issues
97 |       description: Optional - Link to related issues if applicable
98 | 


--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: false
2 | contact_links:
3 |   - name: Strands Agents SDK Support
4 |     url: https://github.com/strands-agents/sdk-python/discussions
5 |     about: Please ask and answer questions here
6 |   - name: Strands Agents SDK Documentation
7 |     url: https://github.com/strands-agents/docs
8 |     about: Visit our documentation for help
9 | 


--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.yml:
--------------------------------------------------------------------------------
 1 | name: Feature Request
 2 | description: Suggest a new feature or enhancement for Strands Agents SDK
 3 | title: "[FEATURE] "
 4 | labels: ["enhancement", "triage"]
 5 | assignees: []
 6 | body:
 7 |   - type: markdown
 8 |     attributes:
 9 |       value: |
10 |         Thanks for suggesting a new feature for Strands Agents SDK!
11 |   - type: textarea
12 |     id: problem-statement
13 |     attributes:
14 |       label: Problem Statement
15 |       description: Describe the problem you're trying to solve. What is currently difficult or impossible to do?
16 |       placeholder: I would like Strands to...
17 |     validations:
18 |       required: true
19 |   - type: textarea
20 |     id: proposed-solution
21 |     attributes:
22 |       label: Proposed Solution
23 |       description: Optional - Describe your proposed solution in detail. How would this feature work?
24 |   - type: textarea
25 |     id: use-case
26 |     attributes:
27 |       label: Use Case
28 |       description: Provide specific use cases for the feature. How would people use it?
29 |       placeholder: This would help with...
30 |     validations:
31 |       required: true
32 |   - type: textarea
33 |     id: alternatives-solutions
34 |     attributes:
35 |       label: Alternatives Solutions
36 |       description: Optional - Have you considered alternative approaches? What are their pros and cons?
37 |   - type: textarea
38 |     id: additional-context
39 |     attributes:
40 |       label: Additional Context
41 |       description: Include any other context, screenshots, code examples, or references that might help understand the feature request.
42 | 


--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
 1 | ## Description
 2 | <!-- Provide a detailed description of the changes in this PR -->
 3 | 
 4 | ## Related Issues
 5 | 
 6 | <!-- Link to related issues using #issue-number format -->
 7 | 
 8 | ## Documentation PR
 9 | 
10 | <!-- Link to related associated PR in the agent-docs repo -->
11 | 
12 | ## Type of Change
13 | 
14 | <!-- Choose one of the following types of changes, delete the rest -->
15 | 
16 | Bug fix
17 | New feature
18 | Breaking change
19 | Documentation update
20 | Other (please describe):
21 | 
22 | ## Testing
23 | 
24 | How have you tested the change?  Verify that the changes do not break functionality or introduce warnings in consuming repositories: agents-docs, agents-tools, agents-cli
25 | 
26 | - [ ] I ran `hatch run prepare`
27 | 
28 | ## Checklist
29 | - [ ] I have read the CONTRIBUTING document
30 | - [ ] I have added any necessary tests that prove my fix is effective or my feature works
31 | - [ ] I have updated the documentation accordingly
32 | - [ ] I have added an appropriate example to the documentation to outline the feature, or no new docs are needed
33 | - [ ] My changes generate no new warnings
34 | - [ ] Any dependent changes have been merged and published
35 | 
36 | ----
37 | 
38 | By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
39 | 


--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
 1 | version: 2
 2 | updates:
 3 |   - package-ecosystem: "pip"
 4 |     directory: "/"
 5 |     schedule:
 6 |       interval: "daily"
 7 |     open-pull-requests-limit: 100
 8 |     commit-message:
 9 |       prefix: ci
10 |     groups:
11 |       dev-dependencies:
12 |         patterns:
13 |           - "pytest"
14 |   - package-ecosystem: "github-actions"
15 |     directory: "/"
16 |     schedule:
17 |       interval: "daily"
18 |     open-pull-requests-limit: 100
19 |     commit-message:
20 |       prefix: ci
21 | 


--------------------------------------------------------------------------------
/.github/workflows/integration-test.yml:
--------------------------------------------------------------------------------
 1 | name: Secure Integration test
 2 | 
 3 | on:
 4 |   pull_request_target:
 5 |     branches: main
 6 | 
 7 | jobs:
 8 |   authorization-check:
 9 |     permissions: read-all
10 |     runs-on: ubuntu-latest
11 |     outputs:
12 |       approval-env: ${{ steps.collab-check.outputs.result }}
13 |     steps:
14 |       - name: Collaborator Check
15 |         uses: actions/github-script@v7
16 |         id: collab-check
17 |         with:
18 |           result-encoding: string
19 |           script: |
20 |             try {
21 |               const permissionResponse = await github.rest.repos.getCollaboratorPermissionLevel({
22 |                 owner: context.repo.owner,
23 |                 repo: context.repo.repo,
24 |                 username: context.payload.pull_request.user.login,
25 |               });
26 |               const permission = permissionResponse.data.permission;
27 |               const hasWriteAccess = ['write', 'admin'].includes(permission);
28 |               if (!hasWriteAccess) {
29 |                 console.log(`User ${context.payload.pull_request.user.login} does not have write access to the repository (permission: ${permission})`);
30 |                 return "manual-approval"
31 |               } else {
32 |                 console.log(`Verifed ${context.payload.pull_request.user.login} has write access. Auto Approving PR Checks.`)
33 |                 return "auto-approve"
34 |               }
35 |             } catch (error) {
36 |               console.log(`${context.payload.pull_request.user.login} does not have write access. Requiring Manual Approval to run PR Checks.`)
37 |               return "manual-approval"
38 |             }
39 |   check-access-and-checkout:
40 |     runs-on: ubuntu-latest
41 |     needs: authorization-check
42 |     environment: ${{ needs.authorization-check.outputs.approval-env }}
43 |     permissions:
44 |       id-token: write
45 |       pull-requests: read
46 |       contents: read
47 |     steps:
48 |       - name: Configure Credentials
49 |         uses: aws-actions/configure-aws-credentials@v4
50 |         with:
51 |          role-to-assume: ${{ secrets.STRANDS_INTEG_TEST_ROLE }}
52 |          aws-region: us-east-1
53 |          mask-aws-account-id: true
54 |       - name: Checkout head commit
55 |         uses: actions/checkout@v4
56 |         with:
57 |           ref: ${{ github.event.pull_request.head.sha }} # Pull the commit from the forked repo
58 |           persist-credentials: false  # Don't persist credentials for subsequent actions
59 |       - name: Set up Python
60 |         uses: actions/setup-python@v5
61 |         with:
62 |           python-version: '3.10'
63 |       - name: Install dependencies
64 |         run: |
65 |           pip install --no-cache-dir hatch
66 |       - name: Run integration tests
67 |         env:
68 |           AWS_REGION: us-east-1
69 |           AWS_REGION_NAME: us-east-1 # Needed for LiteLLM
70 |         id: tests
71 |         run: |
72 |           hatch test tests_integ
73 | 


--------------------------------------------------------------------------------
/.github/workflows/pr-and-push.yml:
--------------------------------------------------------------------------------
 1 | name: Pull Request and Push Action
 2 | 
 3 | on:
 4 |   pull_request:  # Safer than pull_request_target for untrusted code
 5 |     branches: [ main ]
 6 |     types: [opened, synchronize, reopened, ready_for_review, review_requested, review_request_removed]
 7 |   push:
 8 |     branches: [ main ]  # Also run on direct pushes to main
 9 | concurrency:
10 |   group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
11 |   cancel-in-progress: true
12 | 
13 | jobs:
14 |   call-test-lint:
15 |     uses: ./.github/workflows/test-lint.yml
16 |     permissions:
17 |       contents: read
18 |     with:
19 |       ref: ${{ github.event.pull_request.head.sha }}
20 | 


--------------------------------------------------------------------------------
/.github/workflows/pypi-publish-on-release.yml:
--------------------------------------------------------------------------------
 1 | name: Publish Python Package
 2 | 
 3 | on:
 4 |   release:
 5 |     types:
 6 |       - published
 7 | 
 8 | jobs:
 9 |   call-test-lint:
10 |     uses: ./.github/workflows/test-lint.yml
11 |     permissions:
12 |       contents: read
13 |     with:
14 |       ref: ${{ github.event.release.target_commitish }}
15 | 
16 |   build:
17 |     name: Build distribution 📦
18 |     permissions:
19 |       contents: read
20 |     needs:
21 |       - call-test-lint
22 |     runs-on: ubuntu-latest
23 | 
24 |     steps:
25 |     - uses: actions/checkout@v4
26 |       with:
27 |         persist-credentials: false
28 | 
29 |     - name: Set up Python
30 |       uses: actions/setup-python@v5
31 |       with:
32 |         python-version: '3.10'
33 | 
34 |     - name: Install dependencies
35 |       run: |
36 |         python -m pip install --upgrade pip
37 |         pip install hatch twine
38 | 
39 |     - name: Validate version
40 |       run: |
41 |         version=$(hatch version)
42 |         if [[ $version =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
43 |             echo "Valid version format"
44 |             exit 0
45 |         else
46 |             echo "Invalid version format"
47 |             exit 1
48 |         fi
49 | 
50 |     - name: Build
51 |       run: |
52 |         hatch build
53 | 
54 |     - name: Store the distribution packages
55 |       uses: actions/upload-artifact@v4
56 |       with:
57 |         name: python-package-distributions
58 |         path: dist/
59 | 
60 |   deploy:
61 |     name: Upload release to PyPI
62 |     needs:
63 |     - build
64 |     runs-on: ubuntu-latest
65 | 
66 |     # environment is used by PyPI Trusted Publisher and is strongly encouraged
67 |     # https://docs.pypi.org/trusted-publishers/adding-a-publisher/
68 |     environment:
69 |       name: pypi
70 |       url: https://pypi.org/p/strands-agents
71 |     permissions:
72 |       # IMPORTANT: this permission is mandatory for Trusted Publishing
73 |       id-token: write
74 |     
75 |     steps:
76 |     - name: Download all the dists
77 |       uses: actions/download-artifact@v4
78 |       with:
79 |         name: python-package-distributions
80 |         path: dist/
81 |     - name: Publish distribution 📦 to PyPI
82 |       uses: pypa/gh-action-pypi-publish@release/v1
83 | 


--------------------------------------------------------------------------------
/.github/workflows/test-lint.yml:
--------------------------------------------------------------------------------
 1 | name: Test and Lint
 2 | 
 3 | on:
 4 |   workflow_call:
 5 |     inputs:
 6 |       ref:
 7 |         required: true
 8 |         type: string
 9 | 
10 | jobs:
11 |   unit-test:
12 |     name: Unit Tests - Python ${{ matrix.python-version }} - ${{ matrix.os-name }}
13 |     permissions:
14 |       contents: read
15 |     strategy:
16 |       matrix:
17 |        include:
18 |         # Linux
19 |         - os: ubuntu-latest
20 |           os-name: 'linux'
21 |           python-version: "3.10"
22 |         - os: ubuntu-latest
23 |           os-name: 'linux'
24 |           python-version: "3.11"
25 |         - os: ubuntu-latest
26 |           os-name: 'linux'
27 |           python-version: "3.12"
28 |         - os: ubuntu-latest
29 |           os-name: 'linux'
30 |           python-version: "3.13"
31 |         # Windows
32 |         - os: windows-latest
33 |           os-name: 'windows'
34 |           python-version: "3.10"
35 |         - os: windows-latest
36 |           os-name: 'windows'
37 |           python-version: "3.11"
38 |         - os: windows-latest
39 |           os-name: 'windows'
40 |           python-version: "3.12"
41 |         - os: windows-latest
42 |           os-name: 'windows'
43 |           python-version: "3.13"
44 |         # MacOS - latest only; not enough runners for macOS
45 |         - os: macos-latest
46 |           os-name: 'macOS'
47 |           python-version: "3.13"
48 |       fail-fast: true
49 |     runs-on: ${{ matrix.os }}
50 |     env:
51 |       LOG_LEVEL: DEBUG
52 |     steps:
53 |       - name: Checkout code
54 |         uses: actions/checkout@v4
55 |         with:
56 |           ref: ${{ inputs.ref }}  # Explicitly define which commit to check out
57 |           persist-credentials: false  # Don't persist credentials for subsequent actions
58 |       - name: Set up Python
59 |         uses: actions/setup-python@v5
60 |         with:
61 |           python-version: ${{ matrix.python-version }}
62 |       - name: Install dependencies
63 |         run: |
64 |           pip install --no-cache-dir hatch
65 |       - name: Run Unit tests
66 |         id: tests
67 |         run: hatch test tests --cover
68 |         continue-on-error: false
69 |   lint:
70 |     name: Lint
71 |     runs-on: ubuntu-latest
72 |     permissions:
73 |       contents: read
74 |     steps:
75 |       - name: Checkout code
76 |         uses: actions/checkout@v4
77 |         with:
78 |           ref: ${{ inputs.ref }}
79 |           persist-credentials: false
80 | 
81 |       - name: Set up Python
82 |         uses: actions/setup-python@v5
83 |         with:
84 |           python-version: '3.10'
85 |           cache: 'pip'
86 | 
87 |       - name: Install dependencies
88 |         run: |
89 |           pip install --no-cache-dir hatch
90 | 
91 |       - name: Run lint
92 |         id: lint
93 |         run: hatch run test-lint
94 |         continue-on-error: false
95 | 


--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
 1 | build
 2 | __pycache__*
 3 | .coverage*
 4 | .env
 5 | .venv
 6 | .mypy_cache
 7 | .pytest_cache
 8 | .ruff_cache
 9 | *.bak
10 | .vscode
11 | dist
12 | repl_state


--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
 1 | repos:
 2 |   - repo: local
 3 |     hooks:
 4 |       - id: hatch-format
 5 |         name: Format code
 6 |         entry: hatch fmt --formatter
 7 |         language: system
 8 |         pass_filenames: false
 9 |         types: [python]
10 |         stages: [pre-commit]
11 |       - id: hatch-lint
12 |         name: Lint code
13 |         entry: hatch run test-lint
14 |         language: system
15 |         pass_filenames: false
16 |         types: [python]
17 |         stages: [pre-commit]
18 |       - id: hatch-test-lint
19 |         name: Type linting
20 |         entry: hatch run test-lint
21 |         language: system
22 |         pass_filenames: false
23 |         types: [ python ]
24 |         stages: [ pre-commit ]
25 |       - id: hatch-test
26 |         name: Unit tests
27 |         entry: hatch test
28 |         language: system
29 |         pass_filenames: false
30 |         types: [python]
31 |         stages: [pre-commit]
32 |       - id: commitizen-check
33 |         name: Check commit message
34 |         entry: hatch run cz check --commit-msg-file
35 |         language: system
36 |         stages: [commit-msg]


--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
  1 | # Contributing Guidelines
  2 | 
  3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional
  4 | documentation, we greatly value feedback and contributions from our community.
  5 | 
  6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary
  7 | information to effectively respond to your bug report or contribution.
  8 | 
  9 | 
 10 | ## Reporting Bugs/Feature Requests
 11 | 
 12 | We welcome you to use the [Bug Reports](../../issues/new?template=bug_report.yml) file to report bugs or [Feature Requests](../../issues/new?template=feature_request.yml) to suggest features.
 13 | 
 14 | For a list of known bugs and feature requests:
 15 | - Check [Bug Reports](../../issues?q=is%3Aissue%20state%3Aopen%20label%3Abug) for currently tracked issues
 16 | - See [Feature Requests](../../issues?q=is%3Aissue%20state%3Aopen%20label%3Aenhancement) for requested enhancements
 17 | 
 18 | When filing an issue, please check for already tracked items
 19 | 
 20 | Please try to include as much information as you can. Details like these are incredibly useful:
 21 | 
 22 | * A reproducible test case or series of steps
 23 | * The version of our code being used (commit ID)
 24 | * Any modifications you've made relevant to the bug
 25 | * Anything unusual about your environment or deployment
 26 | 
 27 | 
 28 | ## Development Environment
 29 | 
 30 | This project uses [hatchling](https://hatch.pypa.io/latest/build/#hatchling) as the build backend and [hatch](https://hatch.pypa.io/latest/) for development workflow management.
 31 | 
 32 | ### Setting Up Your Development Environment
 33 | 
 34 | 1. Entering virtual environment using `hatch` (recommended), then launch your IDE in the new shell.
 35 |    ```bash
 36 |    hatch shell dev
 37 |    ```
 38 | 
 39 |    Alternatively, install development dependencies in a manually created virtual environment:
 40 |    ```bash
 41 |    pip install -e ".[dev]" && pip install -e ".[litellm]"
 42 |    ```
 43 | 
 44 | 
 45 | 2. Set up pre-commit hooks:
 46 |    ```bash
 47 |    pre-commit install -t pre-commit -t commit-msg
 48 |    ```
 49 |    This will automatically run formatters and conventional commit checks on your code before each commit.
 50 | 
 51 | 3. Run code formatters manually:
 52 |    ```bash
 53 |    hatch fmt --formatter
 54 |    ```
 55 | 
 56 | 4. Run linters:
 57 |    ```bash
 58 |    hatch fmt --linter
 59 |    ```
 60 | 
 61 | 5. Run unit tests:
 62 |    ```bash
 63 |    hatch test
 64 |    ```
 65 | 
 66 | 6. Run integration tests:
 67 |    ```bash
 68 |    hatch run test-integ
 69 |    ```
 70 | 
 71 | ### Pre-commit Hooks
 72 | 
 73 | We use [pre-commit](https://pre-commit.com/) to automatically run quality checks before each commit. The hook will run `hatch run format`, `hatch run lint`, `hatch run test`, and `hatch run cz check` on when you make a commit, ensuring code consistency.
 74 | 
 75 | The pre-commit hook is installed with:
 76 | 
 77 | ```bash
 78 | pre-commit install
 79 | ```
 80 | 
 81 | You can also run the hooks manually on all files:
 82 | 
 83 | ```bash
 84 | pre-commit run --all-files
 85 | ```
 86 | 
 87 | ### Code Formatting and Style Guidelines
 88 | 
 89 | We use the following tools to ensure code quality:
 90 | 1. **ruff** - For formatting and linting
 91 | 2. **mypy** - For static type checking
 92 | 
 93 | These tools are configured in the [pyproject.toml](./pyproject.toml) file. Please ensure your code passes all linting and type checks before submitting a pull request:
 94 | 
 95 | ```bash
 96 | # Run all checks
 97 | hatch fmt --formatter
 98 | hatch fmt --linter
 99 | ```
100 | 
101 | If you're using an IDE like VS Code or PyCharm, consider configuring it to use these tools automatically.
102 | 
103 | For additional details on styling, please see our dedicated [Style Guide](./STYLE_GUIDE.md).
104 | 
105 | 
106 | ## Contributing via Pull Requests
107 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:
108 | 
109 | 1. You are working against the latest source on the *main* branch.
110 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.
111 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted.
112 | 
113 | To send us a pull request, please:
114 | 
115 | 1. Create a branch.
116 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
117 | 3. Format your code using `hatch fmt --formatter`.
118 | 4. Run linting checks with `hatch fmt --linter`.
119 | 5. Ensure local tests pass with `hatch test` and `hatch run test-integ`.
120 | 6. Commit to your branch using clear commit messages following the [Conventional Commits](https://www.conventionalcommits.org) specification.
121 | 7. Send us a pull request, answering any default questions in the pull request interface.
122 | 8. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation.
123 | 
124 | 
125 | ## Finding contributions to work on
126 | Looking at the existing issues is a great way to find something to contribute to.
127 | 
128 | You can check:
129 | - Our known bugs list in [Bug Reports](../../issues?q=is%3Aissue%20state%3Aopen%20label%3Abug) for issues that need fixing
130 | - Feature requests in [Feature Requests](../../issues?q=is%3Aissue%20state%3Aopen%20label%3Aenhancement) for new functionality to implement
131 | 
132 | 
133 | ## Code of Conduct
134 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
135 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
136 | opensource-codeofconduct@amazon.com with any additional questions or comments.
137 | 
138 | 
139 | ## Security issue notifications
140 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.
141 | 
142 | 
143 | ## Licensing
144 | 
145 | See the [LICENSE](./LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.
146 | 


--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.


--------------------------------------------------------------------------------
/STYLE_GUIDE.md:
--------------------------------------------------------------------------------
 1 | # Style Guide
 2 | 
 3 | ## Overview
 4 | 
 5 | The Strands Agents style guide aims to establish consistent formatting, naming conventions, and structure across all code in the repository. We strive to make our code clean, readable, and maintainable.
 6 | 
 7 | Where possible, we will codify these style guidelines into our linting rules and pre-commit hooks to automate enforcement and reduce the manual review burden.
 8 | 
 9 | ## Log Formatting
10 | 
11 | The format for Strands Agents logs is as follows:
12 | 
13 | ```python
14 | logger.debug("field1=<%s>, field2=<%s>, ... | human readable message", field1, field2, ...)
15 | ```
16 | 
17 | ### Guidelines
18 | 
19 | 1. **Context**:
20 |    - Add context as `<FIELD>=<VALUE>` pairs at the beginning of the log
21 |      - Many log services (CloudWatch, Splunk, etc.) look for these patterns to extract fields for searching
22 |    - Use `,`'s to separate pairs
23 |    - Enclose values in `<>` for readability
24 |      - This is particularly helpful in displaying empty values (`field=` vs `field=<>`)
25 |    - Use `%s` for string interpolation as recommended by Python logging
26 |      - This is an optimization to skip string interpolation when the log level is not enabled
27 | 
28 | 1. **Messages**:
29 |    - Add human-readable messages at the end of the log
30 |    - Use lowercase for consistency
31 |    - Avoid punctuation (periods, exclamation points, etc.) to reduce clutter
32 |    - Keep messages concise and focused on a single statement
33 |    - If multiple statements are needed, separate them with the pipe character (`|`)
34 |      - Example: `"processing request | starting validation"`
35 | 
36 | ### Examples
37 | 
38 | #### Good
39 | 
40 | ```python
41 | logger.debug("user_id=<%s>, action=<%s> | user performed action", user_id, action)
42 | logger.info("request_id=<%s>, duration_ms=<%d> | request completed", request_id, duration)
43 | logger.warning("attempt=<%d>, max_attempts=<%d> | retry limit approaching", attempt, max_attempts)
44 | ```
45 | 
46 | #### Poor
47 | 
48 | ```python
49 | # Avoid: No structured fields, direct variable interpolation in message
50 | logger.debug(f"User {user_id} performed action {action}")
51 | 
52 | # Avoid: Inconsistent formatting, punctuation
53 | logger.info("Request completed in %d ms.", duration)
54 | 
55 | # Avoid: No separation between fields and message
56 | logger.warning("Retry limit approaching! attempt=%d max_attempts=%d", attempt, max_attempts)
57 | ```
58 | 
59 | By following these log formatting guidelines, we ensure that logs are both human-readable and machine-parseable, making debugging and monitoring more efficient.
60 | 


--------------------------------------------------------------------------------
/src/strands/__init__.py:
--------------------------------------------------------------------------------
1 | """A framework for building, deploying, and managing AI agents."""
2 | 
3 | from . import agent, models, telemetry, types
4 | from .agent.agent import Agent
5 | from .tools.decorator import tool
6 | 
7 | __all__ = ["Agent", "agent", "models", "tool", "types", "telemetry"]
8 | 


--------------------------------------------------------------------------------
/src/strands/agent/__init__.py:
--------------------------------------------------------------------------------
 1 | """This package provides the core Agent interface and supporting components for building AI agents with the SDK.
 2 | 
 3 | It includes:
 4 | 
 5 | - Agent: The main interface for interacting with AI models and tools
 6 | - ConversationManager: Classes for managing conversation history and context windows
 7 | """
 8 | 
 9 | from .agent import Agent
10 | from .agent_result import AgentResult
11 | from .conversation_manager import (
12 |     ConversationManager,
13 |     NullConversationManager,
14 |     SlidingWindowConversationManager,
15 |     SummarizingConversationManager,
16 | )
17 | 
18 | __all__ = [
19 |     "Agent",
20 |     "AgentResult",
21 |     "ConversationManager",
22 |     "NullConversationManager",
23 |     "SlidingWindowConversationManager",
24 |     "SummarizingConversationManager",
25 | ]
26 | 


--------------------------------------------------------------------------------
/src/strands/agent/agent_result.py:
--------------------------------------------------------------------------------
 1 | """Agent result handling for SDK.
 2 | 
 3 | This module defines the AgentResult class which encapsulates the complete response from an agent's processing cycle.
 4 | """
 5 | 
 6 | from dataclasses import dataclass
 7 | from typing import Any
 8 | 
 9 | from ..telemetry.metrics import EventLoopMetrics
10 | from ..types.content import Message
11 | from ..types.streaming import StopReason
12 | 
13 | 
14 | @dataclass
15 | class AgentResult:
16 |     """Represents the last result of invoking an agent with a prompt.
17 | 
18 |     Attributes:
19 |         stop_reason: The reason why the agent's processing stopped.
20 |         message: The last message generated by the agent.
21 |         metrics: Performance metrics collected during processing.
22 |         state: Additional state information from the event loop.
23 |     """
24 | 
25 |     stop_reason: StopReason
26 |     message: Message
27 |     metrics: EventLoopMetrics
28 |     state: Any
29 | 
30 |     def __str__(self) -> str:
31 |         """Get the agent's last message as a string.
32 | 
33 |         This method extracts and concatenates all text content from the final message, ignoring any non-text content
34 |         like images or structured data.
35 | 
36 |         Returns:
37 |             The agent's last message as a string.
38 |         """
39 |         content_array = self.message.get("content", [])
40 | 
41 |         result = ""
42 |         for item in content_array:
43 |             if isinstance(item, dict) and "text" in item:
44 |                 result += item.get("text", "") + "\n"
45 | 
46 |         return result
47 | 


--------------------------------------------------------------------------------
/src/strands/agent/conversation_manager/__init__.py:
--------------------------------------------------------------------------------
 1 | """This package provides classes for managing conversation history during agent execution.
 2 | 
 3 | It includes:
 4 | 
 5 | - ConversationManager: Abstract base class defining the conversation management interface
 6 | - NullConversationManager: A no-op implementation that does not modify conversation history
 7 | - SlidingWindowConversationManager: An implementation that maintains a sliding window of messages to control context
 8 |   size while preserving conversation coherence
 9 | - SummarizingConversationManager: An implementation that summarizes older context instead
10 |   of simply trimming it
11 | 
12 | Conversation managers help control memory usage and context length while maintaining relevant conversation state, which
13 | is critical for effective agent interactions.
14 | """
15 | 
16 | from .conversation_manager import ConversationManager
17 | from .null_conversation_manager import NullConversationManager
18 | from .sliding_window_conversation_manager import SlidingWindowConversationManager
19 | from .summarizing_conversation_manager import SummarizingConversationManager
20 | 
21 | __all__ = [
22 |     "ConversationManager",
23 |     "NullConversationManager",
24 |     "SlidingWindowConversationManager",
25 |     "SummarizingConversationManager",
26 | ]
27 | 


--------------------------------------------------------------------------------
/src/strands/agent/conversation_manager/conversation_manager.py:
--------------------------------------------------------------------------------
 1 | """Abstract interface for conversation history management."""
 2 | 
 3 | from abc import ABC, abstractmethod
 4 | from typing import TYPE_CHECKING, Any, Optional
 5 | 
 6 | from ...types.content import Message
 7 | 
 8 | if TYPE_CHECKING:
 9 |     from ...agent.agent import Agent
10 | 
11 | 
12 | class ConversationManager(ABC):
13 |     """Abstract base class for managing conversation history.
14 | 
15 |     This class provides an interface for implementing conversation management strategies to control the size of message
16 |     arrays/conversation histories, helping to:
17 | 
18 |     - Manage memory usage
19 |     - Control context length
20 |     - Maintain relevant conversation state
21 |     """
22 | 
23 |     def __init__(self) -> None:
24 |         """Initialize the ConversationManager.
25 | 
26 |         Attributes:
27 |           removed_message_count: The messages that have been removed from the agents messages array.
28 |               These represent messages provided by the user or LLM that have been removed, not messages
29 |               included by the conversation manager through something like summarization.
30 |         """
31 |         self.removed_message_count = 0
32 | 
33 |     def restore_from_session(self, state: dict[str, Any]) -> Optional[list[Message]]:
34 |         """Restore the Conversation Manager's state from a session.
35 | 
36 |         Args:
37 |             state: Previous state of the conversation manager
38 |         Returns:
39 |             Optional list of messages to prepend to the agents messages. By defualt returns None.
40 |         """
41 |         if state.get("__name__") != self.__class__.__name__:
42 |             raise ValueError("Invalid conversation manager state.")
43 |         self.removed_message_count = state["removed_message_count"]
44 |         return None
45 | 
46 |     def get_state(self) -> dict[str, Any]:
47 |         """Get the current state of a Conversation Manager as a Json serializable dictionary."""
48 |         return {
49 |             "__name__": self.__class__.__name__,
50 |             "removed_message_count": self.removed_message_count,
51 |         }
52 | 
53 |     @abstractmethod
54 |     def apply_management(self, agent: "Agent", **kwargs: Any) -> None:
55 |         """Applies management strategy to the provided agent.
56 | 
57 |         Processes the conversation history to maintain appropriate size by modifying the messages list in-place.
58 |         Implementations should handle message pruning, summarization, or other size management techniques to keep the
59 |         conversation context within desired bounds.
60 | 
61 |         Args:
62 |             agent: The agent whose conversation history will be manage.
63 |                 This list is modified in-place.
64 |             **kwargs: Additional keyword arguments for future extensibility.
65 |         """
66 |         pass
67 | 
68 |     @abstractmethod
69 |     def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs: Any) -> None:
70 |         """Called when the model's context window is exceeded.
71 | 
72 |         This method should implement the specific strategy for reducing the window size when a context overflow occurs.
73 |         It is typically called after a ContextWindowOverflowException is caught.
74 | 
75 |         Implementations might use strategies such as:
76 | 
77 |         - Removing the N oldest messages
78 |         - Summarizing older context
79 |         - Applying importance-based filtering
80 |         - Maintaining critical conversation markers
81 | 
82 |         Args:
83 |             agent: The agent whose conversation history will be reduced.
84 |                 This list is modified in-place.
85 |             e: The exception that triggered the context reduction, if any.
86 |             **kwargs: Additional keyword arguments for future extensibility.
87 |         """
88 |         pass
89 | 


--------------------------------------------------------------------------------
/src/strands/agent/conversation_manager/null_conversation_manager.py:
--------------------------------------------------------------------------------
 1 | """Null implementation of conversation management."""
 2 | 
 3 | from typing import TYPE_CHECKING, Any, Optional
 4 | 
 5 | if TYPE_CHECKING:
 6 |     from ...agent.agent import Agent
 7 | 
 8 | from ...types.exceptions import ContextWindowOverflowException
 9 | from .conversation_manager import ConversationManager
10 | 
11 | 
12 | class NullConversationManager(ConversationManager):
13 |     """A no-op conversation manager that does not modify the conversation history.
14 | 
15 |     Useful for:
16 | 
17 |     - Testing scenarios where conversation management should be disabled
18 |     - Cases where conversation history is managed externally
19 |     - Situations where the full conversation history should be preserved
20 |     """
21 | 
22 |     def apply_management(self, agent: "Agent", **kwargs: Any) -> None:
23 |         """Does nothing to the conversation history.
24 | 
25 |         Args:
26 |             agent: The agent whose conversation history will remain unmodified.
27 |             **kwargs: Additional keyword arguments for future extensibility.
28 |         """
29 |         pass
30 | 
31 |     def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs: Any) -> None:
32 |         """Does not reduce context and raises an exception.
33 | 
34 |         Args:
35 |             agent: The agent whose conversation history will remain unmodified.
36 |             e: The exception that triggered the context reduction, if any.
37 |             **kwargs: Additional keyword arguments for future extensibility.
38 | 
39 |         Raises:
40 |             e: If provided.
41 |             ContextWindowOverflowException: If e is None.
42 |         """
43 |         if e:
44 |             raise e
45 |         else:
46 |             raise ContextWindowOverflowException("Context window overflowed!")
47 | 


--------------------------------------------------------------------------------
/src/strands/agent/state.py:
--------------------------------------------------------------------------------
 1 | """Agent state management."""
 2 | 
 3 | import copy
 4 | import json
 5 | from typing import Any, Dict, Optional
 6 | 
 7 | 
 8 | class AgentState:
 9 |     """Represents an Agent's stateful information outside of context provided to a model.
10 | 
11 |     Provides a key-value store for agent state with JSON serialization validation and persistence support.
12 |     Key features:
13 |     - JSON serialization validation on assignment
14 |     - Get/set/delete operations
15 |     """
16 | 
17 |     def __init__(self, initial_state: Optional[Dict[str, Any]] = None):
18 |         """Initialize AgentState."""
19 |         self._state: Dict[str, Dict[str, Any]]
20 |         if initial_state:
21 |             self._validate_json_serializable(initial_state)
22 |             self._state = copy.deepcopy(initial_state)
23 |         else:
24 |             self._state = {}
25 | 
26 |     def set(self, key: str, value: Any) -> None:
27 |         """Set a value in the state.
28 | 
29 |         Args:
30 |             key: The key to store the value under
31 |             value: The value to store (must be JSON serializable)
32 | 
33 |         Raises:
34 |             ValueError: If key is invalid, or if value is not JSON serializable
35 |         """
36 |         self._validate_key(key)
37 |         self._validate_json_serializable(value)
38 | 
39 |         self._state[key] = copy.deepcopy(value)
40 | 
41 |     def get(self, key: Optional[str] = None) -> Any:
42 |         """Get a value or entire state.
43 | 
44 |         Args:
45 |             key: The key to retrieve (if None, returns entire state object)
46 | 
47 |         Returns:
48 |             The stored value, entire state dict, or None if not found
49 |         """
50 |         if key is None:
51 |             return copy.deepcopy(self._state)
52 |         else:
53 |             # Return specific key
54 |             return copy.deepcopy(self._state.get(key))
55 | 
56 |     def delete(self, key: str) -> None:
57 |         """Delete a specific key from the state.
58 | 
59 |         Args:
60 |             key: The key to delete
61 |         """
62 |         self._validate_key(key)
63 | 
64 |         self._state.pop(key, None)
65 | 
66 |     def _validate_key(self, key: str) -> None:
67 |         """Validate that a key is valid.
68 | 
69 |         Args:
70 |             key: The key to validate
71 | 
72 |         Raises:
73 |             ValueError: If key is invalid
74 |         """
75 |         if key is None:
76 |             raise ValueError("Key cannot be None")
77 |         if not isinstance(key, str):
78 |             raise ValueError("Key must be a string")
79 |         if not key.strip():
80 |             raise ValueError("Key cannot be empty")
81 | 
82 |     def _validate_json_serializable(self, value: Any) -> None:
83 |         """Validate that a value is JSON serializable.
84 | 
85 |         Args:
86 |             value: The value to validate
87 | 
88 |         Raises:
89 |             ValueError: If value is not JSON serializable
90 |         """
91 |         try:
92 |             json.dumps(value)
93 |         except (TypeError, ValueError) as e:
94 |             raise ValueError(
95 |                 f"Value is not JSON serializable: {type(value).__name__}. "
96 |                 f"Only JSON-compatible types (str, int, float, bool, list, dict, None) are allowed."
97 |             ) from e
98 | 


--------------------------------------------------------------------------------
/src/strands/event_loop/__init__.py:
--------------------------------------------------------------------------------
 1 | """This package provides the core event loop implementation for the agents SDK.
 2 | 
 3 | The event loop enables conversational AI agents to process messages, execute tools, and handle errors in a controlled,
 4 | iterative manner.
 5 | """
 6 | 
 7 | from . import event_loop
 8 | 
 9 | __all__ = ["event_loop"]
10 | 


--------------------------------------------------------------------------------
/src/strands/experimental/__init__.py:
--------------------------------------------------------------------------------
1 | """Experimental features.
2 | 
3 | This module implements experimental features that are subject to change in future revisions without notice.
4 | """
5 | 


--------------------------------------------------------------------------------
/src/strands/experimental/hooks/__init__.py:
--------------------------------------------------------------------------------
 1 | """Experimental hook functionality that has not yet reached stability."""
 2 | 
 3 | from .events import (
 4 |     AfterModelInvocationEvent,
 5 |     AfterToolInvocationEvent,
 6 |     BeforeModelInvocationEvent,
 7 |     BeforeToolInvocationEvent,
 8 | )
 9 | 
10 | __all__ = [
11 |     "BeforeToolInvocationEvent",
12 |     "AfterToolInvocationEvent",
13 |     "BeforeModelInvocationEvent",
14 |     "AfterModelInvocationEvent",
15 | ]
16 | 


--------------------------------------------------------------------------------
/src/strands/experimental/hooks/events.py:
--------------------------------------------------------------------------------
  1 | """Experimental hook events emitted as part of invoking Agents.
  2 | 
  3 | This module defines the events that are emitted as Agents run through the lifecycle of a request.
  4 | """
  5 | 
  6 | from dataclasses import dataclass
  7 | from typing import Any, Optional
  8 | 
  9 | from ...hooks import HookEvent
 10 | from ...types.content import Message
 11 | from ...types.streaming import StopReason
 12 | from ...types.tools import AgentTool, ToolResult, ToolUse
 13 | 
 14 | 
 15 | @dataclass
 16 | class BeforeToolInvocationEvent(HookEvent):
 17 |     """Event triggered before a tool is invoked.
 18 | 
 19 |     This event is fired just before the agent executes a tool, allowing hook
 20 |     providers to inspect, modify, or replace the tool that will be executed.
 21 |     The selected_tool can be modified by hook callbacks to change which tool
 22 |     gets executed.
 23 | 
 24 |     Attributes:
 25 |         selected_tool: The tool that will be invoked. Can be modified by hooks
 26 |             to change which tool gets executed. This may be None if tool lookup failed.
 27 |         tool_use: The tool parameters that will be passed to selected_tool.
 28 |         invocation_state: Keyword arguments that will be passed to the tool.
 29 |     """
 30 | 
 31 |     selected_tool: Optional[AgentTool]
 32 |     tool_use: ToolUse
 33 |     invocation_state: dict[str, Any]
 34 | 
 35 |     def _can_write(self, name: str) -> bool:
 36 |         return name in ["selected_tool", "tool_use"]
 37 | 
 38 | 
 39 | @dataclass
 40 | class AfterToolInvocationEvent(HookEvent):
 41 |     """Event triggered after a tool invocation completes.
 42 | 
 43 |     This event is fired after the agent has finished executing a tool,
 44 |     regardless of whether the execution was successful or resulted in an error.
 45 |     Hook providers can use this event for cleanup, logging, or post-processing.
 46 | 
 47 |     Note: This event uses reverse callback ordering, meaning callbacks registered
 48 |     later will be invoked first during cleanup.
 49 | 
 50 |     Attributes:
 51 |         selected_tool: The tool that was invoked. It may be None if tool lookup failed.
 52 |         tool_use: The tool parameters that were passed to the tool invoked.
 53 |         invocation_state: Keyword arguments that were passed to the tool
 54 |         result: The result of the tool invocation. Either a ToolResult on success
 55 |             or an Exception if the tool execution failed.
 56 |     """
 57 | 
 58 |     selected_tool: Optional[AgentTool]
 59 |     tool_use: ToolUse
 60 |     invocation_state: dict[str, Any]
 61 |     result: ToolResult
 62 |     exception: Optional[Exception] = None
 63 | 
 64 |     def _can_write(self, name: str) -> bool:
 65 |         return name == "result"
 66 | 
 67 |     @property
 68 |     def should_reverse_callbacks(self) -> bool:
 69 |         """True to invoke callbacks in reverse order."""
 70 |         return True
 71 | 
 72 | 
 73 | @dataclass
 74 | class BeforeModelInvocationEvent(HookEvent):
 75 |     """Event triggered before the model is invoked.
 76 | 
 77 |     This event is fired just before the agent calls the model for inference,
 78 |     allowing hook providers to inspect or modify the messages and configuration
 79 |     that will be sent to the model.
 80 | 
 81 |     Note: This event is not fired for invocations to structured_output.
 82 |     """
 83 | 
 84 |     pass
 85 | 
 86 | 
 87 | @dataclass
 88 | class AfterModelInvocationEvent(HookEvent):
 89 |     """Event triggered after the model invocation completes.
 90 | 
 91 |     This event is fired after the agent has finished calling the model,
 92 |     regardless of whether the invocation was successful or resulted in an error.
 93 |     Hook providers can use this event for cleanup, logging, or post-processing.
 94 | 
 95 |     Note: This event uses reverse callback ordering, meaning callbacks registered
 96 |     later will be invoked first during cleanup.
 97 | 
 98 |     Note: This event is not fired for invocations to structured_output.
 99 | 
100 |     Attributes:
101 |         stop_response: The model response data if invocation was successful, None if failed.
102 |         exception: Exception if the model invocation failed, None if successful.
103 |     """
104 | 
105 |     @dataclass
106 |     class ModelStopResponse:
107 |         """Model response data from successful invocation.
108 | 
109 |         Attributes:
110 |             stop_reason: The reason the model stopped generating.
111 |             message: The generated message from the model.
112 |         """
113 | 
114 |         message: Message
115 |         stop_reason: StopReason
116 | 
117 |     stop_response: Optional[ModelStopResponse] = None
118 |     exception: Optional[Exception] = None
119 | 
120 |     @property
121 |     def should_reverse_callbacks(self) -> bool:
122 |         """True to invoke callbacks in reverse order."""
123 |         return True
124 | 


--------------------------------------------------------------------------------
/src/strands/handlers/__init__.py:
--------------------------------------------------------------------------------
 1 | """Various handlers for performing custom actions on agent state.
 2 | 
 3 | Examples include:
 4 | 
 5 | - Displaying events from the event stream
 6 | """
 7 | 
 8 | from .callback_handler import CompositeCallbackHandler, PrintingCallbackHandler, null_callback_handler
 9 | 
10 | __all__ = ["CompositeCallbackHandler", "null_callback_handler", "PrintingCallbackHandler"]
11 | 


--------------------------------------------------------------------------------
/src/strands/handlers/callback_handler.py:
--------------------------------------------------------------------------------
 1 | """This module provides handlers for formatting and displaying events from the agent."""
 2 | 
 3 | from collections.abc import Callable
 4 | from typing import Any
 5 | 
 6 | 
 7 | class PrintingCallbackHandler:
 8 |     """Handler for streaming text output and tool invocations to stdout."""
 9 | 
10 |     def __init__(self) -> None:
11 |         """Initialize handler."""
12 |         self.tool_count = 0
13 |         self.previous_tool_use = None
14 | 
15 |     def __call__(self, **kwargs: Any) -> None:
16 |         """Stream text output and tool invocations to stdout.
17 | 
18 |         Args:
19 |             **kwargs: Callback event data including:
20 |                 - reasoningText (Optional[str]): Reasoning text to print if provided.
21 |                 - data (str): Text content to stream.
22 |                 - complete (bool): Whether this is the final chunk of a response.
23 |                 - current_tool_use (dict): Information about the current tool being used.
24 |         """
25 |         reasoningText = kwargs.get("reasoningText", False)
26 |         data = kwargs.get("data", "")
27 |         complete = kwargs.get("complete", False)
28 |         current_tool_use = kwargs.get("current_tool_use", {})
29 | 
30 |         if reasoningText:
31 |             print(reasoningText, end="")
32 | 
33 |         if data:
34 |             print(data, end="" if not complete else "\n")
35 | 
36 |         if current_tool_use and current_tool_use.get("name"):
37 |             tool_name = current_tool_use.get("name", "Unknown tool")
38 |             if self.previous_tool_use != current_tool_use:
39 |                 self.previous_tool_use = current_tool_use
40 |                 self.tool_count += 1
41 |                 print(f"\nTool #{self.tool_count}: {tool_name}")
42 | 
43 |         if complete and data:
44 |             print("\n")
45 | 
46 | 
47 | class CompositeCallbackHandler:
48 |     """Class-based callback handler that combines multiple callback handlers.
49 | 
50 |     This handler allows multiple callback handlers to be invoked for the same events,
51 |     enabling different processing or output formats for the same stream data.
52 |     """
53 | 
54 |     def __init__(self, *handlers: Callable) -> None:
55 |         """Initialize handler."""
56 |         self.handlers = handlers
57 | 
58 |     def __call__(self, **kwargs: Any) -> None:
59 |         """Invoke all handlers in the chain."""
60 |         for handler in self.handlers:
61 |             handler(**kwargs)
62 | 
63 | 
64 | def null_callback_handler(**_kwargs: Any) -> None:
65 |     """Callback handler that discards all output.
66 | 
67 |     Args:
68 |         **_kwargs: Event data (ignored).
69 |     """
70 |     return None
71 | 


--------------------------------------------------------------------------------
/src/strands/hooks/__init__.py:
--------------------------------------------------------------------------------
 1 | """Typed hook system for extending agent functionality.
 2 | 
 3 | This module provides a composable mechanism for building objects that can hook
 4 | into specific events during the agent lifecycle. The hook system enables both
 5 | built-in SDK components and user code to react to or modify agent behavior
 6 | through strongly-typed event callbacks.
 7 | 
 8 | Example Usage:
 9 |     ```python
10 |     from strands.hooks import HookProvider, HookRegistry
11 |     from strands.hooks.events import StartRequestEvent, EndRequestEvent
12 | 
13 |     class LoggingHooks(HookProvider):
14 |         def register_hooks(self, registry: HookRegistry) -> None:
15 |             registry.add_callback(StartRequestEvent, self.log_start)
16 |             registry.add_callback(EndRequestEvent, self.log_end)
17 | 
18 |         def log_start(self, event: StartRequestEvent) -> None:
19 |             print(f"Request started for {event.agent.name}")
20 | 
21 |         def log_end(self, event: EndRequestEvent) -> None:
22 |             print(f"Request completed for {event.agent.name}")
23 | 
24 |     # Use with agent
25 |     agent = Agent(hooks=[LoggingHooks()])
26 |     ```
27 | 
28 | This replaces the older callback_handler approach with a more composable,
29 | type-safe system that supports multiple subscribers per event type.
30 | """
31 | 
32 | from .events import (
33 |     AfterInvocationEvent,
34 |     AgentInitializedEvent,
35 |     BeforeInvocationEvent,
36 |     MessageAddedEvent,
37 | )
38 | from .registry import HookCallback, HookEvent, HookProvider, HookRegistry
39 | 
40 | __all__ = [
41 |     "AgentInitializedEvent",
42 |     "BeforeInvocationEvent",
43 |     "AfterInvocationEvent",
44 |     "MessageAddedEvent",
45 |     "HookEvent",
46 |     "HookProvider",
47 |     "HookCallback",
48 |     "HookRegistry",
49 | ]
50 | 


--------------------------------------------------------------------------------
/src/strands/hooks/events.py:
--------------------------------------------------------------------------------
 1 | """Hook events emitted as part of invoking Agents.
 2 | 
 3 | This module defines the events that are emitted as Agents run through the lifecycle of a request.
 4 | """
 5 | 
 6 | from dataclasses import dataclass
 7 | 
 8 | from ..types.content import Message
 9 | from .registry import HookEvent
10 | 
11 | 
12 | @dataclass
13 | class AgentInitializedEvent(HookEvent):
14 |     """Event triggered when an agent has finished initialization.
15 | 
16 |     This event is fired after the agent has been fully constructed and all
17 |     built-in components have been initialized. Hook providers can use this
18 |     event to perform setup tasks that require a fully initialized agent.
19 |     """
20 | 
21 |     pass
22 | 
23 | 
24 | @dataclass
25 | class BeforeInvocationEvent(HookEvent):
26 |     """Event triggered at the beginning of a new agent request.
27 | 
28 |     This event is fired before the agent begins processing a new user request,
29 |     before any model inference or tool execution occurs. Hook providers can
30 |     use this event to perform request-level setup, logging, or validation.
31 | 
32 |     This event is triggered at the beginning of the following api calls:
33 |       - Agent.__call__
34 |       - Agent.stream_async
35 |       - Agent.structured_output
36 |     """
37 | 
38 |     pass
39 | 
40 | 
41 | @dataclass
42 | class AfterInvocationEvent(HookEvent):
43 |     """Event triggered at the end of an agent request.
44 | 
45 |     This event is fired after the agent has completed processing a request,
46 |     regardless of whether it completed successfully or encountered an error.
47 |     Hook providers can use this event for cleanup, logging, or state persistence.
48 | 
49 |     Note: This event uses reverse callback ordering, meaning callbacks registered
50 |     later will be invoked first during cleanup.
51 | 
52 |     This event is triggered at the end of the following api calls:
53 |       - Agent.__call__
54 |       - Agent.stream_async
55 |       - Agent.structured_output
56 |     """
57 | 
58 |     @property
59 |     def should_reverse_callbacks(self) -> bool:
60 |         """True to invoke callbacks in reverse order."""
61 |         return True
62 | 
63 | 
64 | @dataclass
65 | class MessageAddedEvent(HookEvent):
66 |     """Event triggered when a message is added to the agent's conversation.
67 | 
68 |     This event is fired whenever the agent adds a new message to its internal
69 |     message history, including user messages, assistant responses, and tool
70 |     results. Hook providers can use this event for logging, monitoring, or
71 |     implementing custom message processing logic.
72 | 
73 |     Note: This event is only triggered for messages added by the framework
74 |     itself, not for messages manually added by tools or external code.
75 | 
76 |     Attributes:
77 |         message: The message that was added to the conversation history.
78 |     """
79 | 
80 |     message: Message
81 | 


--------------------------------------------------------------------------------
/src/strands/hooks/rules.md:
--------------------------------------------------------------------------------
 1 | # Hook System Rules
 2 | 
 3 | ## Terminology
 4 | 
 5 | - **Paired events**: Events that denote the beginning and end of an operation
 6 | - **Hook callback**: A function that receives a strongly-typed event argument and performs some action in response
 7 | 
 8 | ## Naming Conventions
 9 | 
10 | - All hook events have a suffix of `Event`
11 | - Paired events follow the naming convention of `Before{Item}Event` and `After{Item}Event`
12 | 
13 | ## Paired Events
14 | 
15 | - The final event in a pair returns `True` for `should_reverse_callbacks`
16 | - For every `Before` event there is a corresponding `After` event, even if an exception occurs
17 | 
18 | ## Writable Properties
19 | 
20 | For events with writable properties, those values are re-read after invoking the hook callbacks and used in subsequent processing. For example, `BeforeToolInvocationEvent.selected_tool` is writable - after invoking the callback for `BeforeToolInvocationEvent`, the `selected_tool` takes effect for the tool call.


--------------------------------------------------------------------------------
/src/strands/models/__init__.py:
--------------------------------------------------------------------------------
 1 | """SDK model providers.
 2 | 
 3 | This package includes an abstract base Model class along with concrete implementations for specific providers.
 4 | """
 5 | 
 6 | from . import bedrock, model
 7 | from .bedrock import BedrockModel
 8 | from .model import Model
 9 | 
10 | __all__ = ["bedrock", "model", "BedrockModel", "Model"]
11 | 


--------------------------------------------------------------------------------
/src/strands/models/model.py:
--------------------------------------------------------------------------------
 1 | """Abstract base class for Agent model providers."""
 2 | 
 3 | import abc
 4 | import logging
 5 | from typing import Any, AsyncGenerator, AsyncIterable, Optional, Type, TypeVar, Union
 6 | 
 7 | from pydantic import BaseModel
 8 | 
 9 | from ..types.content import Messages
10 | from ..types.streaming import StreamEvent
11 | from ..types.tools import ToolSpec
12 | 
13 | logger = logging.getLogger(__name__)
14 | 
15 | T = TypeVar("T", bound=BaseModel)
16 | 
17 | 
18 | class Model(abc.ABC):
19 |     """Abstract base class for Agent model providers.
20 | 
21 |     This class defines the interface for all model implementations in the Strands Agents SDK. It provides a
22 |     standardized way to configure and process requests for different AI model providers.
23 |     """
24 | 
25 |     @abc.abstractmethod
26 |     # pragma: no cover
27 |     def update_config(self, **model_config: Any) -> None:
28 |         """Update the model configuration with the provided arguments.
29 | 
30 |         Args:
31 |             **model_config: Configuration overrides.
32 |         """
33 |         pass
34 | 
35 |     @abc.abstractmethod
36 |     # pragma: no cover
37 |     def get_config(self) -> Any:
38 |         """Return the model configuration.
39 | 
40 |         Returns:
41 |             The model's configuration.
42 |         """
43 |         pass
44 | 
45 |     @abc.abstractmethod
46 |     # pragma: no cover
47 |     def structured_output(
48 |         self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any
49 |     ) -> AsyncGenerator[dict[str, Union[T, Any]], None]:
50 |         """Get structured output from the model.
51 | 
52 |         Args:
53 |             output_model: The output model to use for the agent.
54 |             prompt: The prompt messages to use for the agent.
55 |             system_prompt: System prompt to provide context to the model.
56 |             **kwargs: Additional keyword arguments for future extensibility.
57 | 
58 |         Yields:
59 |             Model events with the last being the structured output.
60 | 
61 |         Raises:
62 |             ValidationException: The response format from the model does not match the output_model
63 |         """
64 |         pass
65 | 
66 |     @abc.abstractmethod
67 |     # pragma: no cover
68 |     def stream(
69 |         self,
70 |         messages: Messages,
71 |         tool_specs: Optional[list[ToolSpec]] = None,
72 |         system_prompt: Optional[str] = None,
73 |         **kwargs: Any,
74 |     ) -> AsyncIterable[StreamEvent]:
75 |         """Stream conversation with the model.
76 | 
77 |         This method handles the full lifecycle of conversing with the model:
78 | 
79 |         1. Format the messages, tool specs, and configuration into a streaming request
80 |         2. Send the request to the model
81 |         3. Yield the formatted message chunks
82 | 
83 |         Args:
84 |             messages: List of message objects to be processed by the model.
85 |             tool_specs: List of tool specifications to make available to the model.
86 |             system_prompt: System prompt to provide context to the model.
87 |             **kwargs: Additional keyword arguments for future extensibility.
88 | 
89 |         Yields:
90 |             Formatted message chunks from the model.
91 | 
92 |         Raises:
93 |             ModelThrottledException: When the model service is throttling requests from the client.
94 |         """
95 |         pass
96 | 


--------------------------------------------------------------------------------
/src/strands/multiagent/__init__.py:
--------------------------------------------------------------------------------
 1 | """Multiagent capabilities for Strands Agents.
 2 | 
 3 | This module provides support for multiagent systems, including agent-to-agent (A2A)
 4 | communication protocols and coordination mechanisms.
 5 | 
 6 | Submodules:
 7 |     a2a: Implementation of the Agent-to-Agent (A2A) protocol, which enables
 8 |          standardized communication between agents.
 9 | """
10 | 
11 | from .base import MultiAgentBase, MultiAgentResult
12 | from .graph import GraphBuilder, GraphResult
13 | from .swarm import Swarm, SwarmResult
14 | 
15 | __all__ = [
16 |     "GraphBuilder",
17 |     "GraphResult",
18 |     "MultiAgentBase",
19 |     "MultiAgentResult",
20 |     "Swarm",
21 |     "SwarmResult",
22 | ]
23 | 


--------------------------------------------------------------------------------
/src/strands/multiagent/a2a/__init__.py:
--------------------------------------------------------------------------------
 1 | """Agent-to-Agent (A2A) communication protocol implementation for Strands Agents.
 2 | 
 3 | This module provides classes and utilities for enabling Strands Agents to communicate
 4 | with other agents using the Agent-to-Agent (A2A) protocol.
 5 | 
 6 | Docs: https://google-a2a.github.io/A2A/latest/
 7 | 
 8 | Classes:
 9 |     A2AAgent: A wrapper that adapts a Strands Agent to be A2A-compatible.
10 | """
11 | 
12 | from .executor import StrandsA2AExecutor
13 | from .server import A2AServer
14 | 
15 | __all__ = ["A2AServer", "StrandsA2AExecutor"]
16 | 


--------------------------------------------------------------------------------
/src/strands/multiagent/base.py:
--------------------------------------------------------------------------------
 1 | """Multi-Agent Base Class.
 2 | 
 3 | Provides minimal foundation for multi-agent patterns (Swarm, Graph).
 4 | """
 5 | 
 6 | from abc import ABC, abstractmethod
 7 | from dataclasses import dataclass, field
 8 | from enum import Enum
 9 | from typing import Any, Union
10 | 
11 | from ..agent import AgentResult
12 | from ..types.content import ContentBlock
13 | from ..types.event_loop import Metrics, Usage
14 | 
15 | 
16 | class Status(Enum):
17 |     """Execution status for both graphs and nodes."""
18 | 
19 |     PENDING = "pending"
20 |     EXECUTING = "executing"
21 |     COMPLETED = "completed"
22 |     FAILED = "failed"
23 | 
24 | 
25 | @dataclass
26 | class NodeResult:
27 |     """Unified result from node execution - handles both Agent and nested MultiAgentBase results.
28 | 
29 |     The status field represents the semantic outcome of the node's work:
30 |     - COMPLETED: The node's task was successfully accomplished
31 |     - FAILED: The node's task failed or produced an error
32 |     """
33 | 
34 |     # Core result data - single AgentResult, nested MultiAgentResult, or Exception
35 |     result: Union[AgentResult, "MultiAgentResult", Exception]
36 | 
37 |     # Execution metadata
38 |     execution_time: int = 0
39 |     status: Status = Status.PENDING
40 | 
41 |     # Accumulated metrics from this node and all children
42 |     accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0))
43 |     accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0))
44 |     execution_count: int = 0
45 | 
46 |     def get_agent_results(self) -> list[AgentResult]:
47 |         """Get all AgentResult objects from this node, flattened if nested."""
48 |         if isinstance(self.result, Exception):
49 |             return []  # No agent results for exceptions
50 |         elif isinstance(self.result, AgentResult):
51 |             return [self.result]
52 |         else:
53 |             # Flatten nested results from MultiAgentResult
54 |             flattened = []
55 |             for nested_node_result in self.result.results.values():
56 |                 flattened.extend(nested_node_result.get_agent_results())
57 |             return flattened
58 | 
59 | 
60 | @dataclass
61 | class MultiAgentResult:
62 |     """Result from multi-agent execution with accumulated metrics.
63 | 
64 |     The status field represents the outcome of the MultiAgentBase execution:
65 |     - COMPLETED: The execution was successfully accomplished
66 |     - FAILED: The execution failed or produced an error
67 |     """
68 | 
69 |     status: Status = Status.PENDING
70 |     results: dict[str, NodeResult] = field(default_factory=lambda: {})
71 |     accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0))
72 |     accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0))
73 |     execution_count: int = 0
74 |     execution_time: int = 0
75 | 
76 | 
77 | class MultiAgentBase(ABC):
78 |     """Base class for multi-agent helpers.
79 | 
80 |     This class integrates with existing Strands Agent instances and provides
81 |     multi-agent orchestration capabilities.
82 |     """
83 | 
84 |     @abstractmethod
85 |     async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> MultiAgentResult:
86 |         """Invoke asynchronously."""
87 |         raise NotImplementedError("invoke_async not implemented")
88 | 
89 |     @abstractmethod
90 |     def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> MultiAgentResult:
91 |         """Invoke synchronously."""
92 |         raise NotImplementedError("__call__ not implemented")
93 | 


--------------------------------------------------------------------------------
/src/strands/py.typed:
--------------------------------------------------------------------------------
1 | # Marker file that indicates this package supports typing
2 | 


--------------------------------------------------------------------------------
/src/strands/session/__init__.py:
--------------------------------------------------------------------------------
 1 | """Session module.
 2 | 
 3 | This module provides session management functionality.
 4 | """
 5 | 
 6 | from .file_session_manager import FileSessionManager
 7 | from .repository_session_manager import RepositorySessionManager
 8 | from .s3_session_manager import S3SessionManager
 9 | from .session_manager import SessionManager
10 | from .session_repository import SessionRepository
11 | 
12 | __all__ = [
13 |     "FileSessionManager",
14 |     "RepositorySessionManager",
15 |     "S3SessionManager",
16 |     "SessionManager",
17 |     "SessionRepository",
18 | ]
19 | 


--------------------------------------------------------------------------------
/src/strands/session/session_manager.py:
--------------------------------------------------------------------------------
 1 | """Session manager interface for agent session management."""
 2 | 
 3 | from abc import ABC, abstractmethod
 4 | from typing import TYPE_CHECKING, Any
 5 | 
 6 | from ..hooks.events import AfterInvocationEvent, AgentInitializedEvent, MessageAddedEvent
 7 | from ..hooks.registry import HookProvider, HookRegistry
 8 | from ..types.content import Message
 9 | 
10 | if TYPE_CHECKING:
11 |     from ..agent.agent import Agent
12 | 
13 | 
14 | class SessionManager(HookProvider, ABC):
15 |     """Abstract interface for managing sessions.
16 | 
17 |     A session manager is in charge of persisting the conversation and state of an agent across its interaction.
18 |     Changes made to the agents conversation, state, or other attributes should be persisted immediately after
19 |     they are changed. The different methods introduced in this class are called at important lifecycle events
20 |     for an agent, and should be persisted in the session.
21 |     """
22 | 
23 |     def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None:
24 |         """Register hooks for persisting the agent to the session."""
25 |         # After the normal Agent initialization behavior, call the session initialize function to restore the agent
26 |         registry.add_callback(AgentInitializedEvent, lambda event: self.initialize(event.agent))
27 | 
28 |         # For each message appended to the Agents messages, store that message in the session
29 |         registry.add_callback(MessageAddedEvent, lambda event: self.append_message(event.message, event.agent))
30 | 
31 |         # Sync the agent into the session for each message in case the agent state was updated
32 |         registry.add_callback(MessageAddedEvent, lambda event: self.sync_agent(event.agent))
33 | 
34 |         # After an agent was invoked, sync it with the session to capture any conversation manager state updates
35 |         registry.add_callback(AfterInvocationEvent, lambda event: self.sync_agent(event.agent))
36 | 
37 |     @abstractmethod
38 |     def redact_latest_message(self, redact_message: Message, agent: "Agent", **kwargs: Any) -> None:
39 |         """Redact the message most recently appended to the agent in the session.
40 | 
41 |         Args:
42 |             redact_message: New message to use that contains the redact content
43 |             agent: Agent to apply the message redaction to
44 |             **kwargs: Additional keyword arguments for future extensibility.
45 |         """
46 | 
47 |     @abstractmethod
48 |     def append_message(self, message: Message, agent: "Agent", **kwargs: Any) -> None:
49 |         """Append a message to the agent's session.
50 | 
51 |         Args:
52 |             message: Message to add to the agent in the session
53 |             agent: Agent to append the message to
54 |             **kwargs: Additional keyword arguments for future extensibility.
55 |         """
56 | 
57 |     @abstractmethod
58 |     def sync_agent(self, agent: "Agent", **kwargs: Any) -> None:
59 |         """Serialize and sync the agent with the session storage.
60 | 
61 |         Args:
62 |             agent: Agent who should be synchronized with the session storage
63 |             **kwargs: Additional keyword arguments for future extensibility.
64 |         """
65 | 
66 |     @abstractmethod
67 |     def initialize(self, agent: "Agent", **kwargs: Any) -> None:
68 |         """Initialize an agent with a session.
69 | 
70 |         Args:
71 |             agent: Agent to initialize
72 |             **kwargs: Additional keyword arguments for future extensibility.
73 |         """
74 | 


--------------------------------------------------------------------------------
/src/strands/session/session_repository.py:
--------------------------------------------------------------------------------
 1 | """Session repository interface for agent session management."""
 2 | 
 3 | from abc import ABC, abstractmethod
 4 | from typing import Any, Optional
 5 | 
 6 | from ..types.session import Session, SessionAgent, SessionMessage
 7 | 
 8 | 
 9 | class SessionRepository(ABC):
10 |     """Abstract repository for creating, reading, and updating Sessions, AgentSessions, and AgentMessages."""
11 | 
12 |     @abstractmethod
13 |     def create_session(self, session: Session, **kwargs: Any) -> Session:
14 |         """Create a new Session."""
15 | 
16 |     @abstractmethod
17 |     def read_session(self, session_id: str, **kwargs: Any) -> Optional[Session]:
18 |         """Read a Session."""
19 | 
20 |     @abstractmethod
21 |     def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: Any) -> None:
22 |         """Create a new Agent in a Session."""
23 | 
24 |     @abstractmethod
25 |     def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[SessionAgent]:
26 |         """Read an Agent."""
27 | 
28 |     @abstractmethod
29 |     def update_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: Any) -> None:
30 |         """Update an Agent."""
31 | 
32 |     @abstractmethod
33 |     def create_message(self, session_id: str, agent_id: str, session_message: SessionMessage, **kwargs: Any) -> None:
34 |         """Create a new Message for the Agent."""
35 | 
36 |     @abstractmethod
37 |     def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs: Any) -> Optional[SessionMessage]:
38 |         """Read a Message."""
39 | 
40 |     @abstractmethod
41 |     def update_message(self, session_id: str, agent_id: str, session_message: SessionMessage, **kwargs: Any) -> None:
42 |         """Update a Message.
43 | 
44 |         A message is usually only updated when some content is redacted due to a guardrail.
45 |         """
46 | 
47 |     @abstractmethod
48 |     def list_messages(
49 |         self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0, **kwargs: Any
50 |     ) -> list[SessionMessage]:
51 |         """List Messages from an Agent with pagination."""
52 | 


--------------------------------------------------------------------------------
/src/strands/telemetry/__init__.py:
--------------------------------------------------------------------------------
 1 | """Telemetry module.
 2 | 
 3 | This module provides metrics and tracing functionality.
 4 | """
 5 | 
 6 | from .config import StrandsTelemetry
 7 | from .metrics import EventLoopMetrics, MetricsClient, Trace, metrics_to_string
 8 | from .tracer import Tracer, get_tracer
 9 | 
10 | __all__ = [
11 |     # Metrics
12 |     "EventLoopMetrics",
13 |     "Trace",
14 |     "metrics_to_string",
15 |     "MetricsClient",
16 |     # Tracer
17 |     "Tracer",
18 |     "get_tracer",
19 |     # Telemetry Setup
20 |     "StrandsTelemetry",
21 | ]
22 | 


--------------------------------------------------------------------------------
/src/strands/telemetry/metrics_constants.py:
--------------------------------------------------------------------------------
 1 | """Metrics that are emitted in Strands-Agents."""
 2 | 
 3 | STRANDS_EVENT_LOOP_CYCLE_COUNT = "strands.event_loop.cycle_count"
 4 | STRANDS_EVENT_LOOP_START_CYCLE = "strands.event_loop.start_cycle"
 5 | STRANDS_EVENT_LOOP_END_CYCLE = "strands.event_loop.end_cycle"
 6 | STRANDS_TOOL_CALL_COUNT = "strands.tool.call_count"
 7 | STRANDS_TOOL_SUCCESS_COUNT = "strands.tool.success_count"
 8 | STRANDS_TOOL_ERROR_COUNT = "strands.tool.error_count"
 9 | 
10 | # Histograms
11 | STRANDS_EVENT_LOOP_LATENCY = "strands.event_loop.latency"
12 | STRANDS_TOOL_DURATION = "strands.tool.duration"
13 | STRANDS_EVENT_LOOP_CYCLE_DURATION = "strands.event_loop.cycle_duration"
14 | STRANDS_EVENT_LOOP_INPUT_TOKENS = "strands.event_loop.input.tokens"
15 | STRANDS_EVENT_LOOP_OUTPUT_TOKENS = "strands.event_loop.output.tokens"
16 | 


--------------------------------------------------------------------------------
/src/strands/tools/__init__.py:
--------------------------------------------------------------------------------
 1 | """Agent tool interfaces and utilities.
 2 | 
 3 | This module provides the core functionality for creating, managing, and executing tools through agents.
 4 | """
 5 | 
 6 | from .decorator import tool
 7 | from .structured_output import convert_pydantic_to_tool_spec
 8 | from .tools import InvalidToolUseNameException, PythonAgentTool, normalize_schema, normalize_tool_spec
 9 | 
10 | __all__ = [
11 |     "tool",
12 |     "PythonAgentTool",
13 |     "InvalidToolUseNameException",
14 |     "normalize_schema",
15 |     "normalize_tool_spec",
16 |     "convert_pydantic_to_tool_spec",
17 | ]
18 | 


--------------------------------------------------------------------------------
/src/strands/tools/executor.py:
--------------------------------------------------------------------------------
  1 | """Tool execution functionality for the event loop."""
  2 | 
  3 | import asyncio
  4 | import logging
  5 | import time
  6 | from typing import Any, Optional, cast
  7 | 
  8 | from opentelemetry import trace
  9 | 
 10 | from ..telemetry.metrics import EventLoopMetrics, Trace
 11 | from ..telemetry.tracer import get_tracer
 12 | from ..tools.tools import InvalidToolUseNameException, validate_tool_use
 13 | from ..types.content import Message
 14 | from ..types.tools import RunToolHandler, ToolGenerator, ToolResult, ToolUse
 15 | 
 16 | logger = logging.getLogger(__name__)
 17 | 
 18 | 
 19 | async def run_tools(
 20 |     handler: RunToolHandler,
 21 |     tool_uses: list[ToolUse],
 22 |     event_loop_metrics: EventLoopMetrics,
 23 |     invalid_tool_use_ids: list[str],
 24 |     tool_results: list[ToolResult],
 25 |     cycle_trace: Trace,
 26 |     parent_span: Optional[trace.Span] = None,
 27 | ) -> ToolGenerator:
 28 |     """Execute tools concurrently.
 29 | 
 30 |     Args:
 31 |         handler: Tool handler processing function.
 32 |         tool_uses: List of tool uses to execute.
 33 |         event_loop_metrics: Metrics collection object.
 34 |         invalid_tool_use_ids: List of invalid tool use IDs.
 35 |         tool_results: List to populate with tool results.
 36 |         cycle_trace: Parent trace for the current cycle.
 37 |         parent_span: Parent span for the current cycle.
 38 | 
 39 |     Yields:
 40 |         Events of the tool stream. Tool results are appended to `tool_results`.
 41 |     """
 42 | 
 43 |     async def work(
 44 |         tool_use: ToolUse,
 45 |         worker_id: int,
 46 |         worker_queue: asyncio.Queue,
 47 |         worker_event: asyncio.Event,
 48 |         stop_event: object,
 49 |     ) -> ToolResult:
 50 |         tracer = get_tracer()
 51 |         tool_call_span = tracer.start_tool_call_span(tool_use, parent_span)
 52 | 
 53 |         tool_name = tool_use["name"]
 54 |         tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name)
 55 |         tool_start_time = time.time()
 56 | 
 57 |         try:
 58 |             async for event in handler(tool_use):
 59 |                 worker_queue.put_nowait((worker_id, event))
 60 |                 await worker_event.wait()
 61 |                 worker_event.clear()
 62 | 
 63 |             result = cast(ToolResult, event)
 64 |         finally:
 65 |             worker_queue.put_nowait((worker_id, stop_event))
 66 | 
 67 |         tool_success = result.get("status") == "success"
 68 |         tool_duration = time.time() - tool_start_time
 69 |         message = Message(role="user", content=[{"toolResult": result}])
 70 |         event_loop_metrics.add_tool_usage(tool_use, tool_duration, tool_trace, tool_success, message)
 71 |         cycle_trace.add_child(tool_trace)
 72 | 
 73 |         if tool_call_span:
 74 |             tracer.end_tool_call_span(tool_call_span, result)
 75 | 
 76 |         return result
 77 | 
 78 |     tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids]
 79 |     worker_queue: asyncio.Queue[tuple[int, Any]] = asyncio.Queue()
 80 |     worker_events = [asyncio.Event() for _ in tool_uses]
 81 |     stop_event = object()
 82 | 
 83 |     workers = [
 84 |         asyncio.create_task(work(tool_use, worker_id, worker_queue, worker_events[worker_id], stop_event))
 85 |         for worker_id, tool_use in enumerate(tool_uses)
 86 |     ]
 87 | 
 88 |     worker_count = len(workers)
 89 |     while worker_count:
 90 |         worker_id, event = await worker_queue.get()
 91 |         if event is stop_event:
 92 |             worker_count -= 1
 93 |             continue
 94 | 
 95 |         yield event
 96 |         worker_events[worker_id].set()
 97 | 
 98 |     tool_results.extend([worker.result() for worker in workers])
 99 | 
100 | 
101 | def validate_and_prepare_tools(
102 |     message: Message,
103 |     tool_uses: list[ToolUse],
104 |     tool_results: list[ToolResult],
105 |     invalid_tool_use_ids: list[str],
106 | ) -> None:
107 |     """Validate tool uses and prepare them for execution.
108 | 
109 |     Args:
110 |         message: Current message.
111 |         tool_uses: List to populate with tool uses.
112 |         tool_results: List to populate with tool results for invalid tools.
113 |         invalid_tool_use_ids: List to populate with invalid tool use IDs.
114 |     """
115 |     # Extract tool uses from message
116 |     for content in message["content"]:
117 |         if isinstance(content, dict) and "toolUse" in content:
118 |             tool_uses.append(content["toolUse"])
119 | 
120 |     # Validate tool uses
121 |     # Avoid modifying original `tool_uses` variable during iteration
122 |     tool_uses_copy = tool_uses.copy()
123 |     for tool in tool_uses_copy:
124 |         try:
125 |             validate_tool_use(tool)
126 |         except InvalidToolUseNameException as e:
127 |             # Replace the invalid toolUse name and return invalid name error as ToolResult to the LLM as context
128 |             tool_uses.remove(tool)
129 |             tool["name"] = "INVALID_TOOL_NAME"
130 |             invalid_tool_use_ids.append(tool["toolUseId"])
131 |             tool_uses.append(tool)
132 |             tool_results.append(
133 |                 {
134 |                     "toolUseId": tool["toolUseId"],
135 |                     "status": "error",
136 |                     "content": [{"text": f"Error: {str(e)}"}],
137 |                 }
138 |             )
139 | 


--------------------------------------------------------------------------------
/src/strands/tools/mcp/__init__.py:
--------------------------------------------------------------------------------
 1 | """Model Context Protocol (MCP) integration.
 2 | 
 3 | This package provides integration with the Model Context Protocol (MCP), allowing agents to use tools provided by MCP
 4 | servers.
 5 | 
 6 | - Docs: https://www.anthropic.com/news/model-context-protocol
 7 | """
 8 | 
 9 | from .mcp_agent_tool import MCPAgentTool
10 | from .mcp_client import MCPClient
11 | from .mcp_types import MCPTransport
12 | 
13 | __all__ = ["MCPAgentTool", "MCPClient", "MCPTransport"]
14 | 


--------------------------------------------------------------------------------
/src/strands/tools/mcp/mcp_agent_tool.py:
--------------------------------------------------------------------------------
  1 | """MCP Agent Tool module for adapting Model Context Protocol tools to the agent framework.
  2 | 
  3 | This module provides the MCPAgentTool class which serves as an adapter between
  4 | MCP (Model Context Protocol) tools and the agent framework's tool interface.
  5 | It allows MCP tools to be seamlessly integrated and used within the agent ecosystem.
  6 | """
  7 | 
  8 | import logging
  9 | from typing import TYPE_CHECKING, Any
 10 | 
 11 | from mcp.types import Tool as MCPTool
 12 | from typing_extensions import override
 13 | 
 14 | from ...types.tools import AgentTool, ToolGenerator, ToolSpec, ToolUse
 15 | 
 16 | if TYPE_CHECKING:
 17 |     from .mcp_client import MCPClient
 18 | 
 19 | logger = logging.getLogger(__name__)
 20 | 
 21 | 
 22 | class MCPAgentTool(AgentTool):
 23 |     """Adapter class that wraps an MCP tool and exposes it as an AgentTool.
 24 | 
 25 |     This class bridges the gap between the MCP protocol's tool representation
 26 |     and the agent framework's tool interface, allowing MCP tools to be used
 27 |     seamlessly within the agent framework.
 28 |     """
 29 | 
 30 |     def __init__(self, mcp_tool: MCPTool, mcp_client: "MCPClient") -> None:
 31 |         """Initialize a new MCPAgentTool instance.
 32 | 
 33 |         Args:
 34 |             mcp_tool: The MCP tool to adapt
 35 |             mcp_client: The MCP server connection to use for tool invocation
 36 |         """
 37 |         super().__init__()
 38 |         logger.debug("tool_name=<%s> | creating mcp agent tool", mcp_tool.name)
 39 |         self.mcp_tool = mcp_tool
 40 |         self.mcp_client = mcp_client
 41 | 
 42 |     @property
 43 |     def tool_name(self) -> str:
 44 |         """Get the name of the tool.
 45 | 
 46 |         Returns:
 47 |             str: The name of the MCP tool
 48 |         """
 49 |         return self.mcp_tool.name
 50 | 
 51 |     @property
 52 |     def tool_spec(self) -> ToolSpec:
 53 |         """Get the specification of the tool.
 54 | 
 55 |         This method converts the MCP tool specification to the agent framework's
 56 |         ToolSpec format, including the input schema and description.
 57 | 
 58 |         Returns:
 59 |             ToolSpec: The tool specification in the agent framework format
 60 |         """
 61 |         description: str = self.mcp_tool.description or f"Tool which performs {self.mcp_tool.name}"
 62 |         return {
 63 |             "inputSchema": {"json": self.mcp_tool.inputSchema},
 64 |             "name": self.mcp_tool.name,
 65 |             "description": description,
 66 |         }
 67 | 
 68 |     @property
 69 |     def tool_type(self) -> str:
 70 |         """Get the type of the tool.
 71 | 
 72 |         Returns:
 73 |             str: The type of the tool, always "python" for MCP tools
 74 |         """
 75 |         return "python"
 76 | 
 77 |     @override
 78 |     async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any) -> ToolGenerator:
 79 |         """Stream the MCP tool.
 80 | 
 81 |         This method delegates the tool stream to the MCP server connection, passing the tool use ID, tool name, and
 82 |         input arguments.
 83 | 
 84 |         Args:
 85 |             tool_use: The tool use request containing tool ID and parameters.
 86 |             invocation_state: Context for the tool invocation, including agent state.
 87 |             **kwargs: Additional keyword arguments for future extensibility.
 88 | 
 89 |         Yields:
 90 |             Tool events with the last being the tool result.
 91 |         """
 92 |         logger.debug("tool_name=<%s>, tool_use_id=<%s> | streaming", self.tool_name, tool_use["toolUseId"])
 93 | 
 94 |         result = await self.mcp_client.call_tool_async(
 95 |             tool_use_id=tool_use["toolUseId"],
 96 |             name=self.tool_name,
 97 |             arguments=tool_use["input"],
 98 |         )
 99 |         yield result
100 | 


--------------------------------------------------------------------------------
/src/strands/tools/mcp/mcp_types.py:
--------------------------------------------------------------------------------
 1 | """Type definitions for MCP integration."""
 2 | 
 3 | from contextlib import AbstractAsyncContextManager
 4 | 
 5 | from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
 6 | from mcp.client.streamable_http import GetSessionIdCallback
 7 | from mcp.shared.memory import MessageStream
 8 | from mcp.shared.message import SessionMessage
 9 | 
10 | """
11 | MCPTransport defines the interface for MCP transport implementations. This abstracts
12 | communication with an MCP server, hiding details of the underlying transport mechanism (WebSocket, stdio, etc.).
13 | 
14 | It represents an async context manager that yields a tuple of read and write streams for MCP communication.
15 | When used with `async with`, it should establish the connection and yield the streams, then clean up
16 | when the context is exited.
17 | 
18 | The read stream receives messages from the client (or exceptions if parsing fails), while the write
19 | stream sends messages to the client.
20 | 
21 | Example implementation (simplified):
22 | ```python
23 | @contextlib.asynccontextmanager
24 | async def my_transport_implementation():
25 |     # Set up connection
26 |     read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
27 |     write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
28 |     
29 |     # Start background tasks to handle actual I/O
30 |     async with anyio.create_task_group() as tg:
31 |         tg.start_soon(reader_task, read_stream_writer)
32 |         tg.start_soon(writer_task, write_stream_reader)
33 |         
34 |         # Yield the streams to the caller
35 |         yield (read_stream, write_stream)
36 | ```
37 | """
38 | # GetSessionIdCallback was added for HTTP Streaming but was not applied to the MessageStream type
39 | # https://github.com/modelcontextprotocol/python-sdk/blob/ed25167fa5d715733437996682e20c24470e8177/src/mcp/client/streamable_http.py#L418
40 | _MessageStreamWithGetSessionIdCallback = tuple[
41 |     MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage], GetSessionIdCallback
42 | ]
43 | MCPTransport = AbstractAsyncContextManager[MessageStream | _MessageStreamWithGetSessionIdCallback]
44 | 


--------------------------------------------------------------------------------
/src/strands/tools/watcher.py:
--------------------------------------------------------------------------------
  1 | """Tool watcher for hot reloading tools during development.
  2 | 
  3 | This module provides functionality to watch tool directories for changes and automatically reload tools when they are
  4 | modified.
  5 | """
  6 | 
  7 | import logging
  8 | from pathlib import Path
  9 | from typing import Any, Dict, Set
 10 | 
 11 | from watchdog.events import FileSystemEventHandler
 12 | from watchdog.observers import Observer
 13 | 
 14 | from .registry import ToolRegistry
 15 | 
 16 | logger = logging.getLogger(__name__)
 17 | 
 18 | 
 19 | class ToolWatcher:
 20 |     """Watches tool directories for changes and reloads tools when they are modified."""
 21 | 
 22 |     # This class uses class variables for the observer and handlers because watchdog allows only one Observer instance
 23 |     # per directory. Using class variables ensures that all ToolWatcher instances share a single Observer, with the
 24 |     # MasterChangeHandler routing file system events to the appropriate individual handlers for each registry. This
 25 |     # design pattern avoids conflicts when multiple tool registries are watching the same directories.
 26 | 
 27 |     _shared_observer = None
 28 |     _watched_dirs: Set[str] = set()
 29 |     _observer_started = False
 30 |     _registry_handlers: Dict[str, Dict[int, "ToolWatcher.ToolChangeHandler"]] = {}
 31 | 
 32 |     def __init__(self, tool_registry: ToolRegistry) -> None:
 33 |         """Initialize a tool watcher for the given tool registry.
 34 | 
 35 |         Args:
 36 |             tool_registry: The tool registry to report changes.
 37 |         """
 38 |         self.tool_registry = tool_registry
 39 |         self.start()
 40 | 
 41 |     class ToolChangeHandler(FileSystemEventHandler):
 42 |         """Handler for tool file changes."""
 43 | 
 44 |         def __init__(self, tool_registry: ToolRegistry) -> None:
 45 |             """Initialize a tool change handler.
 46 | 
 47 |             Args:
 48 |                 tool_registry: The tool registry to update when tools change.
 49 |             """
 50 |             self.tool_registry = tool_registry
 51 | 
 52 |         def on_modified(self, event: Any) -> None:
 53 |             """Reload tool if file modification detected.
 54 | 
 55 |             Args:
 56 |                 event: The file system event that triggered this handler.
 57 |             """
 58 |             if event.src_path.endswith(".py"):
 59 |                 tool_path = Path(event.src_path)
 60 |                 tool_name = tool_path.stem
 61 | 
 62 |                 if tool_name not in ["__init__"]:
 63 |                     logger.debug("tool_name=<%s> | tool change detected", tool_name)
 64 |                     try:
 65 |                         self.tool_registry.reload_tool(tool_name)
 66 |                     except Exception as e:
 67 |                         logger.error("tool_name=<%s>, exception=<%s> | failed to reload tool", tool_name, str(e))
 68 | 
 69 |     class MasterChangeHandler(FileSystemEventHandler):
 70 |         """Master handler that delegates to all registered handlers."""
 71 | 
 72 |         def __init__(self, dir_path: str) -> None:
 73 |             """Initialize a master change handler for a specific directory.
 74 | 
 75 |             Args:
 76 |                 dir_path: The directory path to watch.
 77 |             """
 78 |             self.dir_path = dir_path
 79 | 
 80 |         def on_modified(self, event: Any) -> None:
 81 |             """Delegate file modification events to all registered handlers.
 82 | 
 83 |             Args:
 84 |                 event: The file system event that triggered this handler.
 85 |             """
 86 |             if event.src_path.endswith(".py"):
 87 |                 tool_path = Path(event.src_path)
 88 |                 tool_name = tool_path.stem
 89 | 
 90 |                 if tool_name not in ["__init__"]:
 91 |                     # Delegate to all registered handlers for this directory
 92 |                     for handler in ToolWatcher._registry_handlers.get(self.dir_path, {}).values():
 93 |                         try:
 94 |                             handler.on_modified(event)
 95 |                         except Exception as e:
 96 |                             logger.error("exception=<%s> | handler error", str(e))
 97 | 
 98 |     def start(self) -> None:
 99 |         """Start watching all tools directories for changes."""
100 |         # Initialize shared observer if not already done
101 |         if ToolWatcher._shared_observer is None:
102 |             ToolWatcher._shared_observer = Observer()
103 | 
104 |         # Create handler for this instance
105 |         self.tool_change_handler = self.ToolChangeHandler(self.tool_registry)
106 |         registry_id = id(self.tool_registry)
107 | 
108 |         # Get tools directories to watch
109 |         tools_dirs = self.tool_registry.get_tools_dirs()
110 | 
111 |         for tools_dir in tools_dirs:
112 |             dir_str = str(tools_dir)
113 | 
114 |             # Initialize the registry handlers dict for this directory if needed
115 |             if dir_str not in ToolWatcher._registry_handlers:
116 |                 ToolWatcher._registry_handlers[dir_str] = {}
117 | 
118 |             # Store this handler with its registry id
119 |             ToolWatcher._registry_handlers[dir_str][registry_id] = self.tool_change_handler
120 | 
121 |             # Schedule or update the master handler for this directory
122 |             if dir_str not in ToolWatcher._watched_dirs:
123 |                 # First time seeing this directory, create a master handler
124 |                 master_handler = self.MasterChangeHandler(dir_str)
125 |                 ToolWatcher._shared_observer.schedule(master_handler, dir_str, recursive=False)
126 |                 ToolWatcher._watched_dirs.add(dir_str)
127 |                 logger.debug("tools_dir=<%s> | started watching tools directory", tools_dir)
128 |             else:
129 |                 # Directory already being watched, just log it
130 |                 logger.debug("tools_dir=<%s> | directory already being watched", tools_dir)
131 | 
132 |         # Start the observer if not already started
133 |         if not ToolWatcher._observer_started:
134 |             ToolWatcher._shared_observer.start()
135 |             ToolWatcher._observer_started = True
136 |             logger.debug("tool directory watching initialized")
137 | 


--------------------------------------------------------------------------------
/src/strands/types/__init__.py:
--------------------------------------------------------------------------------
1 | """SDK type definitions."""
2 | 
3 | from .collections import PaginatedList
4 | 
5 | __all__ = ["PaginatedList"]
6 | 


--------------------------------------------------------------------------------
/src/strands/types/collections.py:
--------------------------------------------------------------------------------
 1 | """Generic collection types for the Strands SDK."""
 2 | 
 3 | from typing import Generic, List, Optional, TypeVar
 4 | 
 5 | T = TypeVar("T")
 6 | 
 7 | 
 8 | class PaginatedList(list, Generic[T]):
 9 |     """A generic list-like object that includes a pagination token.
10 | 
11 |     This maintains backwards compatibility by inheriting from list,
12 |     so existing code that expects List[T] will continue to work.
13 |     """
14 | 
15 |     def __init__(self, data: List[T], token: Optional[str] = None):
16 |         """Initialize a PaginatedList with data and an optional pagination token.
17 | 
18 |         Args:
19 |             data: The list of items to store.
20 |             token: Optional pagination token for retrieving additional items.
21 |         """
22 |         super().__init__(data)
23 |         self.pagination_token = token
24 | 


--------------------------------------------------------------------------------
/src/strands/types/content.py:
--------------------------------------------------------------------------------
  1 | """Content-related type definitions for the SDK.
  2 | 
  3 | This module defines the types used to represent messages, content blocks, and other content-related structures in the
  4 | SDK. These types are modeled after the Bedrock API.
  5 | 
  6 | - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html
  7 | """
  8 | 
  9 | from typing import Dict, List, Literal, Optional
 10 | 
 11 | from typing_extensions import TypedDict
 12 | 
 13 | from .media import DocumentContent, ImageContent, VideoContent
 14 | from .tools import ToolResult, ToolUse
 15 | 
 16 | 
 17 | class GuardContentText(TypedDict):
 18 |     """Text content to be evaluated by guardrails.
 19 | 
 20 |     Attributes:
 21 |         qualifiers: The qualifiers describing the text block.
 22 |         text: The input text details to be evaluated by the guardrail.
 23 |     """
 24 | 
 25 |     qualifiers: List[Literal["grounding_source", "query", "guard_content"]]
 26 |     text: str
 27 | 
 28 | 
 29 | class GuardContent(TypedDict):
 30 |     """Content block to be evaluated by guardrails.
 31 | 
 32 |     Attributes:
 33 |         text: Text within content block to be evaluated by the guardrail.
 34 |     """
 35 | 
 36 |     text: GuardContentText
 37 | 
 38 | 
 39 | class ReasoningTextBlock(TypedDict, total=False):
 40 |     """Contains the reasoning that the model used to return the output.
 41 | 
 42 |     Attributes:
 43 |         signature: A token that verifies that the reasoning text was generated by the model.
 44 |         text: The reasoning that the model used to return the output.
 45 |     """
 46 | 
 47 |     signature: Optional[str]
 48 |     text: str
 49 | 
 50 | 
 51 | class ReasoningContentBlock(TypedDict, total=False):
 52 |     """Contains content regarding the reasoning that is carried out by the model.
 53 | 
 54 |     Attributes:
 55 |         reasoningText: The reasoning that the model used to return the output.
 56 |         redactedContent: The content in the reasoning that was encrypted by the model provider for safety reasons.
 57 |     """
 58 | 
 59 |     reasoningText: ReasoningTextBlock
 60 |     redactedContent: bytes
 61 | 
 62 | 
 63 | class CachePoint(TypedDict):
 64 |     """A cache point configuration for optimizing conversation history.
 65 | 
 66 |     Attributes:
 67 |         type: The type of cache point, typically "default".
 68 |     """
 69 | 
 70 |     type: str
 71 | 
 72 | 
 73 | class ContentBlock(TypedDict, total=False):
 74 |     """A block of content for a message that you pass to, or receive from, a model.
 75 | 
 76 |     Attributes:
 77 |         cachePoint: A cache point configuration to optimize conversation history.
 78 |         document: A document to include in the message.
 79 |         guardContent: Contains the content to assess with the guardrail.
 80 |         image: Image to include in the message.
 81 |         reasoningContent: Contains content regarding the reasoning that is carried out by the model.
 82 |         text: Text to include in the message.
 83 |         toolResult: The result for a tool request that a model makes.
 84 |         toolUse: Information about a tool use request from a model.
 85 |         video: Video to include in the message.
 86 |     """
 87 | 
 88 |     cachePoint: CachePoint
 89 |     document: DocumentContent
 90 |     guardContent: GuardContent
 91 |     image: ImageContent
 92 |     reasoningContent: ReasoningContentBlock
 93 |     text: str
 94 |     toolResult: ToolResult
 95 |     toolUse: ToolUse
 96 |     video: VideoContent
 97 | 
 98 | 
 99 | class SystemContentBlock(TypedDict, total=False):
100 |     """Contains configurations for instructions to provide the model for how to handle input.
101 | 
102 |     Attributes:
103 |         guardContent: A content block to assess with the guardrail.
104 |         text: A system prompt for the model.
105 |     """
106 | 
107 |     guardContent: GuardContent
108 |     text: str
109 | 
110 | 
111 | class DeltaContent(TypedDict, total=False):
112 |     """A block of content in a streaming response.
113 | 
114 |     Attributes:
115 |         text: The content text.
116 |         toolUse: Information about a tool that the model is requesting to use.
117 |     """
118 | 
119 |     text: str
120 |     toolUse: Dict[Literal["input"], str]
121 | 
122 | 
123 | class ContentBlockStartToolUse(TypedDict):
124 |     """The start of a tool use block.
125 | 
126 |     Attributes:
127 |         name: The name of the tool that the model is requesting to use.
128 |         toolUseId: The ID for the tool request.
129 |     """
130 | 
131 |     name: str
132 |     toolUseId: str
133 | 
134 | 
135 | class ContentBlockStart(TypedDict, total=False):
136 |     """Content block start information.
137 | 
138 |     Attributes:
139 |         toolUse: Information about a tool that the model is requesting to use.
140 |     """
141 | 
142 |     toolUse: Optional[ContentBlockStartToolUse]
143 | 
144 | 
145 | class ContentBlockDelta(TypedDict):
146 |     """The content block delta event.
147 | 
148 |     Attributes:
149 |         contentBlockIndex: The block index for a content block delta event.
150 |         delta: The delta for a content block delta event.
151 |     """
152 | 
153 |     contentBlockIndex: int
154 |     delta: DeltaContent
155 | 
156 | 
157 | class ContentBlockStop(TypedDict):
158 |     """A content block stop event.
159 | 
160 |     Attributes:
161 |         contentBlockIndex: The index for a content block.
162 |     """
163 | 
164 |     contentBlockIndex: int
165 | 
166 | 
167 | Role = Literal["user", "assistant"]
168 | """Role of a message sender.
169 | 
170 | - "user": Messages from the user to the assistant
171 | - "assistant": Messages from the assistant to the user
172 | """
173 | 
174 | 
175 | class Message(TypedDict):
176 |     """A message in a conversation with the agent.
177 | 
178 |     Attributes:
179 |         content: The message content.
180 |         role: The role of the message sender.
181 |     """
182 | 
183 |     content: List[ContentBlock]
184 |     role: Role
185 | 
186 | 
187 | Messages = List[Message]
188 | """A list of messages representing a conversation."""
189 | 


--------------------------------------------------------------------------------
/src/strands/types/event_loop.py:
--------------------------------------------------------------------------------
 1 | """Event loop-related type definitions for the SDK."""
 2 | 
 3 | from typing import Literal
 4 | 
 5 | from typing_extensions import TypedDict
 6 | 
 7 | 
 8 | class Usage(TypedDict):
 9 |     """Token usage information for model interactions.
10 | 
11 |     Attributes:
12 |         inputTokens: Number of tokens sent in the request to the model..
13 |         outputTokens: Number of tokens that the model generated for the request.
14 |         totalTokens: Total number of tokens (input + output).
15 |     """
16 | 
17 |     inputTokens: int
18 |     outputTokens: int
19 |     totalTokens: int
20 | 
21 | 
22 | class Metrics(TypedDict):
23 |     """Performance metrics for model interactions.
24 | 
25 |     Attributes:
26 |         latencyMs (int): Latency of the model request in milliseconds.
27 |     """
28 | 
29 |     latencyMs: int
30 | 
31 | 
32 | StopReason = Literal[
33 |     "content_filtered",
34 |     "end_turn",
35 |     "guardrail_intervened",
36 |     "max_tokens",
37 |     "stop_sequence",
38 |     "tool_use",
39 | ]
40 | """Reason for the model ending its response generation.
41 | 
42 | - "content_filtered": Content was filtered due to policy violation
43 | - "end_turn": Normal completion of the response
44 | - "guardrail_intervened": Guardrail system intervened
45 | - "max_tokens": Maximum token limit reached
46 | - "stop_sequence": Stop sequence encountered
47 | - "tool_use": Model requested to use a tool
48 | """
49 | 


--------------------------------------------------------------------------------
/src/strands/types/exceptions.py:
--------------------------------------------------------------------------------
 1 | """Exception-related type definitions for the SDK."""
 2 | 
 3 | from typing import Any
 4 | 
 5 | 
 6 | class EventLoopException(Exception):
 7 |     """Exception raised by the event loop."""
 8 | 
 9 |     def __init__(self, original_exception: Exception, request_state: Any = None) -> None:
10 |         """Initialize exception.
11 | 
12 |         Args:
13 |             original_exception: The original exception that was raised.
14 |             request_state: The state of the request at the time of the exception.
15 |         """
16 |         self.original_exception = original_exception
17 |         self.request_state = request_state if request_state is not None else {}
18 |         super().__init__(str(original_exception))
19 | 
20 | 
21 | class ContextWindowOverflowException(Exception):
22 |     """Exception raised when the context window is exceeded.
23 | 
24 |     This exception is raised when the input to a model exceeds the maximum context window size that the model can
25 |     handle. This typically occurs when the combined length of the conversation history, system prompt, and current
26 |     message is too large for the model to process.
27 |     """
28 | 
29 |     pass
30 | 
31 | 
32 | class MCPClientInitializationError(Exception):
33 |     """Raised when the MCP server fails to initialize properly."""
34 | 
35 |     pass
36 | 
37 | 
38 | class ModelThrottledException(Exception):
39 |     """Exception raised when the model is throttled.
40 | 
41 |     This exception is raised when the model is throttled by the service. This typically occurs when the service is
42 |     throttling the requests from the client.
43 |     """
44 | 
45 |     def __init__(self, message: str) -> None:
46 |         """Initialize exception.
47 | 
48 |         Args:
49 |             message: The message from the service that describes the throttling.
50 |         """
51 |         self.message = message
52 |         super().__init__(message)
53 | 
54 |     pass
55 | 
56 | 
57 | class SessionException(Exception):
58 |     """Exception raised when session operations fail."""
59 | 
60 |     pass
61 | 


--------------------------------------------------------------------------------
/src/strands/types/media.py:
--------------------------------------------------------------------------------
 1 | """Media-related type definitions for the SDK.
 2 | 
 3 | These types are modeled after the Bedrock API.
 4 | 
 5 | - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html
 6 | """
 7 | 
 8 | from typing import Literal
 9 | 
10 | from typing_extensions import TypedDict
11 | 
12 | DocumentFormat = Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"]
13 | """Supported document formats."""
14 | 
15 | 
16 | class DocumentSource(TypedDict):
17 |     """Contains the content of a document.
18 | 
19 |     Attributes:
20 |         bytes: The binary content of the document.
21 |     """
22 | 
23 |     bytes: bytes
24 | 
25 | 
26 | class DocumentContent(TypedDict):
27 |     """A document to include in a message.
28 | 
29 |     Attributes:
30 |         format: The format of the document (e.g., "pdf", "txt").
31 |         name: The name of the document.
32 |         source: The source containing the document's binary content.
33 |     """
34 | 
35 |     format: Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"]
36 |     name: str
37 |     source: DocumentSource
38 | 
39 | 
40 | ImageFormat = Literal["png", "jpeg", "gif", "webp"]
41 | """Supported image formats."""
42 | 
43 | 
44 | class ImageSource(TypedDict):
45 |     """Contains the content of an image.
46 | 
47 |     Attributes:
48 |         bytes: The binary content of the image.
49 |     """
50 | 
51 |     bytes: bytes
52 | 
53 | 
54 | class ImageContent(TypedDict):
55 |     """An image to include in a message.
56 | 
57 |     Attributes:
58 |         format: The format of the image (e.g., "png", "jpeg").
59 |         source: The source containing the image's binary content.
60 |     """
61 | 
62 |     format: ImageFormat
63 |     source: ImageSource
64 | 
65 | 
66 | VideoFormat = Literal["flv", "mkv", "mov", "mpeg", "mpg", "mp4", "three_gp", "webm", "wmv"]
67 | """Supported video formats."""
68 | 
69 | 
70 | class VideoSource(TypedDict):
71 |     """Contains the content of a video.
72 | 
73 |     Attributes:
74 |         bytes: The binary content of the video.
75 |     """
76 | 
77 |     bytes: bytes
78 | 
79 | 
80 | class VideoContent(TypedDict):
81 |     """A video to include in a message.
82 | 
83 |     Attributes:
84 |         format: The format of the video (e.g., "mp4", "avi").
85 |         source: The source containing the video's binary content.
86 |     """
87 | 
88 |     format: VideoFormat
89 |     source: VideoSource
90 | 


--------------------------------------------------------------------------------
/src/strands/types/session.py:
--------------------------------------------------------------------------------
  1 | """Data models for session management."""
  2 | 
  3 | import base64
  4 | import inspect
  5 | from dataclasses import asdict, dataclass, field
  6 | from datetime import datetime, timezone
  7 | from enum import Enum
  8 | from typing import TYPE_CHECKING, Any, Dict, Optional
  9 | 
 10 | from .content import Message
 11 | 
 12 | if TYPE_CHECKING:
 13 |     from ..agent.agent import Agent
 14 | 
 15 | 
 16 | class SessionType(str, Enum):
 17 |     """Enumeration of session types.
 18 | 
 19 |     As sessions are expanded to support new usecases like multi-agent patterns,
 20 |     new types will be added here.
 21 |     """
 22 | 
 23 |     AGENT = "AGENT"
 24 | 
 25 | 
 26 | def encode_bytes_values(obj: Any) -> Any:
 27 |     """Recursively encode any bytes values in an object to base64.
 28 | 
 29 |     Handles dictionaries, lists, and nested structures.
 30 |     """
 31 |     if isinstance(obj, bytes):
 32 |         return {"__bytes_encoded__": True, "data": base64.b64encode(obj).decode()}
 33 |     elif isinstance(obj, dict):
 34 |         return {k: encode_bytes_values(v) for k, v in obj.items()}
 35 |     elif isinstance(obj, list):
 36 |         return [encode_bytes_values(item) for item in obj]
 37 |     else:
 38 |         return obj
 39 | 
 40 | 
 41 | def decode_bytes_values(obj: Any) -> Any:
 42 |     """Recursively decode any base64-encoded bytes values in an object.
 43 | 
 44 |     Handles dictionaries, lists, and nested structures.
 45 |     """
 46 |     if isinstance(obj, dict):
 47 |         if obj.get("__bytes_encoded__") is True and "data" in obj:
 48 |             return base64.b64decode(obj["data"])
 49 |         return {k: decode_bytes_values(v) for k, v in obj.items()}
 50 |     elif isinstance(obj, list):
 51 |         return [decode_bytes_values(item) for item in obj]
 52 |     else:
 53 |         return obj
 54 | 
 55 | 
 56 | @dataclass
 57 | class SessionMessage:
 58 |     """Message within a SessionAgent.
 59 | 
 60 |     Attributes:
 61 |         message: Message content
 62 |         message_id: Index of the message in the conversation history
 63 |         redact_message: If the original message is redacted, this is the new content to use
 64 |         created_at: ISO format timestamp for when this message was created
 65 |         updated_at: ISO format timestamp for when this message was last updated
 66 |     """
 67 | 
 68 |     message: Message
 69 |     message_id: int
 70 |     redact_message: Optional[Message] = None
 71 |     created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
 72 |     updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
 73 | 
 74 |     @classmethod
 75 |     def from_message(cls, message: Message, index: int) -> "SessionMessage":
 76 |         """Convert from a Message, base64 encoding bytes values."""
 77 |         return cls(
 78 |             message=message,
 79 |             message_id=index,
 80 |             created_at=datetime.now(timezone.utc).isoformat(),
 81 |             updated_at=datetime.now(timezone.utc).isoformat(),
 82 |         )
 83 | 
 84 |     def to_message(self) -> Message:
 85 |         """Convert SessionMessage back to a Message, decoding any bytes values.
 86 | 
 87 |         If the message was redacted, return the redact content instead.
 88 |         """
 89 |         if self.redact_message is not None:
 90 |             return self.redact_message
 91 |         else:
 92 |             return self.message
 93 | 
 94 |     @classmethod
 95 |     def from_dict(cls, env: dict[str, Any]) -> "SessionMessage":
 96 |         """Initialize a SessionMessage from a dictionary, ignoring keys that are not class parameters."""
 97 |         extracted_relevant_parameters = {k: v for k, v in env.items() if k in inspect.signature(cls).parameters}
 98 |         return cls(**decode_bytes_values(extracted_relevant_parameters))
 99 | 
100 |     def to_dict(self) -> dict[str, Any]:
101 |         """Convert the SessionMessage to a dictionary representation."""
102 |         return encode_bytes_values(asdict(self))  # type: ignore
103 | 
104 | 
105 | @dataclass
106 | class SessionAgent:
107 |     """Agent that belongs to a Session."""
108 | 
109 |     agent_id: str
110 |     state: Dict[str, Any]
111 |     conversation_manager_state: Dict[str, Any]
112 |     created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
113 |     updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
114 | 
115 |     @classmethod
116 |     def from_agent(cls, agent: "Agent") -> "SessionAgent":
117 |         """Convert an Agent to a SessionAgent."""
118 |         if agent.agent_id is None:
119 |             raise ValueError("agent_id needs to be defined.")
120 |         return cls(
121 |             agent_id=agent.agent_id,
122 |             conversation_manager_state=agent.conversation_manager.get_state(),
123 |             state=agent.state.get(),
124 |         )
125 | 
126 |     @classmethod
127 |     def from_dict(cls, env: dict[str, Any]) -> "SessionAgent":
128 |         """Initialize a SessionAgent from a dictionary, ignoring keys that are not calss parameters."""
129 |         return cls(**{k: v for k, v in env.items() if k in inspect.signature(cls).parameters})
130 | 
131 |     def to_dict(self) -> dict[str, Any]:
132 |         """Convert the SessionAgent to a dictionary representation."""
133 |         return asdict(self)
134 | 
135 | 
136 | @dataclass
137 | class Session:
138 |     """Session data model."""
139 | 
140 |     session_id: str
141 |     session_type: SessionType
142 |     created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
143 |     updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
144 | 
145 |     @classmethod
146 |     def from_dict(cls, env: dict[str, Any]) -> "Session":
147 |         """Initialize a Session from a dictionary, ignoring keys that are not calss parameters."""
148 |         return cls(**{k: v for k, v in env.items() if k in inspect.signature(cls).parameters})
149 | 
150 |     def to_dict(self) -> dict[str, Any]:
151 |         """Convert the Session to a dictionary representation."""
152 |         return asdict(self)
153 | 


--------------------------------------------------------------------------------
/src/strands/types/traces.py:
--------------------------------------------------------------------------------
1 | """Tracing type definitions for the SDK."""
2 | 
3 | from typing import List, Union
4 | 
5 | AttributeValue = Union[str, bool, float, int, List[str], List[bool], List[float], List[int]]
6 | 


--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/strands-agents/sdk-python/f5e24d402bb22dc1a54c94dd030b4be0f7e73261/tests/__init__.py


--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
  1 | import configparser
  2 | import logging
  3 | import os
  4 | import sys
  5 | 
  6 | import boto3
  7 | import moto
  8 | import pytest
  9 | 
 10 | ## Moto
 11 | 
 12 | # Get the log level from the environment variable
 13 | log_level = os.environ.get("LOG_LEVEL", "INFO").upper()
 14 | 
 15 | logging.getLogger("strands").setLevel(log_level)
 16 | logging.basicConfig(
 17 |     format="%(levelname)s | %(name)s | %(message)s", handlers=[logging.StreamHandler(stream=sys.stdout)]
 18 | )
 19 | 
 20 | 
 21 | @pytest.fixture
 22 | def moto_env(monkeypatch):
 23 |     monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test")
 24 |     monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test")
 25 |     monkeypatch.setenv("AWS_SECURITY_TOKEN", "test")
 26 |     monkeypatch.setenv("AWS_DEFAULT_REGION", "us-west-2")
 27 |     monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False)
 28 |     monkeypatch.delenv("OTEL_EXPORTER_OTLP_HEADERS", raising=False)
 29 | 
 30 | 
 31 | @pytest.fixture
 32 | def moto_mock_aws():
 33 |     with moto.mock_aws():
 34 |         yield
 35 | 
 36 | 
 37 | @pytest.fixture
 38 | def moto_cloudwatch_client():
 39 |     return boto3.client("cloudwatch")
 40 | 
 41 | 
 42 | ## Boto3
 43 | 
 44 | 
 45 | @pytest.fixture
 46 | def boto3_profile_name():
 47 |     return "test-profile"
 48 | 
 49 | 
 50 | @pytest.fixture
 51 | def boto3_profile(boto3_profile_name):
 52 |     config = configparser.ConfigParser()
 53 |     config[boto3_profile_name] = {
 54 |         "aws_access_key_id": "test",
 55 |         "aws_secret_access_key": "test",
 56 |     }
 57 | 
 58 |     return config
 59 | 
 60 | 
 61 | @pytest.fixture
 62 | def boto3_profile_path(boto3_profile, tmp_path, monkeypatch):
 63 |     path = tmp_path / ".aws/credentials"
 64 |     path.parent.mkdir(exist_ok=True)
 65 |     with path.open("w") as fp:
 66 |         boto3_profile.write(fp)
 67 | 
 68 |     monkeypatch.setenv("AWS_SHARED_CREDENTIALS_FILE", str(path))
 69 | 
 70 |     return path
 71 | 
 72 | 
 73 | ## Async
 74 | 
 75 | 
 76 | @pytest.fixture(scope="session")
 77 | def agenerator():
 78 |     async def agenerator(items):
 79 |         for item in items:
 80 |             yield item
 81 | 
 82 |     return agenerator
 83 | 
 84 | 
 85 | @pytest.fixture(scope="session")
 86 | def alist():
 87 |     async def alist(items):
 88 |         return [item async for item in items]
 89 | 
 90 |     return alist
 91 | 
 92 | 
 93 | ## Itertools
 94 | 
 95 | 
 96 | @pytest.fixture(scope="session")
 97 | def generate():
 98 |     def generate(generator):
 99 |         events = []
100 | 
101 |         try:
102 |             while True:
103 |                 event = next(generator)
104 |                 events.append(event)
105 | 
106 |         except StopIteration as stop:
107 |             return events, stop.value
108 | 
109 |     return generate
110 | 


--------------------------------------------------------------------------------
/tests/fixtures/mock_hook_provider.py:
--------------------------------------------------------------------------------
 1 | from typing import Iterator, Tuple, Type
 2 | 
 3 | from strands.hooks import HookEvent, HookProvider, HookRegistry
 4 | 
 5 | 
 6 | class MockHookProvider(HookProvider):
 7 |     def __init__(self, event_types: list[Type]):
 8 |         self.events_received = []
 9 |         self.events_types = event_types
10 | 
11 |     def get_events(self) -> Tuple[int, Iterator[HookEvent]]:
12 |         return len(self.events_received), iter(self.events_received)
13 | 
14 |     def register_hooks(self, registry: HookRegistry) -> None:
15 |         for event_type in self.events_types:
16 |             registry.add_callback(event_type, self.add_event)
17 | 
18 |     def add_event(self, event: HookEvent) -> None:
19 |         self.events_received.append(event)
20 | 


--------------------------------------------------------------------------------
/tests/fixtures/mock_session_repository.py:
--------------------------------------------------------------------------------
 1 | from strands.session.session_repository import SessionRepository
 2 | from strands.types.exceptions import SessionException
 3 | from strands.types.session import SessionAgent, SessionMessage
 4 | 
 5 | 
 6 | class MockedSessionRepository(SessionRepository):
 7 |     """Mock repository for testing."""
 8 | 
 9 |     def __init__(self):
10 |         """Initialize with empty storage."""
11 |         self.sessions = {}
12 |         self.agents = {}
13 |         self.messages = {}
14 | 
15 |     def create_session(self, session) -> None:
16 |         """Create a session."""
17 |         session_id = session.session_id
18 |         if session_id in self.sessions:
19 |             raise SessionException(f"Session {session_id} already exists")
20 |         self.sessions[session_id] = session
21 |         self.agents[session_id] = {}
22 |         self.messages[session_id] = {}
23 | 
24 |     def read_session(self, session_id) -> SessionAgent:
25 |         """Read a session."""
26 |         return self.sessions.get(session_id)
27 | 
28 |     def create_agent(self, session_id, session_agent) -> None:
29 |         """Create an agent."""
30 |         agent_id = session_agent.agent_id
31 |         if session_id not in self.sessions:
32 |             raise SessionException(f"Session {session_id} does not exist")
33 |         if agent_id in self.agents.get(session_id, {}):
34 |             raise SessionException(f"Agent {agent_id} already exists in session {session_id}")
35 |         self.agents.setdefault(session_id, {})[agent_id] = session_agent
36 |         self.messages.setdefault(session_id, {}).setdefault(agent_id, {})
37 |         return session_agent
38 | 
39 |     def read_agent(self, session_id, agent_id) -> SessionAgent:
40 |         """Read an agent."""
41 |         if session_id not in self.sessions:
42 |             return None
43 |         return self.agents.get(session_id, {}).get(agent_id)
44 | 
45 |     def update_agent(self, session_id, session_agent) -> None:
46 |         """Update an agent."""
47 |         agent_id = session_agent.agent_id
48 |         if session_id not in self.sessions:
49 |             raise SessionException(f"Session {session_id} does not exist")
50 |         if agent_id not in self.agents.get(session_id, {}):
51 |             raise SessionException(f"Agent {agent_id} does not exist in session {session_id}")
52 |         self.agents[session_id][agent_id] = session_agent
53 | 
54 |     def create_message(self, session_id, agent_id, session_message) -> None:
55 |         """Create a message."""
56 |         message_id = session_message.message_id
57 |         if session_id not in self.sessions:
58 |             raise SessionException(f"Session {session_id} does not exist")
59 |         if agent_id not in self.agents.get(session_id, {}):
60 |             raise SessionException(f"Agent {agent_id} does not exists in session {session_id}")
61 |         if message_id in self.messages.get(session_id, {}).get(agent_id, {}):
62 |             raise SessionException(f"Message {message_id} already exists in agent {agent_id} in session {session_id}")
63 |         self.messages.setdefault(session_id, {}).setdefault(agent_id, {})[message_id] = session_message
64 | 
65 |     def read_message(self, session_id, agent_id, message_id) -> SessionMessage:
66 |         """Read a message."""
67 |         if session_id not in self.sessions:
68 |             return None
69 |         if agent_id not in self.agents.get(session_id, {}):
70 |             return None
71 |         return self.messages.get(session_id, {}).get(agent_id, {}).get(message_id)
72 | 
73 |     def update_message(self, session_id, agent_id, session_message) -> None:
74 |         """Update a message."""
75 | 
76 |         message_id = session_message.message_id
77 |         if session_id not in self.sessions:
78 |             raise SessionException(f"Session {session_id} does not exist")
79 |         if agent_id not in self.agents.get(session_id, {}):
80 |             raise SessionException(f"Agent {agent_id} does not exist in session {session_id}")
81 |         if message_id not in self.messages.get(session_id, {}).get(agent_id, {}):
82 |             raise SessionException(f"Message {message_id} does not exist in session {session_id}")
83 |         self.messages[session_id][agent_id][message_id] = session_message
84 | 
85 |     def list_messages(self, session_id, agent_id, limit=None, offset=0) -> list[SessionMessage]:
86 |         """List messages."""
87 |         if session_id not in self.sessions:
88 |             return []
89 |         if agent_id not in self.agents.get(session_id, {}):
90 |             return []
91 | 
92 |         messages = self.messages.get(session_id, {}).get(agent_id, {})
93 |         sorted_messages = [messages[key] for key in sorted(messages.keys())]
94 | 
95 |         if limit is not None:
96 |             return sorted_messages[offset : offset + limit]
97 |         return sorted_messages[offset:]
98 | 


--------------------------------------------------------------------------------
/tests/fixtures/mocked_model_provider.py:
--------------------------------------------------------------------------------
 1 | import json
 2 | from typing import Any, AsyncGenerator, Iterable, Optional, Type, TypedDict, TypeVar, Union
 3 | 
 4 | from pydantic import BaseModel
 5 | 
 6 | from strands.models import Model
 7 | from strands.types.content import Message, Messages
 8 | from strands.types.event_loop import StopReason
 9 | from strands.types.streaming import StreamEvent
10 | from strands.types.tools import ToolSpec
11 | 
12 | T = TypeVar("T", bound=BaseModel)
13 | 
14 | 
15 | class RedactionMessage(TypedDict):
16 |     redactedUserContent: str
17 |     redactedAssistantContent: str
18 | 
19 | 
20 | class MockedModelProvider(Model):
21 |     """A mock implementation of the Model interface for testing purposes.
22 | 
23 |     This class simulates a model provider by returning pre-defined agent responses
24 |     in sequence. It implements the Model interface methods and provides functionality
25 |     to stream mock responses as events.
26 |     """
27 | 
28 |     def __init__(self, agent_responses: list[Union[Message, RedactionMessage]]):
29 |         self.agent_responses = agent_responses
30 |         self.index = 0
31 | 
32 |     def format_chunk(self, event: Any) -> StreamEvent:
33 |         return event
34 | 
35 |     def format_request(
36 |         self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
37 |     ) -> Any:
38 |         return None
39 | 
40 |     def get_config(self) -> Any:
41 |         pass
42 | 
43 |     def update_config(self, **model_config: Any) -> None:
44 |         pass
45 | 
46 |     async def structured_output(
47 |         self,
48 |         output_model: Type[T],
49 |         prompt: Messages,
50 |         system_prompt: Optional[str] = None,
51 |         **kwargs: Any,
52 |     ) -> AsyncGenerator[Any, None]:
53 |         pass
54 | 
55 |     async def stream(
56 |         self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
57 |     ) -> AsyncGenerator[Any, None]:
58 |         events = self.map_agent_message_to_events(self.agent_responses[self.index])
59 |         for event in events:
60 |             yield event
61 | 
62 |         self.index += 1
63 | 
64 |     def map_agent_message_to_events(self, agent_message: Union[Message, RedactionMessage]) -> Iterable[dict[str, Any]]:
65 |         stop_reason: StopReason = "end_turn"
66 |         yield {"messageStart": {"role": "assistant"}}
67 |         if agent_message.get("redactedAssistantContent"):
68 |             yield {"redactContent": {"redactUserContentMessage": agent_message["redactedUserContent"]}}
69 |             yield {"contentBlockStart": {"start": {}}}
70 |             yield {"contentBlockDelta": {"delta": {"text": agent_message["redactedAssistantContent"]}}}
71 |             yield {"contentBlockStop": {}}
72 |             stop_reason = "guardrail_intervened"
73 |         else:
74 |             for content in agent_message["content"]:
75 |                 if "text" in content:
76 |                     yield {"contentBlockStart": {"start": {}}}
77 |                     yield {"contentBlockDelta": {"delta": {"text": content["text"]}}}
78 |                     yield {"contentBlockStop": {}}
79 |                 if "toolUse" in content:
80 |                     stop_reason = "tool_use"
81 |                     yield {
82 |                         "contentBlockStart": {
83 |                             "start": {
84 |                                 "toolUse": {
85 |                                     "name": content["toolUse"]["name"],
86 |                                     "toolUseId": content["toolUse"]["toolUseId"],
87 |                                 }
88 |                             }
89 |                         }
90 |                     }
91 |                     yield {
92 |                         "contentBlockDelta": {"delta": {"toolUse": {"input": json.dumps(content["toolUse"]["input"])}}}
93 |                     }
94 |                     yield {"contentBlockStop": {}}
95 | 
96 |         yield {"messageStop": {"stopReason": stop_reason}}
97 | 


--------------------------------------------------------------------------------
/tests/strands/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/strands-agents/sdk-python/f5e24d402bb22dc1a54c94dd030b4be0f7e73261/tests/strands/__init__.py


--------------------------------------------------------------------------------
/tests/strands/agent/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/strands-agents/sdk-python/f5e24d402bb22dc1a54c94dd030b4be0f7e73261/tests/strands/agent/__init__.py


--------------------------------------------------------------------------------
/tests/strands/agent/test_agent_result.py:
--------------------------------------------------------------------------------
 1 | import unittest.mock
 2 | from typing import cast
 3 | 
 4 | import pytest
 5 | 
 6 | from strands.agent.agent_result import AgentResult
 7 | from strands.telemetry.metrics import EventLoopMetrics
 8 | from strands.types.content import Message
 9 | from strands.types.streaming import StopReason
10 | 
11 | 
12 | @pytest.fixture
13 | def mock_metrics():
14 |     return unittest.mock.Mock(spec=EventLoopMetrics)
15 | 
16 | 
17 | @pytest.fixture
18 | def simple_message():
19 |     return {"role": "assistant", "content": [{"text": "Hello world!"}]}
20 | 
21 | 
22 | @pytest.fixture
23 | def complex_message():
24 |     return {
25 |         "role": "assistant",
26 |         "content": [
27 |             {"text": "First paragraph"},
28 |             {"text": "Second paragraph"},
29 |             {"non_text_content": "This should be ignored"},
30 |             {"text": "Third paragraph"},
31 |         ],
32 |     }
33 | 
34 | 
35 | @pytest.fixture
36 | def empty_message():
37 |     return {"role": "assistant", "content": []}
38 | 
39 | 
40 | def test__init__(mock_metrics, simple_message: Message):
41 |     """Test that AgentResult can be properly initialized with all required fields."""
42 |     stop_reason: StopReason = "end_turn"
43 |     state = {"key": "value"}
44 | 
45 |     result = AgentResult(stop_reason=stop_reason, message=simple_message, metrics=mock_metrics, state=state)
46 | 
47 |     assert result.stop_reason == stop_reason
48 |     assert result.message == simple_message
49 |     assert result.metrics == mock_metrics
50 |     assert result.state == state
51 | 
52 | 
53 | def test__str__simple(mock_metrics, simple_message: Message):
54 |     """Test that str() works with a simple message."""
55 |     result = AgentResult(stop_reason="end_turn", message=simple_message, metrics=mock_metrics, state={})
56 | 
57 |     message_string = str(result)
58 |     assert message_string == "Hello world!\n"
59 | 
60 | 
61 | def test__str__complex(mock_metrics, complex_message: Message):
62 |     """Test that str() works with a complex message with multiple text blocks."""
63 |     result = AgentResult(stop_reason="end_turn", message=complex_message, metrics=mock_metrics, state={})
64 | 
65 |     message_string = str(result)
66 |     assert message_string == "First paragraph\nSecond paragraph\nThird paragraph\n"
67 | 
68 | 
69 | def test__str__empty(mock_metrics, empty_message: Message):
70 |     """Test that str() works with an empty message."""
71 |     result = AgentResult(stop_reason="end_turn", message=empty_message, metrics=mock_metrics, state={})
72 | 
73 |     message_string = str(result)
74 |     assert message_string == ""
75 | 
76 | 
77 | def test__str__no_content(mock_metrics):
78 |     """Test that str() works with a message that has no content field."""
79 |     message_without_content = cast(Message, {"role": "assistant"})
80 | 
81 |     result = AgentResult(stop_reason="end_turn", message=message_without_content, metrics=mock_metrics, state={})
82 | 
83 |     message_string = str(result)
84 |     assert message_string == ""
85 | 
86 | 
87 | def test__str__non_dict_content(mock_metrics):
88 |     """Test that str() handles non-dictionary content items gracefully."""
89 |     message_with_non_dict = cast(
90 |         Message,
91 |         {"role": "assistant", "content": [{"text": "Valid text"}, "Not a dictionary", {"text": "More valid text"}]},
92 |     )
93 | 
94 |     result = AgentResult(stop_reason="end_turn", message=message_with_non_dict, metrics=mock_metrics, state={})
95 | 
96 |     message_string = str(result)
97 |     assert message_string == "Valid text\nMore valid text\n"
98 | 


--------------------------------------------------------------------------------
/tests/strands/agent/test_agent_state.py:
--------------------------------------------------------------------------------
  1 | """Tests for AgentState class."""
  2 | 
  3 | import pytest
  4 | 
  5 | from strands import Agent, tool
  6 | from strands.agent.state import AgentState
  7 | from strands.types.content import Messages
  8 | 
  9 | from ...fixtures.mocked_model_provider import MockedModelProvider
 10 | 
 11 | 
 12 | def test_set_and_get():
 13 |     """Test basic set and get operations."""
 14 |     state = AgentState()
 15 |     state.set("key", "value")
 16 |     assert state.get("key") == "value"
 17 | 
 18 | 
 19 | def test_get_nonexistent_key():
 20 |     """Test getting nonexistent key returns None."""
 21 |     state = AgentState()
 22 |     assert state.get("nonexistent") is None
 23 | 
 24 | 
 25 | def test_get_entire_state():
 26 |     """Test getting entire state when no key specified."""
 27 |     state = AgentState()
 28 |     state.set("key1", "value1")
 29 |     state.set("key2", "value2")
 30 | 
 31 |     result = state.get()
 32 |     assert result == {"key1": "value1", "key2": "value2"}
 33 | 
 34 | 
 35 | def test_initialize_and_get_entire_state():
 36 |     """Test getting entire state when no key specified."""
 37 |     state = AgentState({"key1": "value1", "key2": "value2"})
 38 | 
 39 |     result = state.get()
 40 |     assert result == {"key1": "value1", "key2": "value2"}
 41 | 
 42 | 
 43 | def test_initialize_with_error():
 44 |     with pytest.raises(ValueError, match="not JSON serializable"):
 45 |         AgentState({"object", object()})
 46 | 
 47 | 
 48 | def test_delete():
 49 |     """Test deleting keys."""
 50 |     state = AgentState()
 51 |     state.set("key1", "value1")
 52 |     state.set("key2", "value2")
 53 | 
 54 |     state.delete("key1")
 55 | 
 56 |     assert state.get("key1") is None
 57 |     assert state.get("key2") == "value2"
 58 | 
 59 | 
 60 | def test_delete_nonexistent_key():
 61 |     """Test deleting nonexistent key doesn't raise error."""
 62 |     state = AgentState()
 63 |     state.delete("nonexistent")  # Should not raise
 64 | 
 65 | 
 66 | def test_json_serializable_values():
 67 |     """Test that only JSON-serializable values are accepted."""
 68 |     state = AgentState()
 69 | 
 70 |     # Valid JSON types
 71 |     state.set("string", "test")
 72 |     state.set("int", 42)
 73 |     state.set("bool", True)
 74 |     state.set("list", [1, 2, 3])
 75 |     state.set("dict", {"nested": "value"})
 76 |     state.set("null", None)
 77 | 
 78 |     # Invalid JSON types should raise ValueError
 79 |     with pytest.raises(ValueError, match="not JSON serializable"):
 80 |         state.set("function", lambda x: x)
 81 | 
 82 |     with pytest.raises(ValueError, match="not JSON serializable"):
 83 |         state.set("object", object())
 84 | 
 85 | 
 86 | def test_key_validation():
 87 |     """Test key validation for set and delete operations."""
 88 |     state = AgentState()
 89 | 
 90 |     # Invalid keys for set
 91 |     with pytest.raises(ValueError, match="Key cannot be None"):
 92 |         state.set(None, "value")
 93 | 
 94 |     with pytest.raises(ValueError, match="Key cannot be empty"):
 95 |         state.set("", "value")
 96 | 
 97 |     with pytest.raises(ValueError, match="Key must be a string"):
 98 |         state.set(123, "value")
 99 | 
100 |     # Invalid keys for delete
101 |     with pytest.raises(ValueError, match="Key cannot be None"):
102 |         state.delete(None)
103 | 
104 |     with pytest.raises(ValueError, match="Key cannot be empty"):
105 |         state.delete("")
106 | 
107 | 
108 | def test_initial_state():
109 |     """Test initialization with initial state."""
110 |     initial = {"key1": "value1", "key2": "value2"}
111 |     state = AgentState(initial_state=initial)
112 | 
113 |     assert state.get("key1") == "value1"
114 |     assert state.get("key2") == "value2"
115 |     assert state.get() == initial
116 | 
117 | 
118 | def test_agent_state_update_from_tool():
119 |     @tool
120 |     def update_state(agent: Agent):
121 |         agent.state.set("hello", "world")
122 |         agent.state.set("foo", "baz")
123 | 
124 |     agent_messages: Messages = [
125 |         {
126 |             "role": "assistant",
127 |             "content": [{"toolUse": {"name": "update_state", "toolUseId": "123", "input": {}}}],
128 |         },
129 |         {"role": "assistant", "content": [{"text": "I invoked a tool!"}]},
130 |     ]
131 |     mocked_model_provider = MockedModelProvider(agent_messages)
132 | 
133 |     agent = Agent(
134 |         model=mocked_model_provider,
135 |         tools=[update_state],
136 |         state={"foo": "bar"},
137 |     )
138 | 
139 |     assert agent.state.get("hello") is None
140 |     assert agent.state.get("foo") == "bar"
141 | 
142 |     agent("Invoke Mocked!")
143 | 
144 |     assert agent.state.get("hello") == "world"
145 |     assert agent.state.get("foo") == "baz"
146 | 


--------------------------------------------------------------------------------
/tests/strands/event_loop/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/strands-agents/sdk-python/f5e24d402bb22dc1a54c94dd030b4be0f7e73261/tests/strands/event_loop/__init__.py


--------------------------------------------------------------------------------
/tests/strands/experimental/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/strands-agents/sdk-python/f5e24d402bb22dc1a54c94dd030b4be0f7e73261/tests/strands/experimental/__init__.py


--------------------------------------------------------------------------------
/tests/strands/experimental/hooks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/strands-agents/sdk-python/f5e24d402bb22dc1a54c94dd030b4be0f7e73261/tests/strands/experimental/hooks/__init__.py


--------------------------------------------------------------------------------
/tests/strands/experimental/hooks/test_events.py:
--------------------------------------------------------------------------------
  1 | from unittest.mock import Mock
  2 | 
  3 | import pytest
  4 | 
  5 | from strands.experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent
  6 | from strands.hooks import (
  7 |     AfterInvocationEvent,
  8 |     AgentInitializedEvent,
  9 |     BeforeInvocationEvent,
 10 |     MessageAddedEvent,
 11 | )
 12 | from strands.types.tools import ToolResult, ToolUse
 13 | 
 14 | 
 15 | @pytest.fixture
 16 | def agent():
 17 |     return Mock()
 18 | 
 19 | 
 20 | @pytest.fixture
 21 | def tool():
 22 |     tool = Mock()
 23 |     tool.tool_name = "test_tool"
 24 |     return tool
 25 | 
 26 | 
 27 | @pytest.fixture
 28 | def tool_use():
 29 |     return ToolUse(name="test_tool", toolUseId="123", input={"param": "value"})
 30 | 
 31 | 
 32 | @pytest.fixture
 33 | def tool_invocation_state():
 34 |     return {"param": "value"}
 35 | 
 36 | 
 37 | @pytest.fixture
 38 | def tool_result():
 39 |     return ToolResult(content=[{"text": "result"}], status="success", toolUseId="123")
 40 | 
 41 | 
 42 | @pytest.fixture
 43 | def initialized_event(agent):
 44 |     return AgentInitializedEvent(agent=agent)
 45 | 
 46 | 
 47 | @pytest.fixture
 48 | def start_request_event(agent):
 49 |     return BeforeInvocationEvent(agent=agent)
 50 | 
 51 | 
 52 | @pytest.fixture
 53 | def messaged_added_event(agent):
 54 |     return MessageAddedEvent(agent=agent, message=Mock())
 55 | 
 56 | 
 57 | @pytest.fixture
 58 | def end_request_event(agent):
 59 |     return AfterInvocationEvent(agent=agent)
 60 | 
 61 | 
 62 | @pytest.fixture
 63 | def before_tool_event(agent, tool, tool_use, tool_invocation_state):
 64 |     return BeforeToolInvocationEvent(
 65 |         agent=agent,
 66 |         selected_tool=tool,
 67 |         tool_use=tool_use,
 68 |         invocation_state=tool_invocation_state,
 69 |     )
 70 | 
 71 | 
 72 | @pytest.fixture
 73 | def after_tool_event(agent, tool, tool_use, tool_invocation_state, tool_result):
 74 |     return AfterToolInvocationEvent(
 75 |         agent=agent,
 76 |         selected_tool=tool,
 77 |         tool_use=tool_use,
 78 |         invocation_state=tool_invocation_state,
 79 |         result=tool_result,
 80 |     )
 81 | 
 82 | 
 83 | def test_event_should_reverse_callbacks(
 84 |     initialized_event,
 85 |     start_request_event,
 86 |     messaged_added_event,
 87 |     end_request_event,
 88 |     before_tool_event,
 89 |     after_tool_event,
 90 | ):
 91 |     # note that we ignore E712 (explicit booleans) for consistency/readability purposes
 92 | 
 93 |     assert initialized_event.should_reverse_callbacks == False  # noqa: E712
 94 | 
 95 |     assert messaged_added_event.should_reverse_callbacks == False  # noqa: E712
 96 | 
 97 |     assert start_request_event.should_reverse_callbacks == False  # noqa: E712
 98 |     assert end_request_event.should_reverse_callbacks == True  # noqa: E712
 99 | 
100 |     assert before_tool_event.should_reverse_callbacks == False  # noqa: E712
101 |     assert after_tool_event.should_reverse_callbacks == True  # noqa: E712
102 | 
103 | 
104 | def test_message_added_event_cannot_write_properties(messaged_added_event):
105 |     with pytest.raises(AttributeError, match="Property agent is not writable"):
106 |         messaged_added_event.agent = Mock()
107 |     with pytest.raises(AttributeError, match="Property message is not writable"):
108 |         messaged_added_event.message = {}
109 | 
110 | 
111 | def test_before_tool_invocation_event_can_write_properties(before_tool_event):
112 |     new_tool_use = ToolUse(name="new_tool", toolUseId="456", input={})
113 |     before_tool_event.selected_tool = None  # Should not raise
114 |     before_tool_event.tool_use = new_tool_use  # Should not raise
115 | 
116 | 
117 | def test_before_tool_invocation_event_cannot_write_properties(before_tool_event):
118 |     with pytest.raises(AttributeError, match="Property agent is not writable"):
119 |         before_tool_event.agent = Mock()
120 |     with pytest.raises(AttributeError, match="Property invocation_state is not writable"):
121 |         before_tool_event.invocation_state = {}
122 | 
123 | 
124 | def test_after_tool_invocation_event_can_write_properties(after_tool_event):
125 |     new_result = ToolResult(content=[{"text": "new result"}], status="success", toolUseId="456")
126 |     after_tool_event.result = new_result  # Should not raise
127 | 
128 | 
129 | def test_after_tool_invocation_event_cannot_write_properties(after_tool_event):
130 |     with pytest.raises(AttributeError, match="Property agent is not writable"):
131 |         after_tool_event.agent = Mock()
132 |     with pytest.raises(AttributeError, match="Property selected_tool is not writable"):
133 |         after_tool_event.selected_tool = None
134 |     with pytest.raises(AttributeError, match="Property tool_use is not writable"):
135 |         after_tool_event.tool_use = ToolUse(name="new", toolUseId="456", input={})
136 |     with pytest.raises(AttributeError, match="Property invocation_state is not writable"):
137 |         after_tool_event.invocation_state = {}
138 |     with pytest.raises(AttributeError, match="Property exception is not writable"):
139 |         after_tool_event.exception = Exception("test")
140 | 


--------------------------------------------------------------------------------
/tests/strands/experimental/hooks/test_hook_registry.py:
--------------------------------------------------------------------------------
  1 | import unittest.mock
  2 | from dataclasses import dataclass
  3 | from typing import List
  4 | from unittest.mock import MagicMock, Mock
  5 | 
  6 | import pytest
  7 | 
  8 | from strands.hooks import HookEvent, HookProvider, HookRegistry
  9 | 
 10 | 
 11 | @dataclass
 12 | class TestEvent(HookEvent):
 13 |     @property
 14 |     def should_reverse_callbacks(self) -> bool:
 15 |         return False
 16 | 
 17 | 
 18 | @dataclass
 19 | class TestAfterEvent(HookEvent):
 20 |     @property
 21 |     def should_reverse_callbacks(self) -> bool:
 22 |         return True
 23 | 
 24 | 
 25 | class TestHookProvider(HookProvider):
 26 |     """Test hook provider for testing hook registry."""
 27 | 
 28 |     def __init__(self):
 29 |         self.registered = False
 30 | 
 31 |     def register_hooks(self, registry: HookRegistry) -> None:
 32 |         self.registered = True
 33 | 
 34 | 
 35 | @pytest.fixture
 36 | def hook_registry():
 37 |     return HookRegistry()
 38 | 
 39 | 
 40 | @pytest.fixture
 41 | def test_event():
 42 |     return TestEvent(agent=Mock())
 43 | 
 44 | 
 45 | @pytest.fixture
 46 | def test_after_event():
 47 |     return TestAfterEvent(agent=Mock())
 48 | 
 49 | 
 50 | def test_hook_registry_init():
 51 |     """Test that HookRegistry initializes with an empty callbacks dictionary."""
 52 |     registry = HookRegistry()
 53 |     assert registry._registered_callbacks == {}
 54 | 
 55 | 
 56 | def test_add_callback(hook_registry, test_event):
 57 |     """Test that callbacks can be added to the registry."""
 58 |     callback = unittest.mock.Mock()
 59 |     hook_registry.add_callback(TestEvent, callback)
 60 | 
 61 |     assert TestEvent in hook_registry._registered_callbacks
 62 |     assert callback in hook_registry._registered_callbacks[TestEvent]
 63 | 
 64 | 
 65 | def test_add_multiple_callbacks_same_event(hook_registry, test_event):
 66 |     """Test that multiple callbacks can be added for the same event type."""
 67 |     callback1 = unittest.mock.Mock()
 68 |     callback2 = unittest.mock.Mock()
 69 | 
 70 |     hook_registry.add_callback(TestEvent, callback1)
 71 |     hook_registry.add_callback(TestEvent, callback2)
 72 | 
 73 |     assert len(hook_registry._registered_callbacks[TestEvent]) == 2
 74 |     assert callback1 in hook_registry._registered_callbacks[TestEvent]
 75 |     assert callback2 in hook_registry._registered_callbacks[TestEvent]
 76 | 
 77 | 
 78 | def test_add_hook(hook_registry):
 79 |     """Test that hooks can be added to the registry."""
 80 |     hook_provider = MagicMock()
 81 |     hook_registry.add_hook(hook_provider)
 82 | 
 83 |     assert hook_provider.register_hooks.call_count == 1
 84 | 
 85 | 
 86 | def test_get_callbacks_for_normal_event(hook_registry, test_event):
 87 |     """Test that get_callbacks_for returns callbacks in the correct order for normal events."""
 88 |     callback1 = unittest.mock.Mock()
 89 |     callback2 = unittest.mock.Mock()
 90 | 
 91 |     hook_registry.add_callback(TestEvent, callback1)
 92 |     hook_registry.add_callback(TestEvent, callback2)
 93 | 
 94 |     callbacks = list(hook_registry.get_callbacks_for(test_event))
 95 | 
 96 |     assert len(callbacks) == 2
 97 |     assert callbacks[0] == callback1
 98 |     assert callbacks[1] == callback2
 99 | 
100 | 
101 | def test_get_callbacks_for_after_event(hook_registry, test_after_event):
102 |     """Test that get_callbacks_for returns callbacks in reverse order for after events."""
103 |     callback1 = Mock()
104 |     callback2 = Mock()
105 | 
106 |     hook_registry.add_callback(TestAfterEvent, callback1)
107 |     hook_registry.add_callback(TestAfterEvent, callback2)
108 | 
109 |     callbacks = list(hook_registry.get_callbacks_for(test_after_event))
110 | 
111 |     assert len(callbacks) == 2
112 |     assert callbacks[0] == callback2  # Reverse order
113 |     assert callbacks[1] == callback1  # Reverse order
114 | 
115 | 
116 | def test_invoke_callbacks(hook_registry, test_event):
117 |     """Test that invoke_callbacks calls all registered callbacks for an event."""
118 |     callback1 = Mock()
119 |     callback2 = Mock()
120 | 
121 |     hook_registry.add_callback(TestEvent, callback1)
122 |     hook_registry.add_callback(TestEvent, callback2)
123 | 
124 |     hook_registry.invoke_callbacks(test_event)
125 | 
126 |     callback1.assert_called_once_with(test_event)
127 |     callback2.assert_called_once_with(test_event)
128 | 
129 | 
130 | def test_invoke_callbacks_no_registered_callbacks(hook_registry, test_event):
131 |     """Test that invoke_callbacks doesn't fail when there are no registered callbacks."""
132 |     # No callbacks registered
133 |     hook_registry.invoke_callbacks(test_event)
134 |     # Test passes if no exception is raised
135 | 
136 | 
137 | def test_invoke_callbacks_after_event(hook_registry, test_after_event):
138 |     """Test that invoke_callbacks calls callbacks in reverse order for after events."""
139 |     call_order: List[str] = []
140 | 
141 |     def callback1(_event):
142 |         call_order.append("callback1")
143 | 
144 |     def callback2(_event):
145 |         call_order.append("callback2")
146 | 
147 |     hook_registry.add_callback(TestAfterEvent, callback1)
148 |     hook_registry.add_callback(TestAfterEvent, callback2)
149 | 
150 |     hook_registry.invoke_callbacks(test_after_event)
151 | 
152 |     assert call_order == ["callback2", "callback1"]  # Reverse order
153 | 
154 | 
155 | def test_has_callbacks(hook_registry, test_event):
156 |     """Test that has_callbacks returns correct boolean values."""
157 |     # Empty registry should return False
158 |     assert not hook_registry.has_callbacks()
159 | 
160 |     # Registry with callbacks should return True
161 |     callback = Mock()
162 |     hook_registry.add_callback(TestEvent, callback)
163 |     assert hook_registry.has_callbacks()
164 | 
165 |     # Test with multiple event types
166 |     hook_registry.add_callback(TestAfterEvent, Mock())
167 |     assert hook_registry.has_callbacks()
168 | 


--------------------------------------------------------------------------------
/tests/strands/handlers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/strands-agents/sdk-python/f5e24d402bb22dc1a54c94dd030b4be0f7e73261/tests/strands/handlers/__init__.py


--------------------------------------------------------------------------------
/tests/strands/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/strands-agents/sdk-python/f5e24d402bb22dc1a54c94dd030b4be0f7e73261/tests/strands/models/__init__.py


--------------------------------------------------------------------------------
/tests/strands/models/test_model.py:
--------------------------------------------------------------------------------
  1 | import pytest
  2 | from pydantic import BaseModel
  3 | 
  4 | from strands.models import Model as SAModel
  5 | 
  6 | 
  7 | class Person(BaseModel):
  8 |     name: str
  9 |     age: int
 10 | 
 11 | 
 12 | class TestModel(SAModel):
 13 |     def update_config(self, **model_config):
 14 |         return model_config
 15 | 
 16 |     def get_config(self):
 17 |         return
 18 | 
 19 |     async def structured_output(self, output_model, prompt=None, system_prompt=None, **kwargs):
 20 |         yield {"output": output_model(name="test", age=20)}
 21 | 
 22 |     async def stream(self, messages, tool_specs=None, system_prompt=None):
 23 |         yield {"messageStart": {"role": "assistant"}}
 24 |         yield {"contentBlockStart": {"start": {}}}
 25 |         yield {"contentBlockDelta": {"delta": {"text": f"Processed {len(messages)} messages"}}}
 26 |         yield {"contentBlockStop": {}}
 27 |         yield {"messageStop": {"stopReason": "end_turn"}}
 28 |         yield {
 29 |             "metadata": {
 30 |                 "usage": {"inputTokens": 10, "outputTokens": 15, "totalTokens": 25},
 31 |                 "metrics": {"latencyMs": 100},
 32 |             }
 33 |         }
 34 | 
 35 | 
 36 | @pytest.fixture
 37 | def model():
 38 |     return TestModel()
 39 | 
 40 | 
 41 | @pytest.fixture
 42 | def messages():
 43 |     return [
 44 |         {
 45 |             "role": "user",
 46 |             "content": [{"text": "hello"}],
 47 |         },
 48 |     ]
 49 | 
 50 | 
 51 | @pytest.fixture
 52 | def tool_specs():
 53 |     return [
 54 |         {
 55 |             "name": "test_tool",
 56 |             "description": "A test tool",
 57 |             "inputSchema": {
 58 |                 "json": {
 59 |                     "type": "object",
 60 |                     "properties": {
 61 |                         "input": {"type": "string"},
 62 |                     },
 63 |                     "required": ["input"],
 64 |                 },
 65 |             },
 66 |         },
 67 |     ]
 68 | 
 69 | 
 70 | @pytest.fixture
 71 | def system_prompt():
 72 |     return "s1"
 73 | 
 74 | 
 75 | @pytest.mark.asyncio
 76 | async def test_stream(model, messages, tool_specs, system_prompt, alist):
 77 |     response = model.stream(messages, tool_specs, system_prompt)
 78 | 
 79 |     tru_events = await alist(response)
 80 |     exp_events = [
 81 |         {"messageStart": {"role": "assistant"}},
 82 |         {"contentBlockStart": {"start": {}}},
 83 |         {"contentBlockDelta": {"delta": {"text": "Processed 1 messages"}}},
 84 |         {"contentBlockStop": {}},
 85 |         {"messageStop": {"stopReason": "end_turn"}},
 86 |         {
 87 |             "metadata": {
 88 |                 "usage": {"inputTokens": 10, "outputTokens": 15, "totalTokens": 25},
 89 |                 "metrics": {"latencyMs": 100},
 90 |             }
 91 |         },
 92 |     ]
 93 |     assert tru_events == exp_events
 94 | 
 95 | 
 96 | @pytest.mark.asyncio
 97 | async def test_structured_output(model, alist):
 98 |     response = model.structured_output(Person, prompt=messages, system_prompt=system_prompt)
 99 |     events = await alist(response)
100 | 
101 |     tru_output = events[-1]["output"]
102 |     exp_output = Person(name="test", age=20)
103 |     assert tru_output == exp_output
104 | 


--------------------------------------------------------------------------------
/tests/strands/multiagent/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for the multiagent module."""
2 | 


--------------------------------------------------------------------------------
/tests/strands/multiagent/a2a/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for the A2A module."""
2 | 


--------------------------------------------------------------------------------
/tests/strands/multiagent/a2a/conftest.py:
--------------------------------------------------------------------------------
 1 | """Common fixtures for A2A module tests."""
 2 | 
 3 | from unittest.mock import AsyncMock, MagicMock
 4 | 
 5 | import pytest
 6 | from a2a.server.agent_execution import RequestContext
 7 | from a2a.server.events import EventQueue
 8 | 
 9 | from strands.agent.agent import Agent as SAAgent
10 | from strands.agent.agent_result import AgentResult as SAAgentResult
11 | 
12 | 
13 | @pytest.fixture
14 | def mock_strands_agent():
15 |     """Create a mock Strands Agent for testing."""
16 |     agent = MagicMock(spec=SAAgent)
17 |     agent.name = "Test Agent"
18 |     agent.description = "A test agent for unit testing"
19 | 
20 |     # Setup default response
21 |     mock_result = MagicMock(spec=SAAgentResult)
22 |     mock_result.message = {"content": [{"text": "Test response"}]}
23 |     agent.return_value = mock_result
24 | 
25 |     # Setup async methods
26 |     agent.invoke_async = AsyncMock(return_value=mock_result)
27 |     agent.stream_async = AsyncMock(return_value=iter([]))
28 | 
29 |     # Setup mock tool registry
30 |     mock_tool_registry = MagicMock()
31 |     mock_tool_registry.get_all_tools_config.return_value = {}
32 |     agent.tool_registry = mock_tool_registry
33 | 
34 |     return agent
35 | 
36 | 
37 | @pytest.fixture
38 | def mock_request_context():
39 |     """Create a mock RequestContext for testing."""
40 |     context = MagicMock(spec=RequestContext)
41 |     context.get_user_input.return_value = "Test input"
42 |     return context
43 | 
44 | 
45 | @pytest.fixture
46 | def mock_event_queue():
47 |     """Create a mock EventQueue for testing."""
48 |     queue = MagicMock(spec=EventQueue)
49 |     queue.enqueue_event = AsyncMock()
50 |     return queue
51 | 


--------------------------------------------------------------------------------
/tests/strands/session/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for session management."""
2 | 


--------------------------------------------------------------------------------
/tests/strands/tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/strands-agents/sdk-python/f5e24d402bb22dc1a54c94dd030b4be0f7e73261/tests/strands/tools/__init__.py


--------------------------------------------------------------------------------
/tests/strands/tools/mcp/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/strands-agents/sdk-python/f5e24d402bb22dc1a54c94dd030b4be0f7e73261/tests/strands/tools/mcp/__init__.py


--------------------------------------------------------------------------------
/tests/strands/tools/mcp/test_mcp_agent_tool.py:
--------------------------------------------------------------------------------
 1 | from unittest.mock import MagicMock
 2 | 
 3 | import pytest
 4 | from mcp.types import Tool as MCPTool
 5 | 
 6 | from strands.tools.mcp import MCPAgentTool, MCPClient
 7 | 
 8 | 
 9 | @pytest.fixture
10 | def mock_mcp_tool():
11 |     mock_tool = MagicMock(spec=MCPTool)
12 |     mock_tool.name = "test_tool"
13 |     mock_tool.description = "A test tool"
14 |     mock_tool.inputSchema = {"type": "object", "properties": {}}
15 |     return mock_tool
16 | 
17 | 
18 | @pytest.fixture
19 | def mock_mcp_client():
20 |     mock_server = MagicMock(spec=MCPClient)
21 |     mock_server.call_tool_sync.return_value = {
22 |         "status": "success",
23 |         "toolUseId": "test-123",
24 |         "content": [{"text": "Success result"}],
25 |     }
26 |     return mock_server
27 | 
28 | 
29 | @pytest.fixture
30 | def mcp_agent_tool(mock_mcp_tool, mock_mcp_client):
31 |     return MCPAgentTool(mock_mcp_tool, mock_mcp_client)
32 | 
33 | 
34 | def test_tool_name(mcp_agent_tool, mock_mcp_tool):
35 |     assert mcp_agent_tool.tool_name == "test_tool"
36 |     assert mcp_agent_tool.tool_name == mock_mcp_tool.name
37 | 
38 | 
39 | def test_tool_type(mcp_agent_tool):
40 |     assert mcp_agent_tool.tool_type == "python"
41 | 
42 | 
43 | def test_tool_spec_with_description(mcp_agent_tool, mock_mcp_tool):
44 |     tool_spec = mcp_agent_tool.tool_spec
45 | 
46 |     assert tool_spec["name"] == "test_tool"
47 |     assert tool_spec["description"] == "A test tool"
48 |     assert tool_spec["inputSchema"]["json"] == {"type": "object", "properties": {}}
49 | 
50 | 
51 | def test_tool_spec_without_description(mock_mcp_tool, mock_mcp_client):
52 |     mock_mcp_tool.description = None
53 | 
54 |     agent_tool = MCPAgentTool(mock_mcp_tool, mock_mcp_client)
55 |     tool_spec = agent_tool.tool_spec
56 | 
57 |     assert tool_spec["description"] == "Tool which performs test_tool"
58 | 
59 | 
60 | @pytest.mark.asyncio
61 | async def test_stream(mcp_agent_tool, mock_mcp_client, alist):
62 |     tool_use = {"toolUseId": "test-123", "name": "test_tool", "input": {"param": "value"}}
63 | 
64 |     tru_events = await alist(mcp_agent_tool.stream(tool_use, {}))
65 |     exp_events = [mock_mcp_client.call_tool_async.return_value]
66 | 
67 |     assert tru_events == exp_events
68 |     mock_mcp_client.call_tool_async.assert_called_once_with(
69 |         tool_use_id="test-123", name="test_tool", arguments={"param": "value"}
70 |     )
71 | 


--------------------------------------------------------------------------------
/tests/strands/tools/test_registry.py:
--------------------------------------------------------------------------------
 1 | """
 2 | Tests for the SDK tool registry module.
 3 | """
 4 | 
 5 | from unittest.mock import MagicMock
 6 | 
 7 | import pytest
 8 | 
 9 | import strands
10 | from strands.tools import PythonAgentTool
11 | from strands.tools.decorator import DecoratedFunctionTool, tool
12 | from strands.tools.registry import ToolRegistry
13 | 
14 | 
15 | def test_load_tool_from_filepath_failure():
16 |     """Test error handling when load_tool fails."""
17 |     tool_registry = ToolRegistry()
18 |     error_message = "Failed to load tool failing_tool: Tool file not found: /path/to/failing_tool.py"
19 | 
20 |     with pytest.raises(ValueError, match=error_message):
21 |         tool_registry.load_tool_from_filepath("failing_tool", "/path/to/failing_tool.py")
22 | 
23 | 
24 | def test_process_tools_with_invalid_path():
25 |     """Test that process_tools raises an exception when a non-path string is passed."""
26 |     tool_registry = ToolRegistry()
27 |     invalid_path = "not a filepath"
28 | 
29 |     with pytest.raises(ValueError, match=f"Failed to load tool {invalid_path.split('.')[0]}: Tool file not found:.*"):
30 |         tool_registry.process_tools([invalid_path])
31 | 
32 | 
33 | def test_register_tool_with_similar_name_raises():
34 |     tool_1 = PythonAgentTool(tool_name="tool-like-this", tool_spec=MagicMock(), tool_func=lambda: None)
35 |     tool_2 = PythonAgentTool(tool_name="tool_like_this", tool_spec=MagicMock(), tool_func=lambda: None)
36 | 
37 |     tool_registry = ToolRegistry()
38 | 
39 |     tool_registry.register_tool(tool_1)
40 | 
41 |     with pytest.raises(ValueError) as err:
42 |         tool_registry.register_tool(tool_2)
43 | 
44 |     assert (
45 |         str(err.value) == "Tool name 'tool_like_this' already exists as 'tool-like-this'. "
46 |         "Cannot add a duplicate tool which differs by a '-' or '_'"
47 |     )
48 | 
49 | 
50 | def test_get_all_tool_specs_returns_right_tool_specs():
51 |     tool_1 = strands.tool(lambda a: a, name="tool_1")
52 |     tool_2 = strands.tool(lambda b: b, name="tool_2")
53 | 
54 |     tool_registry = ToolRegistry()
55 | 
56 |     tool_registry.register_tool(tool_1)
57 |     tool_registry.register_tool(tool_2)
58 | 
59 |     tool_specs = tool_registry.get_all_tool_specs()
60 | 
61 |     assert tool_specs == [
62 |         tool_1.tool_spec,
63 |         tool_2.tool_spec,
64 |     ]
65 | 
66 | 
67 | def test_scan_module_for_tools():
68 |     @tool
69 |     def tool_function_1(a):
70 |         return a
71 | 
72 |     @tool
73 |     def tool_function_2(b):
74 |         return b
75 | 
76 |     def tool_function_3(c):
77 |         return c
78 | 
79 |     def tool_function_4(d):
80 |         return d
81 | 
82 |     tool_function_4.tool_spec = "invalid"
83 | 
84 |     mock_module = MagicMock()
85 |     mock_module.tool_function_1 = tool_function_1
86 |     mock_module.tool_function_2 = tool_function_2
87 |     mock_module.tool_function_3 = tool_function_3
88 |     mock_module.tool_function_4 = tool_function_4
89 | 
90 |     tool_registry = ToolRegistry()
91 | 
92 |     tools = tool_registry._scan_module_for_tools(mock_module)
93 | 
94 |     assert len(tools) == 2
95 |     assert all(isinstance(tool, DecoratedFunctionTool) for tool in tools)
96 | 


--------------------------------------------------------------------------------
/tests/strands/tools/test_watcher.py:
--------------------------------------------------------------------------------
 1 | """
 2 | Tests for the SDK tool watcher module.
 3 | """
 4 | 
 5 | from unittest.mock import MagicMock, patch
 6 | 
 7 | import pytest
 8 | 
 9 | from strands.tools.registry import ToolRegistry
10 | from strands.tools.watcher import ToolWatcher
11 | 
12 | 
13 | def test_tool_watcher_initialization():
14 |     """Test that the handler initializes with the correct tool registry."""
15 |     tool_registry = ToolRegistry()
16 |     watcher = ToolWatcher(tool_registry)
17 |     assert watcher.tool_registry == tool_registry
18 | 
19 | 
20 | @pytest.mark.parametrize(
21 |     "test_case",
22 |     [
23 |         # Regular Python file - should reload
24 |         {
25 |             "description": "Python file",
26 |             "src_path": "/path/to/test_tool.py",
27 |             "is_directory": False,
28 |             "should_reload": True,
29 |             "expected_tool_name": "test_tool",
30 |         },
31 |         # Non-Python file - should not reload
32 |         {
33 |             "description": "Non-Python file",
34 |             "src_path": "/path/to/test_tool.txt",
35 |             "is_directory": False,
36 |             "should_reload": False,
37 |         },
38 |         # __init__.py file - should not reload
39 |         {
40 |             "description": "Init file",
41 |             "src_path": "/path/to/__init__.py",
42 |             "is_directory": False,
43 |             "should_reload": False,
44 |         },
45 |         # Directory path - should not reload
46 |         {
47 |             "description": "Directory path",
48 |             "src_path": "/path/to/tools_directory",
49 |             "is_directory": True,
50 |             "should_reload": False,
51 |         },
52 |         # Python file marked as directory - should still reload
53 |         {
54 |             "description": "Python file marked as directory",
55 |             "src_path": "/path/to/test_tool2.py",
56 |             "is_directory": True,
57 |             "should_reload": True,
58 |             "expected_tool_name": "test_tool2",
59 |         },
60 |     ],
61 | )
62 | @patch.object(ToolRegistry, "reload_tool")
63 | def test_on_modified_cases(mock_reload_tool, test_case):
64 |     """Test various cases for the on_modified method."""
65 |     tool_registry = ToolRegistry()
66 |     watcher = ToolWatcher(tool_registry)
67 | 
68 |     # Create a mock event with the specified properties
69 |     event = MagicMock()
70 |     event.src_path = test_case["src_path"]
71 |     if "is_directory" in test_case:
72 |         event.is_directory = test_case["is_directory"]
73 | 
74 |     # Call the on_modified method
75 |     watcher.tool_change_handler.on_modified(event)
76 | 
77 |     # Verify the expected behavior
78 |     if test_case["should_reload"]:
79 |         mock_reload_tool.assert_called_once_with(test_case["expected_tool_name"])
80 |     else:
81 |         mock_reload_tool.assert_not_called()
82 | 
83 | 
84 | @patch.object(ToolRegistry, "reload_tool", side_effect=Exception("Test error"))
85 | def test_on_modified_error_handling(mock_reload_tool):
86 |     """Test that on_modified handles errors during tool reloading."""
87 |     tool_registry = ToolRegistry()
88 |     watcher = ToolWatcher(tool_registry)
89 | 
90 |     # Create a mock event with a Python file path
91 |     event = MagicMock()
92 |     event.src_path = "/path/to/test_tool.py"
93 | 
94 |     # Call the on_modified method - should not raise an exception
95 |     watcher.tool_change_handler.on_modified(event)
96 | 
97 |     # Verify that reload_tool was called
98 |     mock_reload_tool.assert_called_once_with("test_tool")
99 | 


--------------------------------------------------------------------------------
/tests/strands/types/test_session.py:
--------------------------------------------------------------------------------
 1 | import json
 2 | from uuid import uuid4
 3 | 
 4 | from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager
 5 | from strands.types.session import (
 6 |     Session,
 7 |     SessionAgent,
 8 |     SessionMessage,
 9 |     SessionType,
10 |     decode_bytes_values,
11 |     encode_bytes_values,
12 | )
13 | 
14 | 
15 | def test_session_json_serializable():
16 |     session = Session(session_id=str(uuid4()), session_type=SessionType.AGENT)
17 |     # json dumps will fail if its not json serializable
18 |     session_json_string = json.dumps(session.to_dict())
19 |     loaded_session = Session.from_dict(json.loads(session_json_string))
20 |     assert loaded_session is not None
21 | 
22 | 
23 | def test_agent_json_serializable():
24 |     agent = SessionAgent(
25 |         agent_id=str(uuid4()), state={"foo": "bar"}, conversation_manager_state=NullConversationManager().get_state()
26 |     )
27 |     # json dumps will fail if its not json serializable
28 |     agent_json_string = json.dumps(agent.to_dict())
29 |     loaded_agent = SessionAgent.from_dict(json.loads(agent_json_string))
30 |     assert loaded_agent is not None
31 | 
32 | 
33 | def test_message_json_serializable():
34 |     message = SessionMessage(message={"role": "user", "content": [{"text": "Hello!"}]}, message_id=0)
35 |     # json dumps will fail if its not json serializable
36 |     message_json_string = json.dumps(message.to_dict())
37 |     loaded_message = SessionMessage.from_dict(json.loads(message_json_string))
38 |     assert loaded_message is not None
39 | 
40 | 
41 | def test_bytes_encoding_decoding():
42 |     # Test simple bytes
43 |     test_bytes = b"Hello, world!"
44 |     encoded = encode_bytes_values(test_bytes)
45 |     assert isinstance(encoded, dict)
46 |     assert encoded["__bytes_encoded__"] is True
47 |     decoded = decode_bytes_values(encoded)
48 |     assert decoded == test_bytes
49 | 
50 |     # Test nested structure with bytes
51 |     test_data = {
52 |         "text": "Hello",
53 |         "binary": b"Binary data",
54 |         "nested": {"more_binary": b"More binary data", "list_with_binary": [b"Item 1", "Text item", b"Item 3"]},
55 |     }
56 | 
57 |     encoded = encode_bytes_values(test_data)
58 |     # Verify it's JSON serializable
59 |     json_str = json.dumps(encoded)
60 |     # Deserialize and decode
61 |     decoded = decode_bytes_values(json.loads(json_str))
62 | 
63 |     # Verify the decoded data matches the original
64 |     assert decoded["text"] == test_data["text"]
65 |     assert decoded["binary"] == test_data["binary"]
66 |     assert decoded["nested"]["more_binary"] == test_data["nested"]["more_binary"]
67 |     assert decoded["nested"]["list_with_binary"][0] == test_data["nested"]["list_with_binary"][0]
68 |     assert decoded["nested"]["list_with_binary"][1] == test_data["nested"]["list_with_binary"][1]
69 |     assert decoded["nested"]["list_with_binary"][2] == test_data["nested"]["list_with_binary"][2]
70 | 
71 | 
72 | def test_session_message_with_bytes():
73 |     # Create a message with bytes content
74 |     message = {
75 |         "role": "user",
76 |         "content": [{"text": "Here is some binary data"}, {"binary_data": b"This is binary data"}],
77 |     }
78 | 
79 |     # Create a SessionMessage
80 |     session_message = SessionMessage.from_message(message, 0)
81 | 
82 |     # Verify it's JSON serializable
83 |     message_json_string = json.dumps(session_message.to_dict())
84 | 
85 |     # Load it back
86 |     loaded_message = SessionMessage.from_dict(json.loads(message_json_string))
87 | 
88 |     # Convert back to original message and verify
89 |     original_message = loaded_message.to_message()
90 | 
91 |     assert original_message["role"] == message["role"]
92 |     assert original_message["content"][0]["text"] == message["content"][0]["text"]
93 |     assert original_message["content"][1]["binary_data"] == message["content"][1]["binary_data"]
94 | 


--------------------------------------------------------------------------------
/tests_integ/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/strands-agents/sdk-python/f5e24d402bb22dc1a54c94dd030b4be0f7e73261/tests_integ/__init__.py


--------------------------------------------------------------------------------
/tests_integ/conftest.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | 
 3 | ## Data
 4 | 
 5 | 
 6 | @pytest.fixture
 7 | def yellow_img(pytestconfig):
 8 |     path = pytestconfig.rootdir / "tests_integ/yellow.png"
 9 |     with open(path, "rb") as fp:
10 |         return fp.read()
11 | 
12 | 
13 | ## Async
14 | 
15 | 
16 | @pytest.fixture(scope="session")
17 | def agenerator():
18 |     async def agenerator(items):
19 |         for item in items:
20 |             yield item
21 | 
22 |     return agenerator
23 | 
24 | 
25 | @pytest.fixture(scope="session")
26 | def alist():
27 |     async def alist(items):
28 |         return [item async for item in items]
29 | 
30 |     return alist
31 | 


--------------------------------------------------------------------------------
/tests_integ/echo_server.py:
--------------------------------------------------------------------------------
 1 | """
 2 | Echo Server for MCP Integration Testing
 3 | 
 4 | This module implements a simple echo server using the Model Context Protocol (MCP).
 5 | It provides a basic tool that echoes back any input string, which is useful for
 6 | testing the MCP communication flow and validating that messages are properly
 7 | transmitted between the client and server.
 8 | 
 9 | The server runs with stdio transport, making it suitable for integration tests
10 | where the client can spawn this process and communicate with it through standard
11 | input/output streams.
12 | 
13 | Usage:
14 |     Run this file directly to start the echo server:
15 |     $ python echo_server.py
16 | """
17 | 
18 | from mcp.server import FastMCP
19 | 
20 | 
21 | def start_echo_server():
22 |     """
23 |     Initialize and start the MCP echo server.
24 | 
25 |     Creates a FastMCP server instance with a single 'echo' tool that returns
26 |     any input string back to the caller. The server uses stdio transport
27 |     for communication.
28 |     """
29 |     mcp = FastMCP("Echo Server")
30 | 
31 |     @mcp.tool(description="Echos response back to the user")
32 |     def echo(to_echo: str) -> str:
33 |         return to_echo
34 | 
35 |     mcp.run(transport="stdio")
36 | 
37 | 
38 | if __name__ == "__main__":
39 |     start_echo_server()
40 | 


--------------------------------------------------------------------------------
/tests_integ/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/strands-agents/sdk-python/f5e24d402bb22dc1a54c94dd030b4be0f7e73261/tests_integ/models/__init__.py


--------------------------------------------------------------------------------
/tests_integ/models/conformance.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | 
 3 | from strands.types.models import Model
 4 | from tests_integ.models.providers import ProviderInfo, all_providers
 5 | 
 6 | 
 7 | def get_models():
 8 |     return [
 9 |         pytest.param(
10 |             provider_info,
11 |             id=provider_info.id,  # Adds the provider name to the test name
12 |             marks=[provider_info.mark],  # ignores tests that don't have the requirements
13 |         )
14 |         for provider_info in all_providers
15 |     ]
16 | 
17 | 
18 | @pytest.fixture(params=get_models())
19 | def provider_info(request) -> ProviderInfo:
20 |     return request.param
21 | 
22 | 
23 | @pytest.fixture()
24 | def model(provider_info):
25 |     return provider_info.create_model()
26 | 
27 | 
28 | def test_model_can_be_constructed(model: Model):
29 |     assert model is not None
30 |     pass
31 | 


--------------------------------------------------------------------------------
/tests_integ/models/providers.py:
--------------------------------------------------------------------------------
  1 | """
  2 | Aggregates all providers for testing all providers in one go.
  3 | """
  4 | 
  5 | import os
  6 | from typing import Callable, Optional
  7 | 
  8 | import requests
  9 | from pytest import mark
 10 | 
 11 | from strands.models import BedrockModel, Model
 12 | from strands.models.anthropic import AnthropicModel
 13 | from strands.models.litellm import LiteLLMModel
 14 | from strands.models.llamaapi import LlamaAPIModel
 15 | from strands.models.mistral import MistralModel
 16 | from strands.models.ollama import OllamaModel
 17 | from strands.models.openai import OpenAIModel
 18 | from strands.models.writer import WriterModel
 19 | 
 20 | 
 21 | class ProviderInfo:
 22 |     """Provider-based info for providers that require an APIKey via environment variables."""
 23 | 
 24 |     def __init__(
 25 |         self,
 26 |         id: str,
 27 |         factory: Callable[[], Model],
 28 |         environment_variable: Optional[str] = None,
 29 |     ) -> None:
 30 |         self.id = id
 31 |         self.model_factory = factory
 32 |         self.mark = mark.skipif(
 33 |             environment_variable is not None and environment_variable not in os.environ,
 34 |             reason=f"{environment_variable} environment variable missing",
 35 |         )
 36 | 
 37 |     def create_model(self) -> Model:
 38 |         return self.model_factory()
 39 | 
 40 | 
 41 | class OllamaProviderInfo(ProviderInfo):
 42 |     """Special case ollama as it's dependent on the server being available."""
 43 | 
 44 |     def __init__(self):
 45 |         super().__init__(
 46 |             id="ollama", factory=lambda: OllamaModel(host="http://localhost:11434", model_id="llama3.3:70b")
 47 |         )
 48 | 
 49 |         is_server_available = False
 50 |         try:
 51 |             is_server_available = requests.get("http://localhost:11434").ok
 52 |         except requests.exceptions.ConnectionError:
 53 |             pass
 54 | 
 55 |         self.mark = mark.skipif(
 56 |             not is_server_available,
 57 |             reason="Local Ollama endpoint not available at localhost:11434",
 58 |         )
 59 | 
 60 | 
 61 | anthropic = ProviderInfo(
 62 |     id="anthropic",
 63 |     environment_variable="ANTHROPIC_API_KEY",
 64 |     factory=lambda: AnthropicModel(
 65 |         client_args={
 66 |             "api_key": os.getenv("ANTHROPIC_API_KEY"),
 67 |         },
 68 |         model_id="claude-3-7-sonnet-20250219",
 69 |         max_tokens=512,
 70 |     ),
 71 | )
 72 | bedrock = ProviderInfo(id="bedrock", factory=lambda: BedrockModel())
 73 | cohere = ProviderInfo(
 74 |     id="cohere",
 75 |     environment_variable="CO_API_KEY",
 76 |     factory=lambda: OpenAIModel(
 77 |         client_args={
 78 |             "base_url": "https://api.cohere.com/compatibility/v1",
 79 |             "api_key": os.getenv("CO_API_KEY"),
 80 |         },
 81 |         model_id="command-a-03-2025",
 82 |         params={"stream_options": None},
 83 |     ),
 84 | )
 85 | litellm = ProviderInfo(
 86 |     id="litellm", factory=lambda: LiteLLMModel(model_id="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0")
 87 | )
 88 | llama = ProviderInfo(
 89 |     id="llama",
 90 |     environment_variable="LLAMA_API_KEY",
 91 |     factory=lambda: LlamaAPIModel(
 92 |         model_id="Llama-4-Maverick-17B-128E-Instruct-FP8",
 93 |         client_args={
 94 |             "api_key": os.getenv("LLAMA_API_KEY"),
 95 |         },
 96 |     ),
 97 | )
 98 | mistral = ProviderInfo(
 99 |     id="mistral",
100 |     environment_variable="MISTRAL_API_KEY",
101 |     factory=lambda: MistralModel(
102 |         model_id="mistral-medium-latest",
103 |         api_key=os.getenv("MISTRAL_API_KEY"),
104 |         stream=True,
105 |         temperature=0.7,
106 |         max_tokens=1000,
107 |         top_p=0.9,
108 |     ),
109 | )
110 | openai = ProviderInfo(
111 |     id="openai",
112 |     environment_variable="OPENAI_API_KEY",
113 |     factory=lambda: OpenAIModel(
114 |         model_id="gpt-4o",
115 |         client_args={
116 |             "api_key": os.getenv("OPENAI_API_KEY"),
117 |         },
118 |     ),
119 | )
120 | writer = ProviderInfo(
121 |     id="writer",
122 |     environment_variable="WRITER_API_KEY",
123 |     factory=lambda: WriterModel(
124 |         model_id="palmyra-x4",
125 |         client_args={"api_key": os.getenv("WRITER_API_KEY", "")},
126 |         stream_options={"include_usage": True},
127 |     ),
128 | )
129 | 
130 | ollama = OllamaProviderInfo()
131 | 
132 | 
133 | all_providers = [
134 |     bedrock,
135 |     anthropic,
136 |     cohere,
137 |     llama,
138 |     litellm,
139 |     mistral,
140 |     openai,
141 |     writer,
142 | ]
143 | 


--------------------------------------------------------------------------------
/tests_integ/models/test_model_anthropic.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | 
  3 | import pydantic
  4 | import pytest
  5 | 
  6 | import strands
  7 | from strands import Agent
  8 | from strands.models.anthropic import AnthropicModel
  9 | from tests_integ.models import providers
 10 | 
 11 | # these tests only run if we have the anthropic api key
 12 | pytestmark = providers.anthropic.mark
 13 | 
 14 | 
 15 | @pytest.fixture
 16 | def model():
 17 |     return AnthropicModel(
 18 |         client_args={
 19 |             "api_key": os.getenv("ANTHROPIC_API_KEY"),
 20 |         },
 21 |         model_id="claude-3-7-sonnet-20250219",
 22 |         max_tokens=512,
 23 |     )
 24 | 
 25 | 
 26 | @pytest.fixture
 27 | def tools():
 28 |     @strands.tool
 29 |     def tool_time() -> str:
 30 |         return "12:00"
 31 | 
 32 |     @strands.tool
 33 |     def tool_weather() -> str:
 34 |         return "sunny"
 35 | 
 36 |     return [tool_time, tool_weather]
 37 | 
 38 | 
 39 | @pytest.fixture
 40 | def system_prompt():
 41 |     return "You are an AI assistant."
 42 | 
 43 | 
 44 | @pytest.fixture
 45 | def agent(model, tools, system_prompt):
 46 |     return Agent(model=model, tools=tools, system_prompt=system_prompt)
 47 | 
 48 | 
 49 | @pytest.fixture
 50 | def weather():
 51 |     class Weather(pydantic.BaseModel):
 52 |         """Extracts the time and weather from the user's message with the exact strings."""
 53 | 
 54 |         time: str
 55 |         weather: str
 56 | 
 57 |     return Weather(time="12:00", weather="sunny")
 58 | 
 59 | 
 60 | @pytest.fixture
 61 | def yellow_color():
 62 |     class Color(pydantic.BaseModel):
 63 |         """Describes a color."""
 64 | 
 65 |         name: str
 66 | 
 67 |         @pydantic.field_validator("name", mode="after")
 68 |         @classmethod
 69 |         def lower(_, value):
 70 |             return value.lower()
 71 | 
 72 |     return Color(name="yellow")
 73 | 
 74 | 
 75 | def test_agent_invoke(agent):
 76 |     result = agent("What is the time and weather in New York?")
 77 |     text = result.message["content"][0]["text"].lower()
 78 | 
 79 |     assert all(string in text for string in ["12:00", "sunny"])
 80 | 
 81 | 
 82 | @pytest.mark.asyncio
 83 | async def test_agent_invoke_async(agent):
 84 |     result = await agent.invoke_async("What is the time and weather in New York?")
 85 |     text = result.message["content"][0]["text"].lower()
 86 | 
 87 |     assert all(string in text for string in ["12:00", "sunny"])
 88 | 
 89 | 
 90 | @pytest.mark.asyncio
 91 | async def test_agent_stream_async(agent):
 92 |     stream = agent.stream_async("What is the time and weather in New York?")
 93 |     async for event in stream:
 94 |         _ = event
 95 | 
 96 |     result = event["result"]
 97 |     text = result.message["content"][0]["text"].lower()
 98 | 
 99 |     assert all(string in text for string in ["12:00", "sunny"])
100 | 
101 | 
102 | def test_structured_output(agent, weather):
103 |     tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny")
104 |     exp_weather = weather
105 |     assert tru_weather == exp_weather
106 | 
107 | 
108 | @pytest.mark.asyncio
109 | async def test_agent_structured_output_async(agent, weather):
110 |     tru_weather = await agent.structured_output_async(type(weather), "The time is 12:00 and the weather is sunny")
111 |     exp_weather = weather
112 |     assert tru_weather == exp_weather
113 | 
114 | 
115 | def test_invoke_multi_modal_input(agent, yellow_img):
116 |     content = [
117 |         {"text": "what is in this image"},
118 |         {
119 |             "image": {
120 |                 "format": "png",
121 |                 "source": {
122 |                     "bytes": yellow_img,
123 |                 },
124 |             },
125 |         },
126 |     ]
127 |     result = agent(content)
128 |     text = result.message["content"][0]["text"].lower()
129 | 
130 |     assert "yellow" in text
131 | 
132 | 
133 | def test_structured_output_multi_modal_input(agent, yellow_img, yellow_color):
134 |     content = [
135 |         {"text": "Is this image red, blue, or yellow?"},
136 |         {
137 |             "image": {
138 |                 "format": "png",
139 |                 "source": {
140 |                     "bytes": yellow_img,
141 |                 },
142 |             },
143 |         },
144 |     ]
145 |     tru_color = agent.structured_output(type(yellow_color), content)
146 |     exp_color = yellow_color
147 |     assert tru_color == exp_color
148 | 


--------------------------------------------------------------------------------
/tests_integ/models/test_model_bedrock.py:
--------------------------------------------------------------------------------
  1 | import pydantic
  2 | import pytest
  3 | 
  4 | import strands
  5 | from strands import Agent
  6 | from strands.models import BedrockModel
  7 | 
  8 | 
  9 | @pytest.fixture
 10 | def system_prompt():
 11 |     return "You are an AI assistant that uses & instead of ."
 12 | 
 13 | 
 14 | @pytest.fixture
 15 | def streaming_model():
 16 |     return BedrockModel(
 17 |         streaming=True,
 18 |     )
 19 | 
 20 | 
 21 | @pytest.fixture
 22 | def non_streaming_model():
 23 |     return BedrockModel(
 24 |         streaming=False,
 25 |     )
 26 | 
 27 | 
 28 | @pytest.fixture
 29 | def streaming_agent(streaming_model, system_prompt):
 30 |     return Agent(model=streaming_model, system_prompt=system_prompt, load_tools_from_directory=False)
 31 | 
 32 | 
 33 | @pytest.fixture
 34 | def non_streaming_agent(non_streaming_model, system_prompt):
 35 |     return Agent(model=non_streaming_model, system_prompt=system_prompt, load_tools_from_directory=False)
 36 | 
 37 | 
 38 | @pytest.fixture
 39 | def yellow_color():
 40 |     class Color(pydantic.BaseModel):
 41 |         """Describes a color."""
 42 | 
 43 |         name: str
 44 | 
 45 |         @pydantic.field_validator("name", mode="after")
 46 |         @classmethod
 47 |         def lower(_, value):
 48 |             return value.lower()
 49 | 
 50 |     return Color(name="yellow")
 51 | 
 52 | 
 53 | def test_streaming_agent(streaming_agent):
 54 |     """Test agent with streaming model."""
 55 |     result = streaming_agent("Hello!")
 56 | 
 57 |     assert len(str(result)) > 0
 58 | 
 59 | 
 60 | def test_non_streaming_agent(non_streaming_agent):
 61 |     """Test agent with non-streaming model."""
 62 |     result = non_streaming_agent("Hello!")
 63 | 
 64 |     assert len(str(result)) > 0
 65 | 
 66 | 
 67 | @pytest.mark.asyncio
 68 | async def test_streaming_model_events(streaming_model, alist):
 69 |     """Test streaming model events."""
 70 |     messages = [{"role": "user", "content": [{"text": "Hello"}]}]
 71 | 
 72 |     # Call stream and collect events
 73 |     events = await alist(streaming_model.stream(messages))
 74 | 
 75 |     # Verify basic structure of events
 76 |     assert any("messageStart" in event for event in events)
 77 |     assert any("contentBlockDelta" in event for event in events)
 78 |     assert any("messageStop" in event for event in events)
 79 | 
 80 | 
 81 | @pytest.mark.asyncio
 82 | async def test_non_streaming_model_events(non_streaming_model, alist):
 83 |     """Test non-streaming model events."""
 84 |     messages = [{"role": "user", "content": [{"text": "Hello"}]}]
 85 | 
 86 |     # Call stream and collect events
 87 |     events = await alist(non_streaming_model.stream(messages))
 88 | 
 89 |     # Verify basic structure of events
 90 |     assert any("messageStart" in event for event in events)
 91 |     assert any("contentBlockDelta" in event for event in events)
 92 |     assert any("messageStop" in event for event in events)
 93 | 
 94 | 
 95 | def test_tool_use_streaming(streaming_model):
 96 |     """Test tool use with streaming model."""
 97 | 
 98 |     tool_was_called = False
 99 | 
100 |     @strands.tool
101 |     def calculator(expression: str) -> float:
102 |         """Calculate the result of a mathematical expression."""
103 | 
104 |         nonlocal tool_was_called
105 |         tool_was_called = True
106 |         return eval(expression)
107 | 
108 |     agent = Agent(model=streaming_model, tools=[calculator], load_tools_from_directory=False)
109 |     result = agent("What is 123 + 456?")
110 | 
111 |     # Print the full message content for debugging
112 |     print("\nFull message content:")
113 |     import json
114 | 
115 |     print(json.dumps(result.message["content"], indent=2))
116 | 
117 |     assert tool_was_called
118 | 
119 | 
120 | def test_tool_use_non_streaming(non_streaming_model):
121 |     """Test tool use with non-streaming model."""
122 | 
123 |     tool_was_called = False
124 | 
125 |     @strands.tool
126 |     def calculator(expression: str) -> float:
127 |         """Calculate the result of a mathematical expression."""
128 | 
129 |         nonlocal tool_was_called
130 |         tool_was_called = True
131 |         return eval(expression)
132 | 
133 |     agent = Agent(model=non_streaming_model, tools=[calculator], load_tools_from_directory=False)
134 |     agent("What is 123 + 456?")
135 | 
136 |     assert tool_was_called
137 | 
138 | 
139 | def test_structured_output_streaming(streaming_model):
140 |     """Test structured output with streaming model."""
141 | 
142 |     class Weather(pydantic.BaseModel):
143 |         time: str
144 |         weather: str
145 | 
146 |     agent = Agent(model=streaming_model)
147 | 
148 |     result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny")
149 |     assert isinstance(result, Weather)
150 |     assert result.time == "12:00"
151 |     assert result.weather == "sunny"
152 | 
153 | 
154 | def test_structured_output_non_streaming(non_streaming_model):
155 |     """Test structured output with non-streaming model."""
156 | 
157 |     class Weather(pydantic.BaseModel):
158 |         time: str
159 |         weather: str
160 | 
161 |     agent = Agent(model=non_streaming_model)
162 | 
163 |     result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny")
164 |     assert isinstance(result, Weather)
165 |     assert result.time == "12:00"
166 |     assert result.weather == "sunny"
167 | 
168 | 
169 | def test_invoke_multi_modal_input(streaming_agent, yellow_img):
170 |     content = [
171 |         {"text": "what is in this image"},
172 |         {
173 |             "image": {
174 |                 "format": "png",
175 |                 "source": {
176 |                     "bytes": yellow_img,
177 |                 },
178 |             },
179 |         },
180 |     ]
181 |     result = streaming_agent(content)
182 |     text = result.message["content"][0]["text"].lower()
183 | 
184 |     assert "yellow" in text
185 | 
186 | 
187 | def test_structured_output_multi_modal_input(streaming_agent, yellow_img, yellow_color):
188 |     content = [
189 |         {"text": "Is this image red, blue, or yellow?"},
190 |         {
191 |             "image": {
192 |                 "format": "png",
193 |                 "source": {
194 |                     "bytes": yellow_img,
195 |                 },
196 |             },
197 |         },
198 |     ]
199 |     tru_color = streaming_agent.structured_output(type(yellow_color), content)
200 |     exp_color = yellow_color
201 |     assert tru_color == exp_color
202 | 


--------------------------------------------------------------------------------
/tests_integ/models/test_model_cohere.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | 
 3 | import pytest
 4 | 
 5 | import strands
 6 | from strands import Agent
 7 | from strands.models.openai import OpenAIModel
 8 | from tests_integ.models import providers
 9 | 
10 | # these tests only run if we have the cohere api key
11 | pytestmark = providers.cohere.mark
12 | 
13 | 
14 | @pytest.fixture
15 | def model():
16 |     return OpenAIModel(
17 |         client_args={
18 |             "base_url": "https://api.cohere.com/compatibility/v1",
19 |             "api_key": os.getenv("CO_API_KEY"),
20 |         },
21 |         model_id="command-a-03-2025",
22 |         params={"stream_options": None},
23 |     )
24 | 
25 | 
26 | @pytest.fixture
27 | def tools():
28 |     @strands.tool
29 |     def tool_time() -> str:
30 |         return "12:00"
31 | 
32 |     @strands.tool
33 |     def tool_weather() -> str:
34 |         return "sunny"
35 | 
36 |     return [tool_time, tool_weather]
37 | 
38 | 
39 | @pytest.fixture
40 | def agent(model, tools):
41 |     return Agent(model=model, tools=tools)
42 | 
43 | 
44 | def test_agent(agent):
45 |     result = agent("What is the time and weather in New York?")
46 |     text = result.message["content"][0]["text"].lower()
47 |     assert all(string in text for string in ["12:00", "sunny"])
48 | 


--------------------------------------------------------------------------------
/tests_integ/models/test_model_litellm.py:
--------------------------------------------------------------------------------
  1 | import pydantic
  2 | import pytest
  3 | 
  4 | import strands
  5 | from strands import Agent
  6 | from strands.models.litellm import LiteLLMModel
  7 | 
  8 | 
  9 | @pytest.fixture
 10 | def model():
 11 |     return LiteLLMModel(model_id="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0")
 12 | 
 13 | 
 14 | @pytest.fixture
 15 | def tools():
 16 |     @strands.tool
 17 |     def tool_time() -> str:
 18 |         return "12:00"
 19 | 
 20 |     @strands.tool
 21 |     def tool_weather() -> str:
 22 |         return "sunny"
 23 | 
 24 |     return [tool_time, tool_weather]
 25 | 
 26 | 
 27 | @pytest.fixture
 28 | def agent(model, tools):
 29 |     return Agent(model=model, tools=tools)
 30 | 
 31 | 
 32 | @pytest.fixture
 33 | def weather():
 34 |     class Weather(pydantic.BaseModel):
 35 |         """Extracts the time and weather from the user's message with the exact strings."""
 36 | 
 37 |         time: str
 38 |         weather: str
 39 | 
 40 |     return Weather(time="12:00", weather="sunny")
 41 | 
 42 | 
 43 | @pytest.fixture
 44 | def yellow_color():
 45 |     class Color(pydantic.BaseModel):
 46 |         """Describes a color."""
 47 | 
 48 |         name: str
 49 | 
 50 |         @pydantic.field_validator("name", mode="after")
 51 |         @classmethod
 52 |         def lower(_, value):
 53 |             return value.lower()
 54 | 
 55 |     return Color(name="yellow")
 56 | 
 57 | 
 58 | def test_agent_invoke(agent):
 59 |     result = agent("What is the time and weather in New York?")
 60 |     text = result.message["content"][0]["text"].lower()
 61 | 
 62 |     assert all(string in text for string in ["12:00", "sunny"])
 63 | 
 64 | 
 65 | @pytest.mark.asyncio
 66 | async def test_agent_invoke_async(agent):
 67 |     result = await agent.invoke_async("What is the time and weather in New York?")
 68 |     text = result.message["content"][0]["text"].lower()
 69 | 
 70 |     assert all(string in text for string in ["12:00", "sunny"])
 71 | 
 72 | 
 73 | @pytest.mark.asyncio
 74 | async def test_agent_stream_async(agent):
 75 |     stream = agent.stream_async("What is the time and weather in New York?")
 76 |     async for event in stream:
 77 |         _ = event
 78 | 
 79 |     result = event["result"]
 80 |     text = result.message["content"][0]["text"].lower()
 81 | 
 82 |     assert all(string in text for string in ["12:00", "sunny"])
 83 | 
 84 | 
 85 | def test_structured_output(agent, weather):
 86 |     tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny")
 87 |     exp_weather = weather
 88 |     assert tru_weather == exp_weather
 89 | 
 90 | 
 91 | @pytest.mark.asyncio
 92 | async def test_agent_structured_output_async(agent, weather):
 93 |     tru_weather = await agent.structured_output_async(type(weather), "The time is 12:00 and the weather is sunny")
 94 |     exp_weather = weather
 95 |     assert tru_weather == exp_weather
 96 | 
 97 | 
 98 | def test_invoke_multi_modal_input(agent, yellow_img):
 99 |     content = [
100 |         {"text": "Is this image red, blue, or yellow?"},
101 |         {
102 |             "image": {
103 |                 "format": "png",
104 |                 "source": {
105 |                     "bytes": yellow_img,
106 |                 },
107 |             },
108 |         },
109 |     ]
110 |     result = agent(content)
111 |     text = result.message["content"][0]["text"].lower()
112 | 
113 |     assert "yellow" in text
114 | 
115 | 
116 | def test_structured_output_multi_modal_input(agent, yellow_img, yellow_color):
117 |     content = [
118 |         {"text": "what is in this image"},
119 |         {
120 |             "image": {
121 |                 "format": "png",
122 |                 "source": {
123 |                     "bytes": yellow_img,
124 |                 },
125 |             },
126 |         },
127 |     ]
128 |     tru_color = agent.structured_output(type(yellow_color), content)
129 |     exp_color = yellow_color
130 |     assert tru_color == exp_color
131 | 


--------------------------------------------------------------------------------
/tests_integ/models/test_model_llamaapi.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Meta Platforms, Inc. and affiliates
 2 | import os
 3 | 
 4 | import pytest
 5 | 
 6 | import strands
 7 | from strands import Agent
 8 | from strands.models.llamaapi import LlamaAPIModel
 9 | from tests_integ.models import providers
10 | 
11 | # these tests only run if we have the llama api key
12 | pytestmark = providers.llama.mark
13 | 
14 | 
15 | @pytest.fixture
16 | def model():
17 |     return LlamaAPIModel(
18 |         model_id="Llama-4-Maverick-17B-128E-Instruct-FP8",
19 |         client_args={
20 |             "api_key": os.getenv("LLAMA_API_KEY"),
21 |         },
22 |     )
23 | 
24 | 
25 | @pytest.fixture
26 | def tools():
27 |     @strands.tool
28 |     def tool_time() -> str:
29 |         return "12:00"
30 | 
31 |     @strands.tool
32 |     def tool_weather() -> str:
33 |         return "sunny"
34 | 
35 |     return [tool_time, tool_weather]
36 | 
37 | 
38 | @pytest.fixture
39 | def agent(model, tools):
40 |     return Agent(model=model, tools=tools)
41 | 
42 | 
43 | def test_agent(agent):
44 |     result = agent("What is the time and weather in New York?")
45 |     text = result.message["content"][0]["text"].lower()
46 | 
47 |     assert all(string in text for string in ["12:00", "sunny"])
48 | 


--------------------------------------------------------------------------------
/tests_integ/models/test_model_mistral.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | 
  3 | import pytest
  4 | from pydantic import BaseModel
  5 | 
  6 | import strands
  7 | from strands import Agent
  8 | from strands.models.mistral import MistralModel
  9 | from tests_integ.models import providers
 10 | 
 11 | # these tests only run if we have the mistral api key
 12 | pytestmark = providers.mistral.mark
 13 | 
 14 | 
 15 | @pytest.fixture()
 16 | def streaming_model():
 17 |     return MistralModel(
 18 |         model_id="mistral-medium-latest",
 19 |         api_key=os.getenv("MISTRAL_API_KEY"),
 20 |         stream=True,
 21 |         temperature=0.7,
 22 |         max_tokens=1000,
 23 |         top_p=0.9,
 24 |     )
 25 | 
 26 | 
 27 | @pytest.fixture()
 28 | def non_streaming_model():
 29 |     return MistralModel(
 30 |         model_id="mistral-medium-latest",
 31 |         api_key=os.getenv("MISTRAL_API_KEY"),
 32 |         stream=False,
 33 |         temperature=0.7,
 34 |         max_tokens=1000,
 35 |         top_p=0.9,
 36 |     )
 37 | 
 38 | 
 39 | @pytest.fixture()
 40 | def system_prompt():
 41 |     return "You are an AI assistant that provides helpful and accurate information."
 42 | 
 43 | 
 44 | @pytest.fixture()
 45 | def tools():
 46 |     @strands.tool
 47 |     def tool_time() -> str:
 48 |         return "12:00"
 49 | 
 50 |     @strands.tool
 51 |     def tool_weather() -> str:
 52 |         return "sunny"
 53 | 
 54 |     return [tool_time, tool_weather]
 55 | 
 56 | 
 57 | @pytest.fixture()
 58 | def streaming_agent(streaming_model, tools):
 59 |     return Agent(model=streaming_model, tools=tools)
 60 | 
 61 | 
 62 | @pytest.fixture()
 63 | def non_streaming_agent(non_streaming_model, tools):
 64 |     return Agent(model=non_streaming_model, tools=tools)
 65 | 
 66 | 
 67 | @pytest.fixture(params=["streaming_agent", "non_streaming_agent"])
 68 | def agent(request):
 69 |     return request.getfixturevalue(request.param)
 70 | 
 71 | 
 72 | @pytest.fixture()
 73 | def weather():
 74 |     class Weather(BaseModel):
 75 |         """Extracts the time and weather from the user's message with the exact strings."""
 76 | 
 77 |         time: str
 78 |         weather: str
 79 | 
 80 |     return Weather(time="12:00", weather="sunny")
 81 | 
 82 | 
 83 | def test_agent_invoke(agent):
 84 |     result = agent("What is the time and weather in New York?")
 85 |     text = result.message["content"][0]["text"].lower()
 86 | 
 87 |     assert all(string in text for string in ["12:00", "sunny"])
 88 | 
 89 | 
 90 | @pytest.mark.asyncio
 91 | async def test_agent_invoke_async(agent):
 92 |     result = await agent.invoke_async("What is the time and weather in New York?")
 93 |     text = result.message["content"][0]["text"].lower()
 94 | 
 95 |     assert all(string in text for string in ["12:00", "sunny"])
 96 | 
 97 | 
 98 | @pytest.mark.asyncio
 99 | async def test_agent_stream_async(agent):
100 |     stream = agent.stream_async("What is the time and weather in New York?")
101 |     async for event in stream:
102 |         _ = event
103 | 
104 |     result = event["result"]
105 |     text = result.message["content"][0]["text"].lower()
106 | 
107 |     assert all(string in text for string in ["12:00", "sunny"])
108 | 
109 | 
110 | def test_agent_structured_output(non_streaming_agent, weather):
111 |     tru_weather = non_streaming_agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny")
112 |     exp_weather = weather
113 |     assert tru_weather == exp_weather
114 | 
115 | 
116 | @pytest.mark.asyncio
117 | async def test_agent_structured_output_async(non_streaming_agent, weather):
118 |     tru_weather = await non_streaming_agent.structured_output_async(
119 |         type(weather), "The time is 12:00 and the weather is sunny"
120 |     )
121 |     exp_weather = weather
122 |     assert tru_weather == exp_weather
123 | 


--------------------------------------------------------------------------------
/tests_integ/models/test_model_ollama.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | from pydantic import BaseModel
 3 | 
 4 | import strands
 5 | from strands import Agent
 6 | from strands.models.ollama import OllamaModel
 7 | from tests_integ.models import providers
 8 | 
 9 | # these tests only run if we have the ollama is running
10 | pytestmark = providers.ollama.mark
11 | 
12 | 
13 | @pytest.fixture
14 | def model():
15 |     return OllamaModel(host="http://localhost:11434", model_id="llama3.3:70b")
16 | 
17 | 
18 | @pytest.fixture
19 | def tools():
20 |     @strands.tool
21 |     def tool_time() -> str:
22 |         return "12:00"
23 | 
24 |     @strands.tool
25 |     def tool_weather() -> str:
26 |         return "sunny"
27 | 
28 |     return [tool_time, tool_weather]
29 | 
30 | 
31 | @pytest.fixture
32 | def agent(model, tools):
33 |     return Agent(model=model, tools=tools)
34 | 
35 | 
36 | @pytest.fixture
37 | def weather():
38 |     class Weather(BaseModel):
39 |         """Extracts the time and weather from the user's message with the exact strings."""
40 | 
41 |         time: str
42 |         weather: str
43 | 
44 |     return Weather(time="12:00", weather="sunny")
45 | 
46 | 
47 | def test_agent_invoke(agent):
48 |     result = agent("What is the time and weather in New York?")
49 |     text = result.message["content"][0]["text"].lower()
50 | 
51 |     assert all(string in text for string in ["12:00", "sunny"])
52 | 
53 | 
54 | @pytest.mark.asyncio
55 | async def test_agent_invoke_async(agent):
56 |     result = await agent.invoke_async("What is the time and weather in New York?")
57 |     text = result.message["content"][0]["text"].lower()
58 | 
59 |     assert all(string in text for string in ["12:00", "sunny"])
60 | 
61 | 
62 | @pytest.mark.asyncio
63 | async def test_agent_stream_async(agent):
64 |     stream = agent.stream_async("What is the time and weather in New York?")
65 |     async for event in stream:
66 |         _ = event
67 | 
68 |     result = event["result"]
69 |     text = result.message["content"][0]["text"].lower()
70 | 
71 |     assert all(string in text for string in ["12:00", "sunny"])
72 | 
73 | 
74 | def test_agent_structured_output(agent, weather):
75 |     tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny")
76 |     exp_weather = weather
77 |     assert tru_weather == exp_weather
78 | 
79 | 
80 | @pytest.mark.asyncio
81 | async def test_agent_structured_output_async(agent, weather):
82 |     tru_weather = await agent.structured_output_async(type(weather), "The time is 12:00 and the weather is sunny")
83 |     exp_weather = weather
84 |     assert tru_weather == exp_weather
85 | 


--------------------------------------------------------------------------------
/tests_integ/models/test_model_openai.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | 
  3 | import pydantic
  4 | import pytest
  5 | 
  6 | import strands
  7 | from strands import Agent, tool
  8 | from strands.models.openai import OpenAIModel
  9 | from tests_integ.models import providers
 10 | 
 11 | # these tests only run if we have the openai api key
 12 | pytestmark = providers.openai.mark
 13 | 
 14 | 
 15 | @pytest.fixture
 16 | def model():
 17 |     return OpenAIModel(
 18 |         model_id="gpt-4o",
 19 |         client_args={
 20 |             "api_key": os.getenv("OPENAI_API_KEY"),
 21 |         },
 22 |     )
 23 | 
 24 | 
 25 | @pytest.fixture
 26 | def tools():
 27 |     @strands.tool
 28 |     def tool_time() -> str:
 29 |         return "12:00"
 30 | 
 31 |     @strands.tool
 32 |     def tool_weather() -> str:
 33 |         return "sunny"
 34 | 
 35 |     return [tool_time, tool_weather]
 36 | 
 37 | 
 38 | @pytest.fixture
 39 | def agent(model, tools):
 40 |     return Agent(model=model, tools=tools)
 41 | 
 42 | 
 43 | @pytest.fixture
 44 | def weather():
 45 |     class Weather(pydantic.BaseModel):
 46 |         """Extracts the time and weather from the user's message with the exact strings."""
 47 | 
 48 |         time: str
 49 |         weather: str
 50 | 
 51 |     return Weather(time="12:00", weather="sunny")
 52 | 
 53 | 
 54 | @pytest.fixture
 55 | def yellow_color():
 56 |     class Color(pydantic.BaseModel):
 57 |         """Describes a color."""
 58 | 
 59 |         name: str
 60 | 
 61 |         @pydantic.field_validator("name", mode="after")
 62 |         @classmethod
 63 |         def lower(_, value):
 64 |             return value.lower()
 65 | 
 66 |     return Color(name="yellow")
 67 | 
 68 | 
 69 | @pytest.fixture(scope="module")
 70 | def test_image_path(request):
 71 |     return request.config.rootpath / "tests_integ" / "test_image.png"
 72 | 
 73 | 
 74 | def test_agent_invoke(agent):
 75 |     result = agent("What is the time and weather in New York?")
 76 |     text = result.message["content"][0]["text"].lower()
 77 | 
 78 |     assert all(string in text for string in ["12:00", "sunny"])
 79 | 
 80 | 
 81 | @pytest.mark.asyncio
 82 | async def test_agent_invoke_async(agent):
 83 |     result = await agent.invoke_async("What is the time and weather in New York?")
 84 |     text = result.message["content"][0]["text"].lower()
 85 | 
 86 |     assert all(string in text for string in ["12:00", "sunny"])
 87 | 
 88 | 
 89 | @pytest.mark.asyncio
 90 | async def test_agent_stream_async(agent):
 91 |     stream = agent.stream_async("What is the time and weather in New York?")
 92 |     async for event in stream:
 93 |         _ = event
 94 | 
 95 |     result = event["result"]
 96 |     text = result.message["content"][0]["text"].lower()
 97 | 
 98 |     assert all(string in text for string in ["12:00", "sunny"])
 99 | 
100 | 
101 | def test_agent_structured_output(agent, weather):
102 |     tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny")
103 |     exp_weather = weather
104 |     assert tru_weather == exp_weather
105 | 
106 | 
107 | @pytest.mark.asyncio
108 | async def test_agent_structured_output_async(agent, weather):
109 |     tru_weather = await agent.structured_output_async(type(weather), "The time is 12:00 and the weather is sunny")
110 |     exp_weather = weather
111 |     assert tru_weather == exp_weather
112 | 
113 | 
114 | def test_invoke_multi_modal_input(agent, yellow_img):
115 |     content = [
116 |         {"text": "what is in this image"},
117 |         {
118 |             "image": {
119 |                 "format": "png",
120 |                 "source": {
121 |                     "bytes": yellow_img,
122 |                 },
123 |             },
124 |         },
125 |     ]
126 |     result = agent(content)
127 |     text = result.message["content"][0]["text"].lower()
128 | 
129 |     assert "yellow" in text
130 | 
131 | 
132 | def test_structured_output_multi_modal_input(agent, yellow_img, yellow_color):
133 |     content = [
134 |         {"text": "Is this image red, blue, or yellow?"},
135 |         {
136 |             "image": {
137 |                 "format": "png",
138 |                 "source": {
139 |                     "bytes": yellow_img,
140 |                 },
141 |             },
142 |         },
143 |     ]
144 |     tru_color = agent.structured_output(type(yellow_color), content)
145 |     exp_color = yellow_color
146 |     assert tru_color == exp_color
147 | 
148 | 
149 | @pytest.mark.skip("https://github.com/strands-agents/sdk-python/issues/320")
150 | def test_tool_returning_images(model, yellow_img):
151 |     @tool
152 |     def tool_with_image_return():
153 |         return {
154 |             "status": "success",
155 |             "content": [
156 |                 {
157 |                     "image": {
158 |                         "format": "png",
159 |                         "source": {"bytes": yellow_img},
160 |                     }
161 |                 },
162 |             ],
163 |         }
164 | 
165 |     agent = Agent(model, tools=[tool_with_image_return])
166 |     # NOTE - this currently fails with: "Invalid 'messages[3]'. Image URLs are only allowed for messages with role
167 |     # 'user', but this message with role 'tool' contains an image URL."
168 |     # See https://github.com/strands-agents/sdk-python/issues/320 for additional details
169 |     agent("Run the the tool and analyze the image")
170 | 


--------------------------------------------------------------------------------
/tests_integ/models/test_model_writer.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | 
 3 | import pytest
 4 | from pydantic import BaseModel
 5 | 
 6 | import strands
 7 | from strands import Agent
 8 | from strands.models.writer import WriterModel
 9 | from tests_integ.models import providers
10 | 
11 | # these tests only run if we have the writer api key
12 | pytestmark = providers.writer.mark
13 | 
14 | 
15 | @pytest.fixture
16 | def model():
17 |     return WriterModel(
18 |         model_id="palmyra-x4",
19 |         client_args={"api_key": os.getenv("WRITER_API_KEY", "")},
20 |         stream_options={"include_usage": True},
21 |     )
22 | 
23 | 
24 | @pytest.fixture
25 | def system_prompt():
26 |     return "You are a smart assistant, that uses @ instead of all punctuation marks"
27 | 
28 | 
29 | @pytest.fixture
30 | def tools():
31 |     @strands.tool
32 |     def tool_time() -> str:
33 |         return "12:00"
34 | 
35 |     @strands.tool
36 |     def tool_weather() -> str:
37 |         return "sunny"
38 | 
39 |     return [tool_time, tool_weather]
40 | 
41 | 
42 | @pytest.fixture
43 | def agent(model, tools, system_prompt):
44 |     return Agent(model=model, tools=tools, system_prompt=system_prompt, load_tools_from_directory=False)
45 | 
46 | 
47 | def test_agent(agent):
48 |     result = agent("What is the time and weather in New York?")
49 |     text = result.message["content"][0]["text"].lower()
50 | 
51 |     assert all(string in text for string in ["12:00", "sunny"])
52 | 
53 | 
54 | @pytest.mark.asyncio
55 | async def test_agent_async(agent):
56 |     result = await agent.invoke_async("What is the time and weather in New York?")
57 |     text = result.message["content"][0]["text"].lower()
58 | 
59 |     assert all(string in text for string in ["12:00", "sunny"])
60 | 
61 | 
62 | @pytest.mark.asyncio
63 | async def test_agent_stream_async(agent):
64 |     stream = agent.stream_async("What is the time and weather in New York?")
65 |     async for event in stream:
66 |         _ = event
67 | 
68 |     result = event["result"]
69 |     text = result.message["content"][0]["text"].lower()
70 | 
71 |     assert all(string in text for string in ["12:00", "sunny"])
72 | 
73 | 
74 | def test_structured_output(agent):
75 |     class Weather(BaseModel):
76 |         time: str
77 |         weather: str
78 | 
79 |     result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny")
80 | 
81 |     assert isinstance(result, Weather)
82 |     assert result.time == "12:00"
83 |     assert result.weather == "sunny"
84 | 
85 | 
86 | @pytest.mark.asyncio
87 | async def test_structured_output_async(agent):
88 |     class Weather(BaseModel):
89 |         time: str
90 |         weather: str
91 | 
92 |     result = await agent.structured_output_async(Weather, "The time is 12:00 and the weather is sunny")
93 | 
94 |     assert isinstance(result, Weather)
95 |     assert result.time == "12:00"
96 |     assert result.weather == "sunny"
97 | 


--------------------------------------------------------------------------------
/tests_integ/test_agent_async.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | 
 3 | import strands
 4 | 
 5 | 
 6 | @pytest.fixture
 7 | def agent():
 8 |     return strands.Agent()
 9 | 
10 | 
11 | @pytest.mark.asyncio
12 | async def test_stream_async(agent):
13 |     stream = agent.stream_async("hello")
14 | 
15 |     exp_message = ""
16 |     async for event in stream:
17 |         if "event" in event and "contentBlockDelta" in event["event"]:
18 |             exp_message += event["event"]["contentBlockDelta"]["delta"]["text"]
19 | 
20 |     tru_message = agent.messages[-1]["content"][0]["text"]
21 | 
22 |     assert tru_message == exp_message
23 | 


--------------------------------------------------------------------------------
/tests_integ/test_bedrock_cache_point.py:
--------------------------------------------------------------------------------
 1 | from strands import Agent
 2 | from strands.types.content import Messages
 3 | 
 4 | 
 5 | def test_bedrock_cache_point():
 6 |     messages: Messages = [
 7 |         {
 8 |             "role": "user",
 9 |             "content": [
10 |                 {
11 |                     "text": "Some really long text!" * 1000  # Minimum token count for cachePoint is 1024 tokens
12 |                 },
13 |                 {"cachePoint": {"type": "default"}},
14 |             ],
15 |         },
16 |         {"role": "assistant", "content": [{"text": "Blue!"}]},
17 |     ]
18 | 
19 |     cache_point_usage = 0
20 | 
21 |     def cache_point_callback_handler(**kwargs):
22 |         nonlocal cache_point_usage
23 |         if "event" in kwargs and kwargs["event"] and "metadata" in kwargs["event"] and kwargs["event"]["metadata"]:
24 |             metadata = kwargs["event"]["metadata"]
25 |             if "usage" in metadata and metadata["usage"]:
26 |                 if "cacheReadInputTokens" in metadata["usage"] or "cacheWriteInputTokens" in metadata["usage"]:
27 |                     cache_point_usage += 1
28 | 
29 |     agent = Agent(messages=messages, callback_handler=cache_point_callback_handler, load_tools_from_directory=False)
30 |     agent("What is favorite color?")
31 |     assert cache_point_usage > 0
32 | 


--------------------------------------------------------------------------------
/tests_integ/test_context_overflow.py:
--------------------------------------------------------------------------------
 1 | from strands import Agent
 2 | from strands.types.content import Messages
 3 | 
 4 | 
 5 | def test_context_window_overflow():
 6 |     messages: Messages = [
 7 |         {"role": "user", "content": [{"text": "Too much text!" * 100000}]},
 8 |         {"role": "assistant", "content": [{"text": "That was a lot of text!"}]},
 9 |     ]
10 | 
11 |     agent = Agent(messages=messages, load_tools_from_directory=False)
12 |     agent("Hi!")
13 |     assert len(agent.messages) == 2
14 | 


--------------------------------------------------------------------------------
/tests_integ/test_function_tools.py:
--------------------------------------------------------------------------------
 1 | #!/usr/bin/env python3
 2 | """
 3 | Test script for function-based tools
 4 | """
 5 | 
 6 | import logging
 7 | from typing import Optional
 8 | 
 9 | from strands import Agent, tool
10 | 
11 | logging.getLogger("strands").setLevel(logging.DEBUG)
12 | logging.basicConfig(format="%(levelname)s | %(name)s | %(message)s", handlers=[logging.StreamHandler()])
13 | 
14 | 
15 | @tool
16 | def word_counter(text: str) -> str:
17 |     """
18 |     Count words in text.
19 | 
20 |     Args:
21 |         text: Text to analyze
22 |     """
23 |     count = len(text.split())
24 |     return f"Word count: {count}"
25 | 
26 | 
27 | @tool(name="count_chars", description="Count characters in text")
28 | def count_chars(text: str, include_spaces: Optional[bool] = True) -> str:
29 |     """
30 |     Count characters in text.
31 | 
32 |     Args:
33 |         text: Text to analyze
34 |         include_spaces: Whether to include spaces in the count
35 |     """
36 |     if not include_spaces:
37 |         text = text.replace(" ", "")
38 |     return f"Character count: {len(text)}"
39 | 
40 | 
41 | # Initialize agent with function tools
42 | agent = Agent(tools=[word_counter, count_chars])
43 | 
44 | print("\n===== Testing Direct Tool Access =====")
45 | # Use the tools directly
46 | word_result = agent.tool.word_counter(text="Hello world, this is a test")
47 | print(f"\nWord counter result: {word_result}")
48 | 
49 | char_result = agent.tool.count_chars(text="Hello world!", include_spaces=False)
50 | print(f"\nCharacter counter result: {char_result}")
51 | 
52 | print("\n===== Testing Natural Language Access =====")
53 | # Use through natural language
54 | nl_result = agent("Count the words in this sentence: 'The quick brown fox jumps over the lazy dog'")
55 | print(f"\nNL Result: {nl_result}")
56 | 


--------------------------------------------------------------------------------
/tests_integ/test_hot_tool_reload_decorator.py:
--------------------------------------------------------------------------------
  1 | """
  2 | Integration test for hot tool reloading functionality with the @tool decorator.
  3 | 
  4 | This test verifies that the Strands Agent can automatically detect and load
  5 | new tools created with the @tool decorator when they are added to a tools directory.
  6 | """
  7 | 
  8 | import logging
  9 | import os
 10 | import time
 11 | from pathlib import Path
 12 | 
 13 | from strands import Agent
 14 | 
 15 | logging.getLogger("strands").setLevel(logging.DEBUG)
 16 | logging.basicConfig(format="%(levelname)s | %(name)s | %(message)s", handlers=[logging.StreamHandler()])
 17 | 
 18 | 
 19 | def test_hot_reload_decorator():
 20 |     """
 21 |     Test that the Agent automatically loads tools created with @tool decorator
 22 |     when added to the current working directory's tools folder.
 23 |     """
 24 |     # Set up the tools directory in current working directory
 25 |     tools_dir = Path.cwd() / "tools"
 26 |     os.makedirs(tools_dir, exist_ok=True)
 27 | 
 28 |     # Tool path that will need cleanup
 29 |     test_tool_path = tools_dir / "uppercase.py"
 30 | 
 31 |     try:
 32 |         # Create an Agent instance without any tools
 33 |         agent = Agent(load_tools_from_directory=True)
 34 | 
 35 |         # Create a test tool using @tool decorator
 36 |         with open(test_tool_path, "w") as f:
 37 |             f.write("""
 38 | from strands import tool
 39 | 
 40 | @tool
 41 | def uppercase(text: str) -> str:
 42 |     \"\"\"Convert text to uppercase.\"\"\"
 43 |     return f"Input: {text}, Output: {text.upper()}"
 44 | """)
 45 | 
 46 |         # Wait for tool detection
 47 |         time.sleep(3)
 48 | 
 49 |         # Verify the tool was automatically loaded
 50 |         assert "uppercase" in agent.tool_names, "Agent should have detected and loaded the uppercase tool"
 51 | 
 52 |         # Test calling the dynamically loaded tool
 53 |         result = agent.tool.uppercase(text="hello world")
 54 | 
 55 |         # Check that the result is successful
 56 |         assert result.get("status") == "success", "Tool call should be successful"
 57 | 
 58 |         # Check the content of the response
 59 |         content_list = result.get("content", [])
 60 |         assert len(content_list) > 0, "Tool response should have content"
 61 | 
 62 |         # Check that the expected message is in the content
 63 |         text_content = next((item.get("text") for item in content_list if "text" in item), "")
 64 |         assert "Input: hello world, Output: HELLO WORLD" in text_content
 65 | 
 66 |     finally:
 67 |         # Clean up - remove the test file
 68 |         if test_tool_path.exists():
 69 |             os.remove(test_tool_path)
 70 | 
 71 | 
 72 | def test_hot_reload_decorator_update():
 73 |     """
 74 |     Test that the Agent detects updates to tools created with @tool decorator.
 75 |     """
 76 |     # Set up the tools directory in current working directory
 77 |     tools_dir = Path.cwd() / "tools"
 78 |     os.makedirs(tools_dir, exist_ok=True)
 79 | 
 80 |     # Tool path that will need cleanup - make sure filename matches function name
 81 |     test_tool_path = tools_dir / "greeting.py"
 82 | 
 83 |     try:
 84 |         # Create an Agent instance
 85 |         agent = Agent(load_tools_from_directory=True)
 86 | 
 87 |         # Create the initial version of the tool
 88 |         with open(test_tool_path, "w") as f:
 89 |             f.write("""
 90 | from strands import tool
 91 | 
 92 | @tool
 93 | def greeting(name: str) -> str:
 94 |     \"\"\"Generate a simple greeting.\"\"\"
 95 |     return f"Hello, {name}!"
 96 | """)
 97 | 
 98 |         # Wait for tool detection
 99 |         time.sleep(3)
100 | 
101 |         # Verify the tool was loaded
102 |         assert "greeting" in agent.tool_names, "Agent should have detected and loaded the greeting tool"
103 | 
104 |         # Test calling the tool
105 |         result1 = agent.tool.greeting(name="Strands")
106 |         text_content1 = next((item.get("text") for item in result1.get("content", []) if "text" in item), "")
107 |         assert "Hello, Strands!" in text_content1, "Tool should return simple greeting"
108 | 
109 |         # Update the tool with new functionality
110 |         with open(test_tool_path, "w") as f:
111 |             f.write("""
112 | from strands import tool
113 | import datetime
114 | 
115 | @tool
116 | def greeting(name: str, formal: bool = False) -> str:
117 |     \"\"\"Generate a greeting with optional formality.\"\"\"
118 |     current_hour = datetime.datetime.now().hour
119 |     time_of_day = "morning" if current_hour < 12 else "afternoon" if current_hour < 18 else "evening"
120 | 
121 |     if formal:
122 |         return f"Good {time_of_day}, {name}. It's a pleasure to meet you."
123 |     else:
124 |         return f"Hey {name}! How's your {time_of_day} going?"
125 | """)
126 | 
127 |         # Wait for hot reload to detect the change
128 |         time.sleep(3)
129 | 
130 |         # Test calling the updated tool
131 |         result2 = agent.tool.greeting(name="Strands", formal=True)
132 |         text_content2 = next((item.get("text") for item in result2.get("content", []) if "text" in item), "")
133 |         assert "Good" in text_content2 and "Strands" in text_content2 and "pleasure to meet you" in text_content2
134 | 
135 |         # Test with informal parameter
136 |         result3 = agent.tool.greeting(name="Strands", formal=False)
137 |         text_content3 = next((item.get("text") for item in result3.get("content", []) if "text" in item), "")
138 |         assert "Hey Strands!" in text_content3 and "going" in text_content3
139 | 
140 |     finally:
141 |         # Clean up - remove the test file
142 |         if test_tool_path.exists():
143 |             os.remove(test_tool_path)
144 | 


--------------------------------------------------------------------------------
/tests_integ/test_mcp_client.py:
--------------------------------------------------------------------------------
  1 | import base64
  2 | import os
  3 | import threading
  4 | import time
  5 | from typing import List, Literal
  6 | 
  7 | import pytest
  8 | from mcp import StdioServerParameters, stdio_client
  9 | from mcp.client.sse import sse_client
 10 | from mcp.client.streamable_http import streamablehttp_client
 11 | from mcp.types import ImageContent as MCPImageContent
 12 | 
 13 | from strands import Agent
 14 | from strands.tools.mcp.mcp_client import MCPClient
 15 | from strands.tools.mcp.mcp_types import MCPTransport
 16 | from strands.types.content import Message
 17 | from strands.types.tools import ToolUse
 18 | 
 19 | 
 20 | def start_calculator_server(transport: Literal["sse", "streamable-http"], port=int):
 21 |     """
 22 |     Initialize and start an MCP calculator server for integration testing.
 23 | 
 24 |     This function creates a FastMCP server instance that provides a simple
 25 |     calculator tool for performing addition operations. The server uses
 26 |     Server-Sent Events (SSE) transport for communication, making it accessible
 27 |     over HTTP.
 28 |     """
 29 |     from mcp.server import FastMCP
 30 | 
 31 |     mcp = FastMCP("Calculator Server", port=port)
 32 | 
 33 |     @mcp.tool(description="Calculator tool which performs calculations")
 34 |     def calculator(x: int, y: int) -> int:
 35 |         return x + y
 36 | 
 37 |     @mcp.tool(description="Generates a custom image")
 38 |     def generate_custom_image() -> MCPImageContent:
 39 |         try:
 40 |             with open("tests_integ/yellow.png", "rb") as image_file:
 41 |                 encoded_image = base64.b64encode(image_file.read())
 42 |                 return MCPImageContent(type="image", data=encoded_image, mimeType="image/png")
 43 |         except Exception as e:
 44 |             print("Error while generating custom image: {}".format(e))
 45 | 
 46 |     mcp.run(transport=transport)
 47 | 
 48 | 
 49 | def test_mcp_client():
 50 |     """
 51 |     Test should yield output similar to the following
 52 |     {'role': 'user', 'content': [{'text': 'add 1 and 2, then echo the result back to me'}]}
 53 |     {'role': 'assistant', 'content': [{'text': "I'll help you add 1 and 2 and then echo the result back to you.\n\nFirst, I'll calculate 1 + 2:"}, {'toolUse': {'toolUseId': 'tooluse_17ptaKUxQB20ySZxwgiI_w', 'name': 'calculator', 'input': {'x': 1, 'y': 2}}}]}
 54 |     {'role': 'user', 'content': [{'toolResult': {'status': 'success', 'toolUseId': 'tooluse_17ptaKUxQB20ySZxwgiI_w', 'content': [{'text': '3'}]}}]}
 55 |     {'role': 'assistant', 'content': [{'text': "\n\nNow I'll echo the result back to you:"}, {'toolUse': {'toolUseId': 'tooluse_GlOc5SN8TE6ti8jVZJMBOg', 'name': 'echo', 'input': {'to_echo': '3'}}}]}
 56 |     {'role': 'user', 'content': [{'toolResult': {'status': 'success', 'toolUseId': 'tooluse_GlOc5SN8TE6ti8jVZJMBOg', 'content': [{'text': '3'}]}}]}
 57 |     {'role': 'assistant', 'content': [{'text': '\n\nThe result of adding 1 and 2 is 3.'}]}
 58 |     """  # noqa: E501
 59 | 
 60 |     server_thread = threading.Thread(
 61 |         target=start_calculator_server, kwargs={"transport": "sse", "port": 8000}, daemon=True
 62 |     )
 63 |     server_thread.start()
 64 |     time.sleep(2)  # wait for server to startup completely
 65 | 
 66 |     sse_mcp_client = MCPClient(lambda: sse_client("http://127.0.0.1:8000/sse"))
 67 |     stdio_mcp_client = MCPClient(
 68 |         lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"]))
 69 |     )
 70 |     with sse_mcp_client, stdio_mcp_client:
 71 |         agent = Agent(tools=sse_mcp_client.list_tools_sync() + stdio_mcp_client.list_tools_sync())
 72 |         agent("add 1 and 2, then echo the result back to me")
 73 | 
 74 |         tool_use_content_blocks = _messages_to_content_blocks(agent.messages)
 75 |         assert any([block["name"] == "echo" for block in tool_use_content_blocks])
 76 |         assert any([block["name"] == "calculator" for block in tool_use_content_blocks])
 77 | 
 78 |         image_prompt = """
 79 |         Generate a custom image, then tell me if the image is red, blue, yellow, pink, orange, or green. 
 80 |         RESPOND ONLY WITH THE COLOR
 81 |         """
 82 |         assert any(
 83 |             [
 84 |                 "yellow".casefold() in block["text"].casefold()
 85 |                 for block in agent(image_prompt).message["content"]
 86 |                 if "text" in block
 87 |             ]
 88 |         )
 89 | 
 90 | 
 91 | def test_can_reuse_mcp_client():
 92 |     stdio_mcp_client = MCPClient(
 93 |         lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"]))
 94 |     )
 95 |     with stdio_mcp_client:
 96 |         stdio_mcp_client.list_tools_sync()
 97 |         pass
 98 |     with stdio_mcp_client:
 99 |         agent = Agent(tools=stdio_mcp_client.list_tools_sync())
100 |         agent("echo the following to me <to_echo>DOG</to_echo>")
101 | 
102 |         tool_use_content_blocks = _messages_to_content_blocks(agent.messages)
103 |         assert any([block["name"] == "echo" for block in tool_use_content_blocks])
104 | 
105 | 
106 | @pytest.mark.skipif(
107 |     condition=os.environ.get("GITHUB_ACTIONS") == "true",
108 |     reason="streamable transport is failing in GitHub actions, debugging if linux compatibility issue",
109 | )
110 | def test_streamable_http_mcp_client():
111 |     server_thread = threading.Thread(
112 |         target=start_calculator_server, kwargs={"transport": "streamable-http", "port": 8001}, daemon=True
113 |     )
114 |     server_thread.start()
115 |     time.sleep(2)  # wait for server to startup completely
116 | 
117 |     def transport_callback() -> MCPTransport:
118 |         return streamablehttp_client(url="http://127.0.0.1:8001/mcp")
119 | 
120 |     streamable_http_client = MCPClient(transport_callback)
121 |     with streamable_http_client:
122 |         agent = Agent(tools=streamable_http_client.list_tools_sync())
123 |         agent("add 1 and 2 using a calculator")
124 | 
125 |         tool_use_content_blocks = _messages_to_content_blocks(agent.messages)
126 |         assert any([block["name"] == "calculator" for block in tool_use_content_blocks])
127 | 
128 | 
129 | def _messages_to_content_blocks(messages: List[Message]) -> List[ToolUse]:
130 |     return [block["toolUse"] for message in messages for block in message["content"] if "toolUse" in block]
131 | 


--------------------------------------------------------------------------------
/tests_integ/test_multiagent_swarm.py:
--------------------------------------------------------------------------------
  1 | import pytest
  2 | 
  3 | from strands import Agent, tool
  4 | from strands.multiagent.swarm import Swarm
  5 | from strands.types.content import ContentBlock
  6 | 
  7 | 
  8 | @tool
  9 | def web_search(query: str) -> str:
 10 |     """Search the web for information."""
 11 |     # Mock implementation
 12 |     return f"Results for '{query}': 25% yearly growth assumption, reaching $1.81 trillion by 2030"
 13 | 
 14 | 
 15 | @tool
 16 | def calculate(expression: str) -> str:
 17 |     """Calculate the result of a mathematical expression."""
 18 |     try:
 19 |         return f"The result of {expression} is {eval(expression)}"
 20 |     except Exception as e:
 21 |         return f"Error calculating {expression}: {str(e)}"
 22 | 
 23 | 
 24 | @pytest.fixture
 25 | def researcher_agent():
 26 |     """Create an agent specialized in research."""
 27 |     return Agent(
 28 |         name="researcher",
 29 |         system_prompt=(
 30 |             "You are a research specialist who excels at finding information. When you need to perform calculations or"
 31 |             " format documents, hand off to the appropriate specialist."
 32 |         ),
 33 |         tools=[web_search],
 34 |     )
 35 | 
 36 | 
 37 | @pytest.fixture
 38 | def analyst_agent():
 39 |     """Create an agent specialized in data analysis."""
 40 |     return Agent(
 41 |         name="analyst",
 42 |         system_prompt=(
 43 |             "You are a data analyst who excels at calculations and numerical analysis. When you need"
 44 |             " research or document formatting, hand off to the appropriate specialist."
 45 |         ),
 46 |         tools=[calculate],
 47 |     )
 48 | 
 49 | 
 50 | @pytest.fixture
 51 | def writer_agent():
 52 |     """Create an agent specialized in writing and formatting."""
 53 |     return Agent(
 54 |         name="writer",
 55 |         system_prompt=(
 56 |             "You are a professional writer who excels at formatting and presenting information. When you need research"
 57 |             " or calculations, hand off to the appropriate specialist."
 58 |         ),
 59 |     )
 60 | 
 61 | 
 62 | def test_swarm_execution_with_string(researcher_agent, analyst_agent, writer_agent):
 63 |     """Test swarm execution with string input."""
 64 |     # Create the swarm
 65 |     swarm = Swarm([researcher_agent, analyst_agent, writer_agent])
 66 | 
 67 |     # Define a task that requires collaboration
 68 |     task = (
 69 |         "Research the current AI agent market trends, calculate the growth rate assuming 25% yearly growth, "
 70 |         "and create a basic report"
 71 |     )
 72 | 
 73 |     # Execute the swarm
 74 |     result = swarm(task)
 75 | 
 76 |     # Verify results
 77 |     assert result.status.value == "completed"
 78 |     assert len(result.results) > 0
 79 |     assert result.execution_time > 0
 80 |     assert result.execution_count > 0
 81 | 
 82 |     # Verify agent history - at least one agent should have been used
 83 |     assert len(result.node_history) > 0
 84 | 
 85 | 
 86 | @pytest.mark.asyncio
 87 | async def test_swarm_execution_with_image(researcher_agent, analyst_agent, writer_agent, yellow_img):
 88 |     """Test swarm execution with image input."""
 89 |     # Create the swarm
 90 |     swarm = Swarm([researcher_agent, analyst_agent, writer_agent])
 91 | 
 92 |     # Create content blocks with text and image
 93 |     content_blocks: list[ContentBlock] = [
 94 |         {"text": "Analyze this image and create a report about what you see:"},
 95 |         {"image": {"format": "png", "source": {"bytes": yellow_img}}},
 96 |     ]
 97 | 
 98 |     # Execute the swarm with multi-modal input
 99 |     result = await swarm.invoke_async(content_blocks)
100 | 
101 |     # Verify results
102 |     assert result.status.value == "completed"
103 |     assert len(result.results) > 0
104 |     assert result.execution_time > 0
105 |     assert result.execution_count > 0
106 | 
107 |     # Verify agent history - at least one agent should have been used
108 |     assert len(result.node_history) > 0
109 | 


--------------------------------------------------------------------------------
/tests_integ/test_stream_agent.py:
--------------------------------------------------------------------------------
 1 | """
 2 | Test script for Strands' custom callback handler functionality.
 3 | Demonstrates different patterns of callback handling and processing.
 4 | """
 5 | 
 6 | import logging
 7 | 
 8 | from strands import Agent
 9 | 
10 | logging.getLogger("strands").setLevel(logging.DEBUG)
11 | logging.basicConfig(format="%(levelname)s | %(name)s | %(message)s", handlers=[logging.StreamHandler()])
12 | 
13 | 
14 | class ToolCountingCallbackHandler:
15 |     def __init__(self):
16 |         self.tool_count = 0
17 |         self.message_count = 0
18 | 
19 |     def callback_handler(self, **kwargs) -> None:
20 |         """
21 |         Custom callback handler that processes and displays different types of events.
22 | 
23 |         Args:
24 |             **kwargs: Callback event data including:
25 |                 - data: Regular output
26 |                 - complete: Completion status
27 |                 - message: Message processing
28 |                 - current_tool_use: Tool execution
29 |         """
30 |         # Extract event data
31 |         data = kwargs.get("data", "")
32 |         complete = kwargs.get("complete", False)
33 |         message = kwargs.get("message", {})
34 |         current_tool_use = kwargs.get("current_tool_use", {})
35 | 
36 |         # Handle regular data output
37 |         if data:
38 |             print(f"🔄 Data: {data}")
39 | 
40 |         # Handle tool execution events
41 |         if current_tool_use:
42 |             self.tool_count += 1
43 |             tool_name = current_tool_use.get("name", "")
44 |             tool_input = current_tool_use.get("input", {})
45 |             print(f"🛠️ Tool Execution #{self.tool_count}\nTool: {tool_name}\nInput: {tool_input}")
46 | 
47 |         # Handle message processing
48 |         if message:
49 |             self.message_count += 1
50 |             print(f"📝 Message #{self.message_count}")
51 | 
52 |         # Handle completion
53 |         if complete:
54 |             self.console.print("✨ Callback Complete", style="bold green")
55 | 
56 | 
57 | def test_basic_interaction():
58 |     """Test basic AGI interaction with custom callback handler."""
59 |     print("\nTesting Basic Interaction")
60 | 
61 |     # Initialize agent with custom handler
62 |     agent = Agent(
63 |         callback_handler=ToolCountingCallbackHandler().callback_handler,
64 |         load_tools_from_directory=False,
65 |     )
66 | 
67 |     # Simple prompt to test callbacking
68 |     agent("Tell me a short joke from your general knowledge")
69 | 
70 |     print("\nBasic Interaction Complete")
71 | 


--------------------------------------------------------------------------------
/tests_integ/yellow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/strands-agents/sdk-python/f5e24d402bb22dc1a54c94dd030b4be0f7e73261/tests_integ/yellow.png


--------------------------------------------------------------------------------