├── .github
├── CODEOWNERS
├── ISSUE_TEMPLATE
│ ├── bug_report.yml
│ ├── config.yml
│ └── feature_request.yml
├── PULL_REQUEST_TEMPLATE.md
├── dependabot.yml
└── workflows
│ ├── pr-and-push.yml
│ ├── pypi-publish-on-release.yml
│ └── test-lint.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CONTRIBUTING.md
├── LICENSE
├── NOTICE
├── README.md
├── STYLE_GUIDE.md
├── pyproject.toml
├── src
└── strands
│ ├── __init__.py
│ ├── agent
│ ├── __init__.py
│ ├── agent.py
│ ├── agent_result.py
│ └── conversation_manager
│ │ ├── __init__.py
│ │ ├── conversation_manager.py
│ │ ├── null_conversation_manager.py
│ │ └── sliding_window_conversation_manager.py
│ ├── event_loop
│ ├── __init__.py
│ ├── error_handler.py
│ ├── event_loop.py
│ ├── message_processor.py
│ └── streaming.py
│ ├── handlers
│ ├── __init__.py
│ ├── callback_handler.py
│ └── tool_handler.py
│ ├── models
│ ├── __init__.py
│ ├── anthropic.py
│ ├── bedrock.py
│ ├── litellm.py
│ ├── llamaapi.py
│ ├── ollama.py
│ └── openai.py
│ ├── py.typed
│ ├── telemetry
│ ├── __init__.py
│ ├── metrics.py
│ └── tracer.py
│ ├── tools
│ ├── __init__.py
│ ├── decorator.py
│ ├── executor.py
│ ├── loader.py
│ ├── mcp
│ │ ├── __init__.py
│ │ ├── mcp_agent_tool.py
│ │ ├── mcp_client.py
│ │ └── mcp_types.py
│ ├── registry.py
│ ├── thread_pool_executor.py
│ ├── tools.py
│ └── watcher.py
│ └── types
│ ├── __init__.py
│ ├── content.py
│ ├── event_loop.py
│ ├── exceptions.py
│ ├── guardrails.py
│ ├── media.py
│ ├── models
│ ├── __init__.py
│ ├── model.py
│ └── openai.py
│ ├── streaming.py
│ ├── tools.py
│ └── traces.py
├── tests-integ
├── __init__.py
├── echo_server.py
├── test_bedrock_cache_point.py
├── test_bedrock_guardrails.py
├── test_context_overflow.py
├── test_function_tools.py
├── test_hot_tool_reload_decorator.py
├── test_image.png
├── test_mcp_client.py
├── test_model_anthropic.py
├── test_model_bedrock.py
├── test_model_litellm.py
├── test_model_llamaapi.py
├── test_model_openai.py
└── test_stream_agent.py
└── tests
├── __init__.py
├── conftest.py
└── strands
├── __init__.py
├── agent
├── __init__.py
├── test_agent.py
├── test_agent_result.py
└── test_conversation_manager.py
├── event_loop
├── __init__.py
├── test_error_handler.py
├── test_event_loop.py
├── test_message_processor.py
└── test_streaming.py
├── handlers
├── __init__.py
├── test_callback_handler.py
└── test_tool_handler.py
├── models
├── __init__.py
├── test_anthropic.py
├── test_bedrock.py
├── test_litellm.py
├── test_llamaapi.py
├── test_ollama.py
└── test_openai.py
├── telemetry
├── test_metrics.py
└── test_tracer.py
├── tools
├── __init__.py
├── mcp
│ ├── __init__.py
│ ├── test_mcp_agent_tool.py
│ └── test_mcp_client.py
├── test_decorator.py
├── test_executor.py
├── test_loader.py
├── test_registry.py
├── test_thread_pool_executor.py
├── test_tools.py
└── test_watcher.py
└── types
└── models
├── __init__.py
├── test_model.py
└── test_openai.py
/.github/CODEOWNERS:
--------------------------------------------------------------------------------
1 | # These owners will be the default owners for everything in
2 | # the repo. Unless a later match takes precedence,
3 | # @strands-agents/contributors will be requested for
4 | # review when someone opens a pull request.
5 | * @strands-agents/maintainers
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.yml:
--------------------------------------------------------------------------------
1 | name: Bug Report
2 | description: Report a bug in the Strands Agents SDK
3 | title: "[BUG] "
4 | labels: ["bug", "triage"]
5 | assignees: []
6 | body:
7 | - type: markdown
8 | attributes:
9 | value: |
10 | Thanks for taking the time to fill out this bug report for Strands SDK!
11 | - type: checkboxes
12 | id: "checks"
13 | attributes:
14 | label: "Checks"
15 | options:
16 | - label: "I have updated to the lastest minor and patch version of Strands"
17 | required: true
18 | - label: "I have checked the documentation and this is not expected behavior"
19 | required: true
20 | - label: "I have searched [./issues](./issues?q=) and there are no duplicates of my issue"
21 | required: true
22 | - type: input
23 | id: strands-version
24 | attributes:
25 | label: Strands Version
26 | description: Which version of Strands are you using?
27 | placeholder: e.g., 0.5.2
28 | validations:
29 | required: true
30 | - type: input
31 | id: python-version
32 | attributes:
33 | label: Python Version
34 | description: Which version of Python are you using?
35 | placeholder: e.g., 3.10.5
36 | validations:
37 | required: true
38 | - type: input
39 | id: os
40 | attributes:
41 | label: Operating System
42 | description: Which operating system are you using?
43 | placeholder: e.g., macOS 12.6
44 | validations:
45 | required: true
46 | - type: dropdown
47 | id: installation-method
48 | attributes:
49 | label: Installation Method
50 | description: How did you install Strands?
51 | options:
52 | - pip
53 | - git clone
54 | - binary
55 | - other
56 | validations:
57 | required: true
58 | - type: textarea
59 | id: steps-to-reproduce
60 | attributes:
61 | label: Steps to Reproduce
62 | description: Detailed steps to reproduce the behavior
63 | placeholder: |
64 | 1. Install Strands using...
65 | 2. Run the command...
66 | 3. See error...
67 | validations:
68 | required: true
69 | - type: textarea
70 | id: expected-behavior
71 | attributes:
72 | label: Expected Behavior
73 | description: A clear description of what you expected to happen
74 | validations:
75 | required: true
76 | - type: textarea
77 | id: actual-behavior
78 | attributes:
79 | label: Actual Behavior
80 | description: What actually happened
81 | validations:
82 | required: true
83 | - type: textarea
84 | id: additional-context
85 | attributes:
86 | label: Additional Context
87 | description: Any other relevant information, logs, screenshots, etc.
88 | - type: textarea
89 | id: possible-solution
90 | attributes:
91 | label: Possible Solution
92 | description: Optional - If you have suggestions on how to fix the bug
93 | - type: input
94 | id: related-issues
95 | attributes:
96 | label: Related Issues
97 | description: Optional - Link to related issues if applicable
98 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: false
2 | contact_links:
3 | - name: Strands Agents SDK Support
4 | url: https://github.com/strands-agents/sdk-python/discussions
5 | about: Please ask and answer questions here
6 | - name: Strands Agents SDK Documentation
7 | url: https://github.com/strands-agents/docs
8 | about: Visit our documentation for help
9 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.yml:
--------------------------------------------------------------------------------
1 | name: Feature Request
2 | description: Suggest a new feature or enhancement for Strands Agents SDK
3 | title: "[FEATURE] "
4 | labels: ["enhancement", "triage"]
5 | assignees: []
6 | body:
7 | - type: markdown
8 | attributes:
9 | value: |
10 | Thanks for suggesting a new feature for Strands Agents SDK!
11 | - type: textarea
12 | id: problem-statement
13 | attributes:
14 | label: Problem Statement
15 | description: Describe the problem you're trying to solve. What is currently difficult or impossible to do?
16 | placeholder: I would like Strands to...
17 | validations:
18 | required: true
19 | - type: textarea
20 | id: proposed-solution
21 | attributes:
22 | label: Proposed Solution
23 | description: Optional - Describe your proposed solution in detail. How would this feature work?
24 | - type: textarea
25 | id: use-case
26 | attributes:
27 | label: Use Case
28 | description: Provide specific use cases for the feature. How would people use it?
29 | placeholder: This would help with...
30 | validations:
31 | required: true
32 | - type: textarea
33 | id: alternatives-solutions
34 | attributes:
35 | label: Alternatives Solutions
36 | description: Optional - Have you considered alternative approaches? What are their pros and cons?
37 | - type: textarea
38 | id: additional-context
39 | attributes:
40 | label: Additional Context
41 | description: Include any other context, screenshots, code examples, or references that might help understand the feature request.
42 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | ## Description
2 | [Provide a detailed description of the changes in this PR]
3 |
4 | ## Related Issues
5 | [Link to related issues using #issue-number format]
6 |
7 | ## Documentation PR
8 | [Link to related associated PR in the agent-docs repo]
9 |
10 | ## Type of Change
11 | - Bug fix
12 | - New feature
13 | - Breaking change
14 | - Documentation update
15 | - Other (please describe):
16 |
17 | [Choose one of the above types of changes]
18 |
19 |
20 | ## Testing
21 | [How have you tested the change?]
22 |
23 | * `hatch fmt --linter`
24 | * `hatch fmt --formatter`
25 | * `hatch test --all`
26 | * Verify that the changes do not break functionality or introduce warnings in consuming repositories: agents-docs, agents-tools, agents-cli
27 |
28 |
29 | ## Checklist
30 | - [ ] I have read the CONTRIBUTING document
31 | - [ ] I have added tests that prove my fix is effective or my feature works
32 | - [ ] I have updated the documentation accordingly
33 | - [ ] I have added an appropriate example to the documentation to outline the feature
34 | - [ ] My changes generate no new warnings
35 | - [ ] Any dependent changes have been merged and published
36 |
37 | By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
38 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | updates:
3 | - package-ecosystem: "pip"
4 | directory: "/"
5 | schedule:
6 | interval: "daily"
7 | open-pull-requests-limit: 100
8 | commit-message:
9 | prefix: ci
10 | groups:
11 | dev-dependencies:
12 | patterns:
13 | - "pytest"
14 | - package-ecosystem: "github-actions"
15 | directory: "/"
16 | schedule:
17 | interval: "daily"
18 | open-pull-requests-limit: 100
19 | commit-message:
20 | prefix: ci
21 |
--------------------------------------------------------------------------------
/.github/workflows/pr-and-push.yml:
--------------------------------------------------------------------------------
1 | name: Pull Request and Push Action
2 |
3 | on:
4 | pull_request: # Safer than pull_request_target for untrusted code
5 | branches: [ main ]
6 | types: [opened, synchronize, reopened, ready_for_review, review_requested, review_request_removed]
7 | push:
8 | branches: [ main ] # Also run on direct pushes to main
9 | concurrency:
10 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
11 | cancel-in-progress: true
12 |
13 | jobs:
14 | call-test-lint:
15 | uses: ./.github/workflows/test-lint.yml
16 | permissions:
17 | contents: read
18 | with:
19 | ref: ${{ github.event.pull_request.head.sha }}
20 |
--------------------------------------------------------------------------------
/.github/workflows/pypi-publish-on-release.yml:
--------------------------------------------------------------------------------
1 | name: Publish Python Package
2 |
3 | on:
4 | release:
5 | types:
6 | - published
7 |
8 | jobs:
9 | call-test-lint:
10 | uses: ./.github/workflows/test-lint.yml
11 | permissions:
12 | contents: read
13 | with:
14 | ref: ${{ github.event.release.target_commitish }}
15 |
16 | build:
17 | name: Build distribution 📦
18 | permissions:
19 | contents: read
20 | needs:
21 | - call-test-lint
22 | runs-on: ubuntu-latest
23 |
24 | steps:
25 | - uses: actions/checkout@v4
26 | with:
27 | persist-credentials: false
28 |
29 | - name: Set up Python
30 | uses: actions/setup-python@v5
31 | with:
32 | python-version: '3.10'
33 |
34 | - name: Install dependencies
35 | run: |
36 | python -m pip install --upgrade pip
37 | pip install hatch twine
38 |
39 | - name: Validate version
40 | run: |
41 | version=$(hatch version)
42 | if [[ $version =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
43 | echo "Valid version format"
44 | exit 0
45 | else
46 | echo "Invalid version format"
47 | exit 1
48 | fi
49 |
50 | - name: Build
51 | run: |
52 | hatch build
53 |
54 | - name: Store the distribution packages
55 | uses: actions/upload-artifact@v4
56 | with:
57 | name: python-package-distributions
58 | path: dist/
59 |
60 | deploy:
61 | name: Upload release to PyPI
62 | needs:
63 | - build
64 | runs-on: ubuntu-latest
65 |
66 | # environment is used by PyPI Trusted Publisher and is strongly encouraged
67 | # https://docs.pypi.org/trusted-publishers/adding-a-publisher/
68 | environment:
69 | name: pypi
70 | url: https://pypi.org/p/strands-agents
71 | permissions:
72 | # IMPORTANT: this permission is mandatory for Trusted Publishing
73 | id-token: write
74 |
75 | steps:
76 | - name: Download all the dists
77 | uses: actions/download-artifact@v4
78 | with:
79 | name: python-package-distributions
80 | path: dist/
81 | - name: Publish distribution 📦 to PyPI
82 | uses: pypa/gh-action-pypi-publish@release/v1
83 |
--------------------------------------------------------------------------------
/.github/workflows/test-lint.yml:
--------------------------------------------------------------------------------
1 | name: Test and Lint
2 |
3 | on:
4 | workflow_call:
5 | inputs:
6 | ref:
7 | required: true
8 | type: string
9 |
10 | jobs:
11 | unit-test:
12 | name: Unit Tests - Python ${{ matrix.python-version }} - ${{ matrix.os-name }}
13 | permissions:
14 | contents: read
15 | strategy:
16 | matrix:
17 | include:
18 | # Linux
19 | - os: ubuntu-latest
20 | os-name: 'linux'
21 | python-version: "3.10"
22 | - os: ubuntu-latest
23 | os-name: 'linux'
24 | python-version: "3.11"
25 | - os: ubuntu-latest
26 | os-name: 'linux'
27 | python-version: "3.12"
28 | - os: ubuntu-latest
29 | os-name: 'linux'
30 | python-version: "3.13"
31 | # Windows
32 | - os: windows-latest
33 | os-name: 'windows'
34 | python-version: "3.10"
35 | - os: windows-latest
36 | os-name: 'windows'
37 | python-version: "3.11"
38 | - os: windows-latest
39 | os-name: 'windows'
40 | python-version: "3.12"
41 | - os: windows-latest
42 | os-name: 'windows'
43 | python-version: "3.13"
44 | # MacOS - latest only; not enough runners for macOS
45 | - os: macos-latest
46 | os-name: 'macOS'
47 | python-version: "3.13"
48 | fail-fast: true
49 | runs-on: ${{ matrix.os }}
50 | env:
51 | LOG_LEVEL: DEBUG
52 | steps:
53 | - name: Checkout code
54 | uses: actions/checkout@v4
55 | with:
56 | ref: ${{ inputs.ref }} # Explicitly define which commit to check out
57 | persist-credentials: false # Don't persist credentials for subsequent actions
58 | - name: Set up Python
59 | uses: actions/setup-python@v5
60 | with:
61 | python-version: ${{ matrix.python-version }}
62 | - name: Install dependencies
63 | run: |
64 | pip install --no-cache-dir hatch
65 | - name: Run Unit tests
66 | id: tests
67 | run: hatch test tests --cover
68 | continue-on-error: false
69 | lint:
70 | name: Lint
71 | runs-on: ubuntu-latest
72 | permissions:
73 | contents: read
74 | steps:
75 | - name: Checkout code
76 | uses: actions/checkout@v4
77 | with:
78 | ref: ${{ inputs.ref }}
79 | persist-credentials: false
80 |
81 | - name: Set up Python
82 | uses: actions/setup-python@v5
83 | with:
84 | python-version: '3.10'
85 | cache: 'pip'
86 |
87 | - name: Install dependencies
88 | run: |
89 | pip install --no-cache-dir hatch
90 |
91 | - name: Run lint
92 | id: lint
93 | run: hatch run test-lint
94 | continue-on-error: false
95 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | build
2 | __pycache__*
3 | .coverage*
4 | .env
5 | .venv
6 | .mypy_cache
7 | .pytest_cache
8 | .ruff_cache
9 | *.bak
10 | .vscode
11 | dist
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: local
3 | hooks:
4 | - id: hatch-format
5 | name: Format code
6 | entry: hatch fmt --formatter
7 | language: system
8 | pass_filenames: false
9 | types: [python]
10 | stages: [pre-commit]
11 | - id: hatch-lint
12 | name: Lint code
13 | entry: hatch run test-lint
14 | language: system
15 | pass_filenames: false
16 | types: [python]
17 | stages: [pre-commit]
18 | - id: hatch-test-lint
19 | name: Type linting
20 | entry: hatch run test-lint
21 | language: system
22 | pass_filenames: false
23 | types: [ python ]
24 | stages: [ pre-commit ]
25 | - id: hatch-test
26 | name: Unit tests
27 | entry: hatch test
28 | language: system
29 | pass_filenames: false
30 | types: [python]
31 | stages: [pre-commit]
32 | - id: commitizen-check
33 | name: Check commit message
34 | entry: hatch run cz check --commit-msg-file
35 | language: system
36 | stages: [commit-msg]
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing Guidelines
2 |
3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional
4 | documentation, we greatly value feedback and contributions from our community.
5 |
6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary
7 | information to effectively respond to your bug report or contribution.
8 |
9 |
10 | ## Reporting Bugs/Feature Requests
11 |
12 | We welcome you to use the [Bug Reports](../../issues/new?template=bug_report.yml) file to report bugs or [Feature Requests](../../issues/new?template=feature_request.yml) to suggest features.
13 |
14 | For a list of known bugs and feature requests:
15 | - Check [Bug Reports](../../issues?q=is%3Aissue%20state%3Aopen%20label%3Abug) for currently tracked issues
16 | - See [Feature Requests](../../issues?q=is%3Aissue%20state%3Aopen%20label%3Aenhancement) for requested enhancements
17 |
18 | When filing an issue, please check for already tracked items
19 |
20 | Please try to include as much information as you can. Details like these are incredibly useful:
21 |
22 | * A reproducible test case or series of steps
23 | * The version of our code being used (commit ID)
24 | * Any modifications you've made relevant to the bug
25 | * Anything unusual about your environment or deployment
26 |
27 |
28 | ## Development Environment
29 |
30 | This project uses [hatchling](https://hatch.pypa.io/latest/build/#hatchling) as the build backend and [hatch](https://hatch.pypa.io/latest/) for development workflow management.
31 |
32 | ### Setting Up Your Development Environment
33 |
34 | 1. Entering virtual environment using `hatch` (recommended), then launch your IDE in the new shell.
35 | ```bash
36 | hatch shell dev
37 | ```
38 |
39 | Alternatively, install development dependencies in a manually created virtual environment:
40 | ```bash
41 | pip install -e ".[dev]" && pip install -e ".[litellm]"
42 | ```
43 |
44 |
45 | 2. Set up pre-commit hooks:
46 | ```bash
47 | pre-commit install -t pre-commit -t commit-msg
48 | ```
49 | This will automatically run formatters and conventional commit checks on your code before each commit.
50 |
51 | 3. Run code formatters manually:
52 | ```bash
53 | hatch fmt --formatter
54 | ```
55 |
56 | 4. Run linters:
57 | ```bash
58 | hatch fmt --linter
59 | ```
60 |
61 | 5. Run unit tests:
62 | ```bash
63 | hatch test
64 | ```
65 |
66 | 6. Run integration tests:
67 | ```bash
68 | hatch run test-integ
69 | ```
70 |
71 | ### Pre-commit Hooks
72 |
73 | We use [pre-commit](https://pre-commit.com/) to automatically run quality checks before each commit. The hook will run `hatch run format`, `hatch run lint`, `hatch run test`, and `hatch run cz check` on when you make a commit, ensuring code consistency.
74 |
75 | The pre-commit hook is installed with:
76 |
77 | ```bash
78 | pre-commit install
79 | ```
80 |
81 | You can also run the hooks manually on all files:
82 |
83 | ```bash
84 | pre-commit run --all-files
85 | ```
86 |
87 | ### Code Formatting and Style Guidelines
88 |
89 | We use the following tools to ensure code quality:
90 | 1. **ruff** - For formatting and linting
91 | 2. **mypy** - For static type checking
92 |
93 | These tools are configured in the [pyproject.toml](./pyproject.toml) file. Please ensure your code passes all linting and type checks before submitting a pull request:
94 |
95 | ```bash
96 | # Run all checks
97 | hatch fmt --formatter
98 | hatch fmt --linter
99 | ```
100 |
101 | If you're using an IDE like VS Code or PyCharm, consider configuring it to use these tools automatically.
102 |
103 | For additional details on styling, please see our dedicated [Style Guide](./STYLE_GUIDE.md).
104 |
105 |
106 | ## Contributing via Pull Requests
107 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:
108 |
109 | 1. You are working against the latest source on the *main* branch.
110 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.
111 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted.
112 |
113 | To send us a pull request, please:
114 |
115 | 1. Create a branch.
116 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
117 | 3. Format your code using `hatch fmt --formatter`.
118 | 4. Run linting checks with `hatch fmt --linter`.
119 | 5. Ensure local tests pass with `hatch test` and `hatch run test-integ`.
120 | 6. Commit to your branch using clear commit messages following the [Conventional Commits](https://www.conventionalcommits.org) specification.
121 | 7. Send us a pull request, answering any default questions in the pull request interface.
122 | 8. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation.
123 |
124 |
125 | ## Finding contributions to work on
126 | Looking at the existing issues is a great way to find something to contribute to.
127 |
128 | You can check:
129 | - Our known bugs list in [Bug Reports](../../issues?q=is%3Aissue%20state%3Aopen%20label%3Abug) for issues that need fixing
130 | - Feature requests in [Feature Requests](../../issues?q=is%3Aissue%20state%3Aopen%20label%3Aenhancement) for new functionality to implement
131 |
132 |
133 | ## Code of Conduct
134 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
135 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
136 | opensource-codeofconduct@amazon.com with any additional questions or comments.
137 |
138 |
139 | ## Security issue notifications
140 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.
141 |
142 |
143 | ## Licensing
144 |
145 | See the [LICENSE](./LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.
146 |
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
7 |
8 |
9 | Strands Agents
10 |
11 |
12 |
13 | A model-driven approach to building AI agents in just a few lines of code.
14 |
15 |
16 |
17 |

18 |

19 |

20 |

21 |

22 |

23 |
24 |
25 |
26 | Documentation
27 | ◆ Samples
28 | ◆ Python SDK
29 | ◆ Tools
30 | ◆ Agent Builder
31 | ◆ MCP Server
32 |
33 |
34 |
35 | Strands Agents is a simple yet powerful SDK that takes a model-driven approach to building and running AI agents. From simple conversational assistants to complex autonomous workflows, from local development to production deployment, Strands Agents scales with your needs.
36 |
37 | ## Feature Overview
38 |
39 | - **Lightweight & Flexible**: Simple agent loop that just works and is fully customizable
40 | - **Model Agnostic**: Support for Amazon Bedrock, Anthropic, LiteLLM, Llama, Ollama, OpenAI, and custom providers
41 | - **Advanced Capabilities**: Multi-agent systems, autonomous agents, and streaming support
42 | - **Built-in MCP**: Native support for Model Context Protocol (MCP) servers, enabling access to thousands of pre-built tools
43 |
44 | ## Quick Start
45 |
46 | ```bash
47 | # Install Strands Agents
48 | pip install strands-agents strands-agents-tools
49 | ```
50 |
51 | ```python
52 | from strands import Agent
53 | from strands_tools import calculator
54 | agent = Agent(tools=[calculator])
55 | agent("What is the square root of 1764")
56 | ```
57 |
58 | > **Note**: For the default Amazon Bedrock model provider, you'll need AWS credentials configured and model access enabled for Claude 3.7 Sonnet in the us-west-2 region. See the [Quickstart Guide](https://strandsagents.com/) for details on configuring other model providers.
59 |
60 | ## Installation
61 |
62 | Ensure you have Python 3.10+ installed, then:
63 |
64 | ```bash
65 | # Create and activate virtual environment
66 | python -m venv .venv
67 | source .venv/bin/activate # On Windows use: .venv\Scripts\activate
68 |
69 | # Install Strands and tools
70 | pip install strands-agents strands-agents-tools
71 | ```
72 |
73 | ## Features at a Glance
74 |
75 | ### Python-Based Tools
76 |
77 | Easily build tools using Python decorators:
78 |
79 | ```python
80 | from strands import Agent, tool
81 |
82 | @tool
83 | def word_count(text: str) -> int:
84 | """Count words in text.
85 |
86 | This docstring is used by the LLM to understand the tool's purpose.
87 | """
88 | return len(text.split())
89 |
90 | agent = Agent(tools=[word_count])
91 | response = agent("How many words are in this sentence?")
92 | ```
93 |
94 | ### MCP Support
95 |
96 | Seamlessly integrate Model Context Protocol (MCP) servers:
97 |
98 | ```python
99 | from strands import Agent
100 | from strands.tools.mcp import MCPClient
101 | from mcp import stdio_client, StdioServerParameters
102 |
103 | aws_docs_client = MCPClient(
104 | lambda: stdio_client(StdioServerParameters(command="uvx", args=["awslabs.aws-documentation-mcp-server@latest"]))
105 | )
106 |
107 | with aws_docs_client:
108 | agent = Agent(tools=aws_docs_client.list_tools_sync())
109 | response = agent("Tell me about Amazon Bedrock and how to use it with Python")
110 | ```
111 |
112 | ### Multiple Model Providers
113 |
114 | Support for various model providers:
115 |
116 | ```python
117 | from strands import Agent
118 | from strands.models import BedrockModel
119 | from strands.models.ollama import OllamaModel
120 | from strands.models.llamaapi import LlamaAPIModel
121 |
122 | # Bedrock
123 | bedrock_model = BedrockModel(
124 | model_id="us.amazon.nova-pro-v1:0",
125 | temperature=0.3,
126 | streaming=True, # Enable/disable streaming
127 | )
128 | agent = Agent(model=bedrock_model)
129 | agent("Tell me about Agentic AI")
130 |
131 | # Ollama
132 | ollama_model = OllamaModel(
133 | host="http://localhost:11434",
134 | model_id="llama3"
135 | )
136 | agent = Agent(model=ollama_model)
137 | agent("Tell me about Agentic AI")
138 |
139 | # Llama API
140 | llama_model = LlamaAPIModel(
141 | model_id="Llama-4-Maverick-17B-128E-Instruct-FP8",
142 | )
143 | agent = Agent(model=llama_model)
144 | response = agent("Tell me about Agentic AI")
145 | ```
146 |
147 | Built-in providers:
148 | - [Amazon Bedrock](https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/)
149 | - [Anthropic](https://strandsagents.com/latest/user-guide/concepts/model-providers/anthropic/)
150 | - [LiteLLM](https://strandsagents.com/latest/user-guide/concepts/model-providers/litellm/)
151 | - [LlamaAPI](https://strandsagents.com/latest/user-guide/concepts/model-providers/llamaapi/)
152 | - [Ollama](https://strandsagents.com/latest/user-guide/concepts/model-providers/ollama/)
153 | - [OpenAI](https://strandsagents.com/latest/user-guide/concepts/model-providers/openai/)
154 |
155 | Custom providers can be implemented using [Custom Providers](https://strandsagents.com/latest/user-guide/concepts/model-providers/custom_model_provider/)
156 |
157 | ### Example tools
158 |
159 | Strands offers an optional strands-agents-tools package with pre-built tools for quick experimentation:
160 |
161 | ```python
162 | from strands import Agent
163 | from strands_tools import calculator
164 | agent = Agent(tools=[calculator])
165 | agent("What is the square root of 1764")
166 | ```
167 |
168 | It's also available on GitHub via [strands-agents/tools](https://github.com/strands-agents/tools).
169 |
170 | ## Documentation
171 |
172 | For detailed guidance & examples, explore our documentation:
173 |
174 | - [User Guide](https://strandsagents.com/)
175 | - [Quick Start Guide](https://strandsagents.com/latest/user-guide/quickstart/)
176 | - [Agent Loop](https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/)
177 | - [Examples](https://strandsagents.com/latest/examples/)
178 | - [API Reference](https://strandsagents.com/latest/api-reference/agent/)
179 | - [Production & Deployment Guide](https://strandsagents.com/latest/user-guide/deploy/operating-agents-in-production/)
180 |
181 | ## Contributing ❤️
182 |
183 | We welcome contributions! See our [Contributing Guide](CONTRIBUTING.md) for details on:
184 | - Reporting bugs & features
185 | - Development setup
186 | - Contributing via Pull Requests
187 | - Code of Conduct
188 | - Reporting of security issues
189 |
190 | ## License
191 |
192 | This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details.
193 |
194 | ## Security
195 |
196 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information.
197 |
198 | ## ⚠️ Preview Status
199 |
200 | Strands Agents is currently in public preview. During this period:
201 | - APIs may change as we refine the SDK
202 | - We welcome feedback and contributions
203 |
--------------------------------------------------------------------------------
/STYLE_GUIDE.md:
--------------------------------------------------------------------------------
1 | # Style Guide
2 |
3 | ## Overview
4 |
5 | The Strands Agents style guide aims to establish consistent formatting, naming conventions, and structure across all code in the repository. We strive to make our code clean, readable, and maintainable.
6 |
7 | Where possible, we will codify these style guidelines into our linting rules and pre-commit hooks to automate enforcement and reduce the manual review burden.
8 |
9 | ## Log Formatting
10 |
11 | The format for Strands Agents logs is as follows:
12 |
13 | ```python
14 | logger.debug("field1=<%s>, field2=<%s>, ... | human readable message", field1, field2, ...)
15 | ```
16 |
17 | ### Guidelines
18 |
19 | 1. **Context**:
20 | - Add context as `=` pairs at the beginning of the log
21 | - Many log services (CloudWatch, Splunk, etc.) look for these patterns to extract fields for searching
22 | - Use `,`'s to separate pairs
23 | - Enclose values in `<>` for readability
24 | - This is particularly helpful in displaying empty values (`field=` vs `field=<>`)
25 | - Use `%s` for string interpolation as recommended by Python logging
26 | - This is an optimization to skip string interpolation when the log level is not enabled
27 |
28 | 1. **Messages**:
29 | - Add human-readable messages at the end of the log
30 | - Use lowercase for consistency
31 | - Avoid punctuation (periods, exclamation points, etc.) to reduce clutter
32 | - Keep messages concise and focused on a single statement
33 | - If multiple statements are needed, separate them with the pipe character (`|`)
34 | - Example: `"processing request | starting validation"`
35 |
36 | ### Examples
37 |
38 | #### Good
39 |
40 | ```python
41 | logger.debug("user_id=<%s>, action=<%s> | user performed action", user_id, action)
42 | logger.info("request_id=<%s>, duration_ms=<%d> | request completed", request_id, duration)
43 | logger.warning("attempt=<%d>, max_attempts=<%d> | retry limit approaching", attempt, max_attempts)
44 | ```
45 |
46 | #### Poor
47 |
48 | ```python
49 | # Avoid: No structured fields, direct variable interpolation in message
50 | logger.debug(f"User {user_id} performed action {action}")
51 |
52 | # Avoid: Inconsistent formatting, punctuation
53 | logger.info("Request completed in %d ms.", duration)
54 |
55 | # Avoid: No separation between fields and message
56 | logger.warning("Retry limit approaching! attempt=%d max_attempts=%d", attempt, max_attempts)
57 | ```
58 |
59 | By following these log formatting guidelines, we ensure that logs are both human-readable and machine-parseable, making debugging and monitoring more efficient.
60 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["hatchling", "hatch-vcs"]
3 | build-backend = "hatchling.build"
4 |
5 | [project]
6 | name = "strands-agents"
7 | dynamic = ["version"]
8 | description = "A model-driven approach to building AI agents in just a few lines of code"
9 | readme = "README.md"
10 | requires-python = ">=3.10"
11 | license = {text = "Apache-2.0"}
12 | authors = [
13 | {name = "AWS", email = "opensource@amazon.com"},
14 | ]
15 | classifiers = [
16 | "Development Status :: 3 - Alpha",
17 | "Intended Audience :: Developers",
18 | "License :: OSI Approved :: Apache Software License",
19 | "Operating System :: OS Independent",
20 | "Programming Language :: Python :: 3",
21 | "Programming Language :: Python :: 3.10",
22 | "Programming Language :: Python :: 3.11",
23 | "Programming Language :: Python :: 3.12",
24 | "Programming Language :: Python :: 3.13",
25 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
26 | "Topic :: Software Development :: Libraries :: Python Modules",
27 | ]
28 | dependencies = [
29 | "boto3>=1.26.0,<2.0.0",
30 | "botocore>=1.29.0,<2.0.0",
31 | "docstring_parser>=0.15,<0.16.0",
32 | "mcp>=1.8.0,<2.0.0",
33 | "pydantic>=2.0.0,<3.0.0",
34 | "typing-extensions>=4.13.2,<5.0.0",
35 | "watchdog>=6.0.0,<7.0.0",
36 | "opentelemetry-api>=1.30.0,<2.0.0",
37 | "opentelemetry-sdk>=1.30.0,<2.0.0",
38 | "opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0",
39 | ]
40 |
41 | [project.urls]
42 | Homepage = "https://github.com/strands-agents/sdk-python"
43 | "Bug Tracker" = "https://github.com/strands-agents/sdk-python/issues"
44 | Documentation = "https://strandsagents.com"
45 |
46 | [tool.hatch.build.targets.wheel]
47 | packages = ["src/strands"]
48 |
49 | [project.optional-dependencies]
50 | anthropic = [
51 | "anthropic>=0.21.0,<1.0.0",
52 | ]
53 | dev = [
54 | "commitizen>=4.4.0,<5.0.0",
55 | "hatch>=1.0.0,<2.0.0",
56 | "moto>=5.1.0,<6.0.0",
57 | "mypy>=1.15.0,<2.0.0",
58 | "pre-commit>=3.2.0,<4.2.0",
59 | "pytest>=8.0.0,<9.0.0",
60 | "pytest-asyncio>=0.26.0,<0.27.0",
61 | "ruff>=0.4.4,<0.5.0",
62 | "swagger-parser>=1.0.2,<2.0.0",
63 | ]
64 | docs = [
65 | "sphinx>=5.0.0,<6.0.0",
66 | "sphinx-rtd-theme>=1.0.0,<2.0.0",
67 | "sphinx-autodoc-typehints>=1.12.0,<2.0.0",
68 | ]
69 | litellm = [
70 | "litellm>=1.69.0,<2.0.0",
71 | ]
72 | llamaapi = [
73 | "llama-api-client>=0.1.0,<1.0.0",
74 | ]
75 | ollama = [
76 | "ollama>=0.4.8,<1.0.0",
77 | ]
78 | openai = [
79 | "openai>=1.68.0,<2.0.0",
80 | ]
81 |
82 | [tool.hatch.version]
83 | # Tells Hatch to use your version control system (git) to determine the version.
84 | source = "vcs"
85 |
86 | [tool.hatch.envs.hatch-static-analysis]
87 | features = ["anthropic", "litellm", "llamaapi", "ollama", "openai"]
88 | dependencies = [
89 | "mypy>=1.15.0,<2.0.0",
90 | "ruff>=0.11.6,<0.12.0",
91 | "strands-agents @ {root:uri}"
92 | ]
93 |
94 | [tool.hatch.envs.hatch-static-analysis.scripts]
95 | format-check = [
96 | "ruff format --check"
97 | ]
98 | format-fix = [
99 | "ruff format"
100 | ]
101 | lint-check = [
102 | "ruff check",
103 | "mypy -p src"
104 | ]
105 | lint-fix = [
106 | "ruff check --fix"
107 | ]
108 |
109 | [tool.hatch.envs.hatch-test]
110 | features = ["anthropic", "litellm", "llamaapi", "ollama", "openai"]
111 | extra-dependencies = [
112 | "moto>=5.1.0,<6.0.0",
113 | "pytest>=8.0.0,<9.0.0",
114 | "pytest-asyncio>=0.26.0,<0.27.0",
115 | "pytest-cov>=4.1.0,<5.0.0",
116 | "pytest-xdist>=3.0.0,<4.0.0",
117 | ]
118 | extra-args = [
119 | "-n",
120 | "auto",
121 | "-vv",
122 | ]
123 |
124 | [tool.hatch.envs.dev]
125 | dev-mode = true
126 | features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama"]
127 |
128 |
129 |
130 | [[tool.hatch.envs.hatch-test.matrix]]
131 | python = ["3.13", "3.12", "3.11", "3.10"]
132 |
133 |
134 | [tool.hatch.envs.hatch-test.scripts]
135 | run = [
136 | "pytest{env:HATCH_TEST_ARGS:} {args}"
137 | ]
138 | run-cov = [
139 | "pytest{env:HATCH_TEST_ARGS:} --cov --cov-config=pyproject.toml {args}"
140 | ]
141 |
142 | cov-combine = []
143 | cov-report = []
144 |
145 |
146 | [tool.hatch.envs.default.scripts]
147 | list = [
148 | "echo 'Scripts commands available for default env:'; hatch env show --json | jq --raw-output '.default.scripts | keys[]'"
149 | ]
150 | format = [
151 | "hatch fmt --formatter",
152 | ]
153 | test-format = [
154 | "hatch fmt --formatter --check",
155 | ]
156 | lint = [
157 | "hatch fmt --linter"
158 | ]
159 | test-lint = [
160 | "hatch fmt --linter --check"
161 | ]
162 | test = [
163 | "hatch test --cover --cov-report html --cov-report xml {args}"
164 | ]
165 | test-integ = [
166 | "hatch test tests-integ {args}"
167 | ]
168 |
169 |
170 | [tool.mypy]
171 | python_version = "3.10"
172 | warn_return_any = true
173 | warn_unused_configs = true
174 | disallow_untyped_defs = true
175 | disallow_incomplete_defs = true
176 | check_untyped_defs = true
177 | disallow_untyped_decorators = true
178 | no_implicit_optional = true
179 | warn_redundant_casts = true
180 | warn_unused_ignores = true
181 | warn_no_return = true
182 | warn_unreachable = true
183 | follow_untyped_imports = true
184 | ignore_missing_imports = false
185 |
186 | [[tool.mypy.overrides]]
187 | module = "litellm"
188 | ignore_missing_imports = true
189 |
190 | [tool.ruff]
191 | line-length = 120
192 | include = ["examples/**/*.py", "src/**/*.py", "tests/**/*.py", "tests-integ/**/*.py"]
193 |
194 | [tool.ruff.lint]
195 | select = [
196 | "B", # flake8-bugbear
197 | "D", # pydocstyle
198 | "E", # pycodestyle
199 | "F", # pyflakes
200 | "G", # logging format
201 | "I", # isort
202 | "LOG", # logging
203 | ]
204 |
205 | [tool.ruff.lint.per-file-ignores]
206 | "!src/**/*.py" = ["D"]
207 |
208 | [tool.ruff.lint.pydocstyle]
209 | convention = "google"
210 |
211 | [tool.pytest.ini_options]
212 | testpaths = [
213 | "tests"
214 | ]
215 | asyncio_default_fixture_loop_scope = "function"
216 |
217 | [tool.coverage.run]
218 | branch = true
219 | source = ["src"]
220 | context = "thread"
221 | parallel = true
222 | concurrency = ["thread", "multiprocessing"]
223 |
224 | [tool.coverage.report]
225 | show_missing = true
226 |
227 | [tool.coverage.html]
228 | directory = "build/coverage/html"
229 |
230 | [tool.coverage.xml]
231 | output = "build/coverage/coverage.xml"
232 |
233 | [tool.commitizen]
234 | name = "cz_conventional_commits"
235 | tag_format = "v$version"
236 | bump_message = "chore(release): bump version $current_version -> $new_version"
237 | version_files = [
238 | "pyproject.toml:version",
239 | ]
240 | update_changelog_on_bump = true
241 | style = [
242 | ["qmark", "fg:#ff9d00 bold"],
243 | ["question", "bold"],
244 | ["answer", "fg:#ff9d00 bold"],
245 | ["pointer", "fg:#ff9d00 bold"],
246 | ["highlighted", "fg:#ff9d00 bold"],
247 | ["selected", "fg:#cc5454"],
248 | ["separator", "fg:#cc5454"],
249 | ["instruction", ""],
250 | ["text", ""],
251 | ["disabled", "fg:#858585 italic"]
252 | ]
253 |
--------------------------------------------------------------------------------
/src/strands/__init__.py:
--------------------------------------------------------------------------------
1 | """A framework for building, deploying, and managing AI agents."""
2 |
3 | from . import agent, event_loop, models, telemetry, types
4 | from .agent.agent import Agent
5 | from .tools.decorator import tool
6 | from .tools.thread_pool_executor import ThreadPoolExecutorWrapper
7 |
8 | __all__ = ["Agent", "ThreadPoolExecutorWrapper", "agent", "event_loop", "models", "tool", "types", "telemetry"]
9 |
--------------------------------------------------------------------------------
/src/strands/agent/__init__.py:
--------------------------------------------------------------------------------
1 | """This package provides the core Agent interface and supporting components for building AI agents with the SDK.
2 |
3 | It includes:
4 |
5 | - Agent: The main interface for interacting with AI models and tools
6 | - ConversationManager: Classes for managing conversation history and context windows
7 | """
8 |
9 | from .agent import Agent
10 | from .agent_result import AgentResult
11 | from .conversation_manager import ConversationManager, NullConversationManager, SlidingWindowConversationManager
12 |
13 | __all__ = [
14 | "Agent",
15 | "AgentResult",
16 | "ConversationManager",
17 | "NullConversationManager",
18 | "SlidingWindowConversationManager",
19 | ]
20 |
--------------------------------------------------------------------------------
/src/strands/agent/agent_result.py:
--------------------------------------------------------------------------------
1 | """Agent result handling for SDK.
2 |
3 | This module defines the AgentResult class which encapsulates the complete response from an agent's processing cycle.
4 | """
5 |
6 | from dataclasses import dataclass
7 | from typing import Any
8 |
9 | from ..telemetry.metrics import EventLoopMetrics
10 | from ..types.content import Message
11 | from ..types.streaming import StopReason
12 |
13 |
14 | @dataclass
15 | class AgentResult:
16 | """Represents the last result of invoking an agent with a prompt.
17 |
18 | Attributes:
19 | stop_reason: The reason why the agent's processing stopped.
20 | message: The last message generated by the agent.
21 | metrics: Performance metrics collected during processing.
22 | state: Additional state information from the event loop.
23 | """
24 |
25 | stop_reason: StopReason
26 | message: Message
27 | metrics: EventLoopMetrics
28 | state: Any
29 |
30 | def __str__(self) -> str:
31 | """Get the agent's last message as a string.
32 |
33 | This method extracts and concatenates all text content from the final message, ignoring any non-text content
34 | like images or structured data.
35 |
36 | Returns:
37 | The agent's last message as a string.
38 | """
39 | content_array = self.message.get("content", [])
40 |
41 | result = ""
42 | for item in content_array:
43 | if isinstance(item, dict) and "text" in item:
44 | result += item.get("text", "") + "\n"
45 |
46 | return result
47 |
--------------------------------------------------------------------------------
/src/strands/agent/conversation_manager/__init__.py:
--------------------------------------------------------------------------------
1 | """This package provides classes for managing conversation history during agent execution.
2 |
3 | It includes:
4 |
5 | - ConversationManager: Abstract base class defining the conversation management interface
6 | - NullConversationManager: A no-op implementation that does not modify conversation history
7 | - SlidingWindowConversationManager: An implementation that maintains a sliding window of messages to control context
8 | size while preserving conversation coherence
9 |
10 | Conversation managers help control memory usage and context length while maintaining relevant conversation state, which
11 | is critical for effective agent interactions.
12 | """
13 |
14 | from .conversation_manager import ConversationManager
15 | from .null_conversation_manager import NullConversationManager
16 | from .sliding_window_conversation_manager import SlidingWindowConversationManager
17 |
18 | __all__ = ["ConversationManager", "NullConversationManager", "SlidingWindowConversationManager"]
19 |
--------------------------------------------------------------------------------
/src/strands/agent/conversation_manager/conversation_manager.py:
--------------------------------------------------------------------------------
1 | """Abstract interface for conversation history management."""
2 |
3 | from abc import ABC, abstractmethod
4 | from typing import TYPE_CHECKING, Optional
5 |
6 | if TYPE_CHECKING:
7 | from ...agent.agent import Agent
8 |
9 |
10 | class ConversationManager(ABC):
11 | """Abstract base class for managing conversation history.
12 |
13 | This class provides an interface for implementing conversation management strategies to control the size of message
14 | arrays/conversation histories, helping to:
15 |
16 | - Manage memory usage
17 | - Control context length
18 | - Maintain relevant conversation state
19 | """
20 |
21 | @abstractmethod
22 | # pragma: no cover
23 | def apply_management(self, agent: "Agent") -> None:
24 | """Applies management strategy to the provided agent.
25 |
26 | Processes the conversation history to maintain appropriate size by modifying the messages list in-place.
27 | Implementations should handle message pruning, summarization, or other size management techniques to keep the
28 | conversation context within desired bounds.
29 |
30 | Args:
31 | agent: The agent whose conversation history will be manage.
32 | This list is modified in-place.
33 | """
34 | pass
35 |
36 | @abstractmethod
37 | # pragma: no cover
38 | def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None:
39 | """Called when the model's context window is exceeded.
40 |
41 | This method should implement the specific strategy for reducing the window size when a context overflow occurs.
42 | It is typically called after a ContextWindowOverflowException is caught.
43 |
44 | Implementations might use strategies such as:
45 |
46 | - Removing the N oldest messages
47 | - Summarizing older context
48 | - Applying importance-based filtering
49 | - Maintaining critical conversation markers
50 |
51 | Args:
52 | agent: The agent whose conversation history will be reduced.
53 | This list is modified in-place.
54 | e: The exception that triggered the context reduction, if any.
55 | """
56 | pass
57 |
--------------------------------------------------------------------------------
/src/strands/agent/conversation_manager/null_conversation_manager.py:
--------------------------------------------------------------------------------
1 | """Null implementation of conversation management."""
2 |
3 | from typing import TYPE_CHECKING, Optional
4 |
5 | if TYPE_CHECKING:
6 | from ...agent.agent import Agent
7 |
8 | from ...types.exceptions import ContextWindowOverflowException
9 | from .conversation_manager import ConversationManager
10 |
11 |
12 | class NullConversationManager(ConversationManager):
13 | """A no-op conversation manager that does not modify the conversation history.
14 |
15 | Useful for:
16 |
17 | - Testing scenarios where conversation management should be disabled
18 | - Cases where conversation history is managed externally
19 | - Situations where the full conversation history should be preserved
20 | """
21 |
22 | def apply_management(self, _agent: "Agent") -> None:
23 | """Does nothing to the conversation history.
24 |
25 | Args:
26 | agent: The agent whose conversation history will remain unmodified.
27 | """
28 | pass
29 |
30 | def reduce_context(self, _agent: "Agent", e: Optional[Exception] = None) -> None:
31 | """Does not reduce context and raises an exception.
32 |
33 | Args:
34 | agent: The agent whose conversation history will remain unmodified.
35 | e: The exception that triggered the context reduction, if any.
36 |
37 | Raises:
38 | e: If provided.
39 | ContextWindowOverflowException: If e is None.
40 | """
41 | if e:
42 | raise e
43 | else:
44 | raise ContextWindowOverflowException("Context window overflowed!")
45 |
--------------------------------------------------------------------------------
/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py:
--------------------------------------------------------------------------------
1 | """Sliding window conversation history management."""
2 |
3 | import logging
4 | from typing import TYPE_CHECKING, Optional
5 |
6 | if TYPE_CHECKING:
7 | from ...agent.agent import Agent
8 |
9 | from ...types.content import Message, Messages
10 | from ...types.exceptions import ContextWindowOverflowException
11 | from .conversation_manager import ConversationManager
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | def is_user_message(message: Message) -> bool:
17 | """Check if a message is from a user.
18 |
19 | Args:
20 | message: The message object to check.
21 |
22 | Returns:
23 | True if the message has the user role, False otherwise.
24 | """
25 | return message["role"] == "user"
26 |
27 |
28 | def is_assistant_message(message: Message) -> bool:
29 | """Check if a message is from an assistant.
30 |
31 | Args:
32 | message: The message object to check.
33 |
34 | Returns:
35 | True if the message has the assistant role, False otherwise.
36 | """
37 | return message["role"] == "assistant"
38 |
39 |
40 | class SlidingWindowConversationManager(ConversationManager):
41 | """Implements a sliding window strategy for managing conversation history.
42 |
43 | This class handles the logic of maintaining a conversation window that preserves tool usage pairs and avoids
44 | invalid window states.
45 | """
46 |
47 | def __init__(self, window_size: int = 40):
48 | """Initialize the sliding window conversation manager.
49 |
50 | Args:
51 | window_size: Maximum number of messages to keep in the agent's history.
52 | Defaults to 40 messages.
53 | """
54 | self.window_size = window_size
55 |
56 | def apply_management(self, agent: "Agent") -> None:
57 | """Apply the sliding window to the agent's messages array to maintain a manageable history size.
58 |
59 | This method is called after every event loop cycle, as the messages array may have been modified with tool
60 | results and assistant responses. It first removes any dangling messages that might create an invalid
61 | conversation state, then applies the sliding window if the message count exceeds the window size.
62 |
63 | Special handling is implemented to ensure we don't leave a user message with toolResult
64 | as the first message in the array. It also ensures that all toolUse blocks have corresponding toolResult
65 | blocks to maintain conversation coherence.
66 |
67 | Args:
68 | agent: The agent whose messages will be managed.
69 | This list is modified in-place.
70 | """
71 | messages = agent.messages
72 | self._remove_dangling_messages(messages)
73 |
74 | if len(messages) <= self.window_size:
75 | logger.debug(
76 | "window_size=<%s>, message_count=<%s> | skipping context reduction", len(messages), self.window_size
77 | )
78 | return
79 | self.reduce_context(agent)
80 |
81 | def _remove_dangling_messages(self, messages: Messages) -> None:
82 | """Remove dangling messages that would create an invalid conversation state.
83 |
84 | After the event loop cycle is executed, we expect the messages array to end with either an assistant tool use
85 | request followed by the pairing user tool result or an assistant response with no tool use request. If the
86 | event loop cycle fails, we may end up in an invalid message state, and so this method will remove problematic
87 | messages from the end of the array.
88 |
89 | This method handles two specific cases:
90 |
91 | - User with no tool result: Indicates that event loop failed to generate an assistant tool use request
92 | - Assistant with tool use request: Indicates that event loop failed to generate a pairing user tool result
93 |
94 | Args:
95 | messages: The messages to clean up.
96 | This list is modified in-place.
97 | """
98 | # remove any dangling user messages with no ToolResult
99 | if len(messages) > 0 and is_user_message(messages[-1]):
100 | if not any("toolResult" in content for content in messages[-1]["content"]):
101 | messages.pop()
102 |
103 | # remove any dangling assistant messages with ToolUse
104 | if len(messages) > 0 and is_assistant_message(messages[-1]):
105 | if any("toolUse" in content for content in messages[-1]["content"]):
106 | messages.pop()
107 | # remove remaining dangling user messages with no ToolResult after we popped off an assistant message
108 | if len(messages) > 0 and is_user_message(messages[-1]):
109 | if not any("toolResult" in content for content in messages[-1]["content"]):
110 | messages.pop()
111 |
112 | def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None:
113 | """Trim the oldest messages to reduce the conversation context size.
114 |
115 | The method handles special cases where trimming the messages leads to:
116 | - toolResult with no corresponding toolUse
117 | - toolUse with no corresponding toolResult
118 |
119 | Args:
120 | agent: The agent whose messages will be reduce.
121 | This list is modified in-place.
122 | e: The exception that triggered the context reduction, if any.
123 |
124 | Raises:
125 | ContextWindowOverflowException: If the context cannot be reduced further.
126 | Such as when the conversation is already minimal or when tool result messages cannot be properly
127 | converted.
128 | """
129 | messages = agent.messages
130 | # If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size
131 | trim_index = 2 if len(messages) <= self.window_size else len(messages) - self.window_size
132 |
133 | # Find the next valid trim_index
134 | while trim_index < len(messages):
135 | if (
136 | # Oldest message cannot be a toolResult because it needs a toolUse preceding it
137 | any("toolResult" in content for content in messages[trim_index]["content"])
138 | or (
139 | # Oldest message can be a toolUse only if a toolResult immediately follows it.
140 | any("toolUse" in content for content in messages[trim_index]["content"])
141 | and trim_index + 1 < len(messages)
142 | and not any("toolResult" in content for content in messages[trim_index + 1]["content"])
143 | )
144 | ):
145 | trim_index += 1
146 | else:
147 | break
148 | else:
149 | # If we didn't find a valid trim_index, then we throw
150 | raise ContextWindowOverflowException("Unable to trim conversation context!") from e
151 |
152 | # Overwrite message history
153 | messages[:] = messages[trim_index:]
154 |
--------------------------------------------------------------------------------
/src/strands/event_loop/__init__.py:
--------------------------------------------------------------------------------
1 | """This package provides the core event loop implementation for the agents SDK.
2 |
3 | The event loop enables conversational AI agents to process messages, execute tools, and handle errors in a controlled,
4 | iterative manner.
5 | """
6 |
7 | from . import error_handler, event_loop, message_processor
8 |
9 | __all__ = ["error_handler", "event_loop", "message_processor"]
10 |
--------------------------------------------------------------------------------
/src/strands/event_loop/error_handler.py:
--------------------------------------------------------------------------------
1 | """This module provides specialized error handlers for common issues that may occur during event loop execution.
2 |
3 | Examples include throttling exceptions and context window overflow errors. These handlers implement recovery strategies
4 | like exponential backoff for throttling and message truncation for context window limitations.
5 | """
6 |
7 | import logging
8 | import time
9 | from typing import Any, Dict, Optional, Tuple
10 |
11 | from ..telemetry.metrics import EventLoopMetrics
12 | from ..types.content import Message, Messages
13 | from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
14 | from ..types.models import Model
15 | from ..types.streaming import StopReason
16 | from .message_processor import find_last_message_with_tool_results, truncate_tool_results
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | def handle_throttling_error(
22 | e: ModelThrottledException,
23 | attempt: int,
24 | max_attempts: int,
25 | current_delay: int,
26 | max_delay: int,
27 | callback_handler: Any,
28 | kwargs: Dict[str, Any],
29 | ) -> Tuple[bool, int]:
30 | """Handle throttling exceptions from the model provider with exponential backoff.
31 |
32 | Args:
33 | e: The exception that occurred during model invocation.
34 | attempt: Number of times event loop has attempted model invocation.
35 | max_attempts: Maximum number of retry attempts allowed.
36 | current_delay: Current delay in seconds before retrying.
37 | max_delay: Maximum delay in seconds (cap for exponential growth).
38 | callback_handler: Callback for processing events as they happen.
39 | kwargs: Additional arguments to pass to the callback handler.
40 |
41 | Returns:
42 | A tuple containing:
43 | - bool: True if retry should be attempted, False otherwise
44 | - int: The new delay to use for the next retry attempt
45 | """
46 | if attempt < max_attempts - 1: # Don't sleep on last attempt
47 | logger.debug(
48 | "retry_delay_seconds=<%s>, max_attempts=<%s>, current_attempt=<%s> "
49 | "| throttling exception encountered "
50 | "| delaying before next retry",
51 | current_delay,
52 | max_attempts,
53 | attempt + 1,
54 | )
55 | callback_handler(event_loop_throttled_delay=current_delay, **kwargs)
56 | time.sleep(current_delay)
57 | new_delay = min(current_delay * 2, max_delay) # Double delay each retry
58 | return True, new_delay
59 |
60 | callback_handler(force_stop=True, force_stop_reason=str(e))
61 | return False, current_delay
62 |
63 |
64 | def handle_input_too_long_error(
65 | e: ContextWindowOverflowException,
66 | messages: Messages,
67 | model: Model,
68 | system_prompt: Optional[str],
69 | tool_config: Any,
70 | callback_handler: Any,
71 | tool_handler: Any,
72 | kwargs: Dict[str, Any],
73 | ) -> Tuple[StopReason, Message, EventLoopMetrics, Any]:
74 | """Handle 'Input is too long' errors by truncating tool results.
75 |
76 | When a context window overflow exception occurs (input too long for the model), this function attempts to recover
77 | by finding and truncating the most recent tool results in the conversation history. If truncation is successful, the
78 | function will make a call to the event loop.
79 |
80 | Args:
81 | e: The ContextWindowOverflowException that occurred.
82 | messages: The conversation message history.
83 | model: Model provider for running inference.
84 | system_prompt: System prompt for the model.
85 | tool_config: Tool configuration for the conversation.
86 | callback_handler: Callback for processing events as they happen.
87 | tool_handler: Handler for tool execution.
88 | kwargs: Additional arguments for the event loop.
89 |
90 | Returns:
91 | The results from the event loop call if successful.
92 |
93 | Raises:
94 | ContextWindowOverflowException: If messages cannot be truncated.
95 | """
96 | from .event_loop import recurse_event_loop # Import here to avoid circular imports
97 |
98 | # Find the last message with tool results
99 | last_message_with_tool_results = find_last_message_with_tool_results(messages)
100 |
101 | # If we found a message with toolResult
102 | if last_message_with_tool_results is not None:
103 | logger.debug("message_index=<%s> | found message with tool results at index", last_message_with_tool_results)
104 |
105 | # Truncate the tool results in this message
106 | truncate_tool_results(messages, last_message_with_tool_results)
107 |
108 | return recurse_event_loop(
109 | model=model,
110 | system_prompt=system_prompt,
111 | messages=messages,
112 | tool_config=tool_config,
113 | callback_handler=callback_handler,
114 | tool_handler=tool_handler,
115 | **kwargs,
116 | )
117 |
118 | # If we can't handle this error, pass it up
119 | callback_handler(force_stop=True, force_stop_reason=str(e))
120 | logger.error("an exception occurred in event_loop_cycle | %s", e)
121 | raise ContextWindowOverflowException() from e
122 |
--------------------------------------------------------------------------------
/src/strands/event_loop/message_processor.py:
--------------------------------------------------------------------------------
1 | """This module provides utilities for processing and manipulating conversation messages within the event loop.
2 |
3 | It includes functions for cleaning up orphaned tool uses, finding messages with specific content types, and truncating
4 | large tool results to prevent context window overflow.
5 | """
6 |
7 | import logging
8 | from typing import Dict, Optional, Set, Tuple
9 |
10 | from ..types.content import Messages
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | def clean_orphaned_empty_tool_uses(messages: Messages) -> bool:
16 | """Clean up orphaned empty tool uses in conversation messages.
17 |
18 | This function identifies and removes any toolUse entries with empty input that don't have a corresponding
19 | toolResult. This prevents validation errors that occur when the model expects matching toolResult blocks for each
20 | toolUse.
21 |
22 | The function applies fixes by either:
23 |
24 | 1. Replacing a message containing only an orphaned toolUse with a context message
25 | 2. Removing the orphaned toolUse entry from a message with multiple content items
26 |
27 | Args:
28 | messages: The conversation message history.
29 |
30 | Returns:
31 | True if any fixes were applied, False otherwise.
32 | """
33 | if not messages:
34 | return False
35 |
36 | # Dictionary to track empty toolUse entries: {tool_id: (msg_index, content_index, tool_name)}
37 | empty_tool_uses: Dict[str, Tuple[int, int, str]] = {}
38 |
39 | # Set to track toolResults that have been seen
40 | tool_results: Set[str] = set()
41 |
42 | # Identify empty toolUse entries
43 | for i, msg in enumerate(messages):
44 | if msg.get("role") != "assistant":
45 | continue
46 |
47 | for j, content in enumerate(msg.get("content", [])):
48 | if isinstance(content, dict) and "toolUse" in content:
49 | tool_use = content.get("toolUse", {})
50 | tool_id = tool_use.get("toolUseId")
51 | tool_input = tool_use.get("input", {})
52 | tool_name = tool_use.get("name", "unknown tool")
53 |
54 | # Check if this is an empty toolUse
55 | if tool_id and (not tool_input or tool_input == {}):
56 | empty_tool_uses[tool_id] = (i, j, tool_name)
57 |
58 | # Identify toolResults
59 | for msg in messages:
60 | if msg.get("role") != "user":
61 | continue
62 |
63 | for content in msg.get("content", []):
64 | if isinstance(content, dict) and "toolResult" in content:
65 | tool_result = content.get("toolResult", {})
66 | tool_id = tool_result.get("toolUseId")
67 | if tool_id:
68 | tool_results.add(tool_id)
69 |
70 | # Filter for orphaned empty toolUses (no corresponding toolResult)
71 | orphaned_tool_uses = {tool_id: info for tool_id, info in empty_tool_uses.items() if tool_id not in tool_results}
72 |
73 | # Apply fixes in reverse order of occurrence (to avoid index shifting)
74 | if not orphaned_tool_uses:
75 | return False
76 |
77 | # Sort by message index and content index in reverse order
78 | sorted_orphaned = sorted(orphaned_tool_uses.items(), key=lambda x: (x[1][0], x[1][1]), reverse=True)
79 |
80 | # Apply fixes
81 | for tool_id, (msg_idx, content_idx, tool_name) in sorted_orphaned:
82 | logger.debug(
83 | "tool_name=<%s>, tool_id=<%s>, message_index=<%s>, content_index=<%s> "
84 | "fixing orphaned empty tool use at message index",
85 | tool_name,
86 | tool_id,
87 | msg_idx,
88 | content_idx,
89 | )
90 | try:
91 | # Check if this is the sole content in the message
92 | if len(messages[msg_idx]["content"]) == 1:
93 | # Replace with a message indicating the attempted tool
94 | messages[msg_idx]["content"] = [{"text": f"[Attempted to use {tool_name}, but operation was canceled]"}]
95 | logger.debug("message_index=<%s> | replaced content with context message", msg_idx)
96 | else:
97 | # Simply remove the orphaned toolUse entry
98 | messages[msg_idx]["content"].pop(content_idx)
99 | logger.debug(
100 | "message_index=<%s>, content_index=<%s> | removed content item from message", msg_idx, content_idx
101 | )
102 | except Exception as e:
103 | logger.warning("failed to fix orphaned tool use | %s", e)
104 |
105 | return True
106 |
107 |
108 | def find_last_message_with_tool_results(messages: Messages) -> Optional[int]:
109 | """Find the index of the last message containing tool results.
110 |
111 | This is useful for identifying messages that might need to be truncated to reduce context size.
112 |
113 | Args:
114 | messages: The conversation message history.
115 |
116 | Returns:
117 | Index of the last message with tool results, or None if no such message exists.
118 | """
119 | # Iterate backwards through all messages (from newest to oldest)
120 | for idx in range(len(messages) - 1, -1, -1):
121 | # Check if this message has any content with toolResult
122 | current_message = messages[idx]
123 | has_tool_result = False
124 |
125 | for content in current_message.get("content", []):
126 | if isinstance(content, dict) and "toolResult" in content:
127 | has_tool_result = True
128 | break
129 |
130 | if has_tool_result:
131 | return idx
132 |
133 | return None
134 |
135 |
136 | def truncate_tool_results(messages: Messages, msg_idx: int) -> bool:
137 | """Truncate tool results in a message to reduce context size.
138 |
139 | When a message contains tool results that are too large for the model's context window, this function replaces the
140 | content of those tool results with a simple error message.
141 |
142 | Args:
143 | messages: The conversation message history.
144 | msg_idx: Index of the message containing tool results to truncate.
145 |
146 | Returns:
147 | True if any changes were made to the message, False otherwise.
148 | """
149 | if msg_idx >= len(messages) or msg_idx < 0:
150 | return False
151 |
152 | message = messages[msg_idx]
153 | changes_made = False
154 |
155 | for i, content in enumerate(message.get("content", [])):
156 | if isinstance(content, dict) and "toolResult" in content:
157 | # Update status to error with informative message
158 | message["content"][i]["toolResult"]["status"] = "error"
159 | message["content"][i]["toolResult"]["content"] = [{"text": "The tool result was too large!"}]
160 | changes_made = True
161 |
162 | return changes_made
163 |
--------------------------------------------------------------------------------
/src/strands/handlers/__init__.py:
--------------------------------------------------------------------------------
1 | """Various handlers for performing custom actions on agent state.
2 |
3 | Examples include:
4 |
5 | - Processing tool invocations
6 | - Displaying events from the event stream
7 | """
8 |
9 | from .callback_handler import CompositeCallbackHandler, PrintingCallbackHandler, null_callback_handler
10 |
11 | __all__ = ["CompositeCallbackHandler", "null_callback_handler", "PrintingCallbackHandler"]
12 |
--------------------------------------------------------------------------------
/src/strands/handlers/callback_handler.py:
--------------------------------------------------------------------------------
1 | """This module provides handlers for formatting and displaying events from the agent."""
2 |
3 | from collections.abc import Callable
4 | from typing import Any
5 |
6 |
7 | class PrintingCallbackHandler:
8 | """Handler for streaming text output and tool invocations to stdout."""
9 |
10 | def __init__(self) -> None:
11 | """Initialize handler."""
12 | self.tool_count = 0
13 | self.previous_tool_use = None
14 |
15 | def __call__(self, **kwargs: Any) -> None:
16 | """Stream text output and tool invocations to stdout.
17 |
18 | Args:
19 | **kwargs: Callback event data including:
20 | - reasoningText (Optional[str]): Reasoning text to print if provided.
21 | - data (str): Text content to stream.
22 | - complete (bool): Whether this is the final chunk of a response.
23 | - current_tool_use (dict): Information about the current tool being used.
24 | """
25 | reasoningText = kwargs.get("reasoningText", False)
26 | data = kwargs.get("data", "")
27 | complete = kwargs.get("complete", False)
28 | current_tool_use = kwargs.get("current_tool_use", {})
29 |
30 | if reasoningText:
31 | print(reasoningText, end="")
32 |
33 | if data:
34 | print(data, end="" if not complete else "\n")
35 |
36 | if current_tool_use and current_tool_use.get("name"):
37 | tool_name = current_tool_use.get("name", "Unknown tool")
38 | if self.previous_tool_use != current_tool_use:
39 | self.previous_tool_use = current_tool_use
40 | self.tool_count += 1
41 | print(f"\nTool #{self.tool_count}: {tool_name}")
42 |
43 | if complete and data:
44 | print("\n")
45 |
46 |
47 | class CompositeCallbackHandler:
48 | """Class-based callback handler that combines multiple callback handlers.
49 |
50 | This handler allows multiple callback handlers to be invoked for the same events,
51 | enabling different processing or output formats for the same stream data.
52 | """
53 |
54 | def __init__(self, *handlers: Callable) -> None:
55 | """Initialize handler."""
56 | self.handlers = handlers
57 |
58 | def __call__(self, **kwargs: Any) -> None:
59 | """Invoke all handlers in the chain."""
60 | for handler in self.handlers:
61 | handler(**kwargs)
62 |
63 |
64 | def null_callback_handler(**_kwargs: Any) -> None:
65 | """Callback handler that discards all output.
66 |
67 | Args:
68 | **_kwargs: Event data (ignored).
69 | """
70 | return None
71 |
--------------------------------------------------------------------------------
/src/strands/handlers/tool_handler.py:
--------------------------------------------------------------------------------
1 | """This module provides handlers for managing tool invocations."""
2 |
3 | import logging
4 | from typing import Any, List, Optional
5 |
6 | from ..tools.registry import ToolRegistry
7 | from ..types.models import Model
8 | from ..types.tools import ToolConfig, ToolHandler, ToolResult, ToolUse
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | class AgentToolHandler(ToolHandler):
14 | """Handler for processing tool invocations in agent.
15 |
16 | This class implements the ToolHandler interface and provides functionality for looking up tools in a registry and
17 | invoking them with the appropriate parameters.
18 | """
19 |
20 | def __init__(self, tool_registry: ToolRegistry) -> None:
21 | """Initialize handler.
22 |
23 | Args:
24 | tool_registry: Registry of available tools.
25 | """
26 | self.tool_registry = tool_registry
27 |
28 | def preprocess(
29 | self,
30 | tool: ToolUse,
31 | tool_config: ToolConfig,
32 | **kwargs: Any,
33 | ) -> Optional[ToolResult]:
34 | """Preprocess a tool before invocation (not implemented).
35 |
36 | Args:
37 | tool: The tool use object to preprocess.
38 | tool_config: Configuration for the tool.
39 | **kwargs: Additional keyword arguments.
40 |
41 | Returns:
42 | Result of preprocessing, if any.
43 | """
44 | pass
45 |
46 | def process(
47 | self,
48 | tool: Any,
49 | *,
50 | model: Model,
51 | system_prompt: Optional[str],
52 | messages: List[Any],
53 | tool_config: Any,
54 | callback_handler: Any,
55 | **kwargs: Any,
56 | ) -> Any:
57 | """Process a tool invocation.
58 |
59 | Looks up the tool in the registry and invokes it with the provided parameters.
60 |
61 | Args:
62 | tool: The tool object to process, containing name and parameters.
63 | model: The model being used for the agent.
64 | system_prompt: The system prompt for the agent.
65 | messages: The conversation history.
66 | tool_config: Configuration for the tool.
67 | callback_handler: Callback for processing events as they happen.
68 | **kwargs: Additional keyword arguments passed to the tool.
69 |
70 | Returns:
71 | The result of the tool invocation, or an error response if the tool fails or is not found.
72 | """
73 | logger.debug("tool=<%s> | invoking", tool)
74 | tool_use_id = tool["toolUseId"]
75 | tool_name = tool["name"]
76 |
77 | # Get the tool info
78 | tool_info = self.tool_registry.dynamic_tools.get(tool_name)
79 | tool_func = tool_info if tool_info is not None else self.tool_registry.registry.get(tool_name)
80 |
81 | try:
82 | # Check if tool exists
83 | if not tool_func:
84 | logger.error(
85 | "tool_name=<%s>, available_tools=<%s> | tool not found in registry",
86 | tool_name,
87 | list(self.tool_registry.registry.keys()),
88 | )
89 | return {
90 | "toolUseId": tool_use_id,
91 | "status": "error",
92 | "content": [{"text": f"Unknown tool: {tool_name}"}],
93 | }
94 | # Add standard arguments to kwargs for Python tools
95 | kwargs.update(
96 | {
97 | "model": model,
98 | "system_prompt": system_prompt,
99 | "messages": messages,
100 | "tool_config": tool_config,
101 | "callback_handler": callback_handler,
102 | }
103 | )
104 |
105 | return tool_func.invoke(tool, **kwargs)
106 |
107 | except Exception as e:
108 | logger.exception("tool_name=<%s> | failed to process tool", tool_name)
109 | return {
110 | "toolUseId": tool_use_id,
111 | "status": "error",
112 | "content": [{"text": f"Error: {str(e)}"}],
113 | }
114 |
--------------------------------------------------------------------------------
/src/strands/models/__init__.py:
--------------------------------------------------------------------------------
1 | """SDK model providers.
2 |
3 | This package includes an abstract base Model class along with concrete implementations for specific providers.
4 | """
5 |
6 | from . import bedrock
7 | from .bedrock import BedrockModel
8 |
9 | __all__ = ["bedrock", "BedrockModel"]
10 |
--------------------------------------------------------------------------------
/src/strands/models/litellm.py:
--------------------------------------------------------------------------------
1 | """LiteLLM model provider.
2 |
3 | - Docs: https://docs.litellm.ai/
4 | """
5 |
6 | import logging
7 | from typing import Any, Optional, TypedDict, cast
8 |
9 | import litellm
10 | from typing_extensions import Unpack, override
11 |
12 | from ..types.content import ContentBlock
13 | from .openai import OpenAIModel
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 | class LiteLLMModel(OpenAIModel):
19 | """LiteLLM model provider implementation."""
20 |
21 | class LiteLLMConfig(TypedDict, total=False):
22 | """Configuration options for LiteLLM models.
23 |
24 | Attributes:
25 | model_id: Model ID (e.g., "openai/gpt-4o", "anthropic/claude-3-sonnet").
26 | For a complete list of supported models, see https://docs.litellm.ai/docs/providers.
27 | params: Model parameters (e.g., max_tokens).
28 | For a complete list of supported parameters, see
29 | https://docs.litellm.ai/docs/completion/input#input-params-1.
30 | """
31 |
32 | model_id: str
33 | params: Optional[dict[str, Any]]
34 |
35 | def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[LiteLLMConfig]) -> None:
36 | """Initialize provider instance.
37 |
38 | Args:
39 | client_args: Arguments for the LiteLLM client.
40 | For a complete list of supported arguments, see
41 | https://github.com/BerriAI/litellm/blob/main/litellm/main.py.
42 | **model_config: Configuration options for the LiteLLM model.
43 | """
44 | self.config = dict(model_config)
45 |
46 | logger.debug("config=<%s> | initializing", self.config)
47 |
48 | client_args = client_args or {}
49 | self.client = litellm.LiteLLM(**client_args)
50 |
51 | @override
52 | def update_config(self, **model_config: Unpack[LiteLLMConfig]) -> None: # type: ignore[override]
53 | """Update the LiteLLM model configuration with the provided arguments.
54 |
55 | Args:
56 | **model_config: Configuration overrides.
57 | """
58 | self.config.update(model_config)
59 |
60 | @override
61 | def get_config(self) -> LiteLLMConfig:
62 | """Get the LiteLLM model configuration.
63 |
64 | Returns:
65 | The LiteLLM model configuration.
66 | """
67 | return cast(LiteLLMModel.LiteLLMConfig, self.config)
68 |
69 | @override
70 | @classmethod
71 | def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]:
72 | """Format a LiteLLM content block.
73 |
74 | Args:
75 | content: Message content.
76 |
77 | Returns:
78 | LiteLLM formatted content block.
79 |
80 | Raises:
81 | TypeError: If the content block type cannot be converted to a LiteLLM-compatible format.
82 | """
83 | if "reasoningContent" in content:
84 | return {
85 | "signature": content["reasoningContent"]["reasoningText"]["signature"],
86 | "thinking": content["reasoningContent"]["reasoningText"]["text"],
87 | "type": "thinking",
88 | }
89 |
90 | if "video" in content:
91 | return {
92 | "type": "video_url",
93 | "video_url": {
94 | "detail": "auto",
95 | "url": content["video"]["source"]["bytes"],
96 | },
97 | }
98 |
99 | return super().format_request_message_content(content)
100 |
--------------------------------------------------------------------------------
/src/strands/models/openai.py:
--------------------------------------------------------------------------------
1 | """OpenAI model provider.
2 |
3 | - Docs: https://platform.openai.com/docs/overview
4 | """
5 |
6 | import logging
7 | from typing import Any, Iterable, Optional, Protocol, TypedDict, cast
8 |
9 | import openai
10 | from typing_extensions import Unpack, override
11 |
12 | from ..types.models import OpenAIModel as SAOpenAIModel
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | class Client(Protocol):
18 | """Protocol defining the OpenAI-compatible interface for the underlying provider client."""
19 |
20 | @property
21 | # pragma: no cover
22 | def chat(self) -> Any:
23 | """Chat completions interface."""
24 | ...
25 |
26 |
27 | class OpenAIModel(SAOpenAIModel):
28 | """OpenAI model provider implementation."""
29 |
30 | client: Client
31 |
32 | class OpenAIConfig(TypedDict, total=False):
33 | """Configuration options for OpenAI models.
34 |
35 | Attributes:
36 | model_id: Model ID (e.g., "gpt-4o").
37 | For a complete list of supported models, see https://platform.openai.com/docs/models.
38 | params: Model parameters (e.g., max_tokens).
39 | For a complete list of supported parameters, see
40 | https://platform.openai.com/docs/api-reference/chat/create.
41 | """
42 |
43 | model_id: str
44 | params: Optional[dict[str, Any]]
45 |
46 | def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[OpenAIConfig]) -> None:
47 | """Initialize provider instance.
48 |
49 | Args:
50 | client_args: Arguments for the OpenAI client.
51 | For a complete list of supported arguments, see https://pypi.org/project/openai/.
52 | **model_config: Configuration options for the OpenAI model.
53 | """
54 | self.config = dict(model_config)
55 |
56 | logger.debug("config=<%s> | initializing", self.config)
57 |
58 | client_args = client_args or {}
59 | self.client = openai.OpenAI(**client_args)
60 |
61 | @override
62 | def update_config(self, **model_config: Unpack[OpenAIConfig]) -> None: # type: ignore[override]
63 | """Update the OpenAI model configuration with the provided arguments.
64 |
65 | Args:
66 | **model_config: Configuration overrides.
67 | """
68 | self.config.update(model_config)
69 |
70 | @override
71 | def get_config(self) -> OpenAIConfig:
72 | """Get the OpenAI model configuration.
73 |
74 | Returns:
75 | The OpenAI model configuration.
76 | """
77 | return cast(OpenAIModel.OpenAIConfig, self.config)
78 |
79 | @override
80 | def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
81 | """Send the request to the OpenAI model and get the streaming response.
82 |
83 | Args:
84 | request: The formatted request to send to the OpenAI model.
85 |
86 | Returns:
87 | An iterable of response events from the OpenAI model.
88 | """
89 | response = self.client.chat.completions.create(**request)
90 |
91 | yield {"chunk_type": "message_start"}
92 | yield {"chunk_type": "content_start", "data_type": "text"}
93 |
94 | tool_calls: dict[int, list[Any]] = {}
95 |
96 | for event in response:
97 | # Defensive: skip events with empty or missing choices
98 | if not getattr(event, "choices", None):
99 | continue
100 | choice = event.choices[0]
101 |
102 | if choice.delta.content:
103 | yield {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content}
104 |
105 | for tool_call in choice.delta.tool_calls or []:
106 | tool_calls.setdefault(tool_call.index, []).append(tool_call)
107 |
108 | if choice.finish_reason:
109 | break
110 |
111 | yield {"chunk_type": "content_stop", "data_type": "text"}
112 |
113 | for tool_deltas in tool_calls.values():
114 | yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}
115 |
116 | for tool_delta in tool_deltas:
117 | yield {"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}
118 |
119 | yield {"chunk_type": "content_stop", "data_type": "tool"}
120 |
121 | yield {"chunk_type": "message_stop", "data": choice.finish_reason}
122 |
123 | # Skip remaining events as we don't have use for anything except the final usage payload
124 | for event in response:
125 | _ = event
126 |
127 | yield {"chunk_type": "metadata", "data": event.usage}
128 |
--------------------------------------------------------------------------------
/src/strands/py.typed:
--------------------------------------------------------------------------------
1 | # Marker file that indicates this package supports typing
2 |
--------------------------------------------------------------------------------
/src/strands/telemetry/__init__.py:
--------------------------------------------------------------------------------
1 | """Telemetry module.
2 |
3 | This module provides metrics and tracing functionality.
4 | """
5 |
6 | from .metrics import EventLoopMetrics, Trace, metrics_to_string
7 | from .tracer import Tracer, get_tracer
8 |
9 | __all__ = [
10 | "EventLoopMetrics",
11 | "Trace",
12 | "metrics_to_string",
13 | "Tracer",
14 | "get_tracer",
15 | ]
16 |
--------------------------------------------------------------------------------
/src/strands/tools/__init__.py:
--------------------------------------------------------------------------------
1 | """Agent tool interfaces and utilities.
2 |
3 | This module provides the core functionality for creating, managing, and executing tools through agents.
4 | """
5 |
6 | from .decorator import tool
7 | from .thread_pool_executor import ThreadPoolExecutorWrapper
8 | from .tools import FunctionTool, InvalidToolUseNameException, PythonAgentTool, normalize_schema, normalize_tool_spec
9 |
10 | __all__ = [
11 | "tool",
12 | "FunctionTool",
13 | "PythonAgentTool",
14 | "InvalidToolUseNameException",
15 | "normalize_schema",
16 | "normalize_tool_spec",
17 | "ThreadPoolExecutorWrapper",
18 | ]
19 |
--------------------------------------------------------------------------------
/src/strands/tools/executor.py:
--------------------------------------------------------------------------------
1 | """Tool execution functionality for the event loop."""
2 |
3 | import logging
4 | import time
5 | from concurrent.futures import TimeoutError
6 | from typing import Any, Callable, List, Optional, Tuple
7 |
8 | from opentelemetry import trace
9 |
10 | from ..telemetry.metrics import EventLoopMetrics, Trace
11 | from ..telemetry.tracer import get_tracer
12 | from ..tools.tools import InvalidToolUseNameException, validate_tool_use
13 | from ..types.content import Message
14 | from ..types.event_loop import ParallelToolExecutorInterface
15 | from ..types.tools import ToolResult, ToolUse
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 |
20 | def run_tools(
21 | handler: Callable[[ToolUse], ToolResult],
22 | tool_uses: List[ToolUse],
23 | event_loop_metrics: EventLoopMetrics,
24 | request_state: Any,
25 | invalid_tool_use_ids: List[str],
26 | tool_results: List[ToolResult],
27 | cycle_trace: Trace,
28 | parent_span: Optional[trace.Span] = None,
29 | parallel_tool_executor: Optional[ParallelToolExecutorInterface] = None,
30 | ) -> bool:
31 | """Execute tools either in parallel or sequentially.
32 |
33 | Args:
34 | handler: Tool handler processing function.
35 | tool_uses: List of tool uses to execute.
36 | event_loop_metrics: Metrics collection object.
37 | request_state: Current request state.
38 | invalid_tool_use_ids: List of invalid tool use IDs.
39 | tool_results: List to populate with tool results.
40 | cycle_trace: Parent trace for the current cycle.
41 | parent_span: Parent span for the current cycle.
42 | parallel_tool_executor: Optional executor for parallel processing.
43 |
44 | Returns:
45 | bool: True if any tool failed, False otherwise.
46 | """
47 |
48 | def _handle_tool_execution(tool: ToolUse) -> Tuple[bool, Optional[ToolResult]]:
49 | result = None
50 | tool_succeeded = False
51 |
52 | tracer = get_tracer()
53 | tool_call_span = tracer.start_tool_call_span(tool, parent_span)
54 |
55 | try:
56 | if "toolUseId" not in tool or tool["toolUseId"] not in invalid_tool_use_ids:
57 | tool_name = tool["name"]
58 | tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name)
59 | tool_start_time = time.time()
60 | result = handler(tool)
61 | tool_success = result.get("status") == "success"
62 | if tool_success:
63 | tool_succeeded = True
64 |
65 | tool_duration = time.time() - tool_start_time
66 | message = Message(role="user", content=[{"toolResult": result}])
67 | event_loop_metrics.add_tool_usage(tool, tool_duration, tool_trace, tool_success, message)
68 | cycle_trace.add_child(tool_trace)
69 |
70 | if tool_call_span:
71 | tracer.end_tool_call_span(tool_call_span, result)
72 | except Exception as e:
73 | if tool_call_span:
74 | tracer.end_span_with_error(tool_call_span, str(e), e)
75 |
76 | return tool_succeeded, result
77 |
78 | any_tool_failed = False
79 | if parallel_tool_executor:
80 | logger.debug(
81 | "tool_count=<%s>, tool_executor=<%s> | executing tools in parallel",
82 | len(tool_uses),
83 | type(parallel_tool_executor).__name__,
84 | )
85 | # Submit all tasks with their associated tools
86 | future_to_tool = {
87 | parallel_tool_executor.submit(_handle_tool_execution, tool_use): tool_use for tool_use in tool_uses
88 | }
89 | logger.debug("tool_count=<%s> | submitted tasks to parallel executor", len(tool_uses))
90 |
91 | # Collect results truly in parallel using the provided executor's as_completed method
92 | completed_results = []
93 | try:
94 | for future in parallel_tool_executor.as_completed(future_to_tool):
95 | try:
96 | succeeded, result = future.result()
97 | if result is not None:
98 | completed_results.append(result)
99 | if not succeeded:
100 | any_tool_failed = True
101 | except Exception as e:
102 | tool = future_to_tool[future]
103 | logger.debug("tool_name=<%s> | tool execution failed | %s", tool["name"], e)
104 | any_tool_failed = True
105 | except TimeoutError:
106 | logger.error("timeout_seconds=<%s> | parallel tool execution timed out", parallel_tool_executor.timeout)
107 | # Process any completed tasks
108 | for future in future_to_tool:
109 | if future.done(): # type: ignore
110 | try:
111 | succeeded, result = future.result(timeout=0)
112 | if result is not None:
113 | completed_results.append(result)
114 | except Exception as tool_e:
115 | tool = future_to_tool[future]
116 | logger.debug("tool_name=<%s> | tool execution failed | %s", tool["name"], tool_e)
117 | else:
118 | # This future didn't complete within the timeout
119 | tool = future_to_tool[future]
120 | logger.debug("tool_name=<%s> | tool execution timed out", tool["name"])
121 |
122 | any_tool_failed = True
123 |
124 | # Add completed results to tool_results
125 | tool_results.extend(completed_results)
126 | else:
127 | # Sequential execution fallback
128 | for tool_use in tool_uses:
129 | succeeded, result = _handle_tool_execution(tool_use)
130 | if result is not None:
131 | tool_results.append(result)
132 | if not succeeded:
133 | any_tool_failed = True
134 |
135 | return any_tool_failed
136 |
137 |
138 | def validate_and_prepare_tools(
139 | message: Message,
140 | tool_uses: List[ToolUse],
141 | tool_results: List[ToolResult],
142 | invalid_tool_use_ids: List[str],
143 | ) -> None:
144 | """Validate tool uses and prepare them for execution.
145 |
146 | Args:
147 | message: Current message.
148 | tool_uses: List to populate with tool uses.
149 | tool_results: List to populate with tool results for invalid tools.
150 | invalid_tool_use_ids: List to populate with invalid tool use IDs.
151 | """
152 | # Extract tool uses from message
153 | for content in message["content"]:
154 | if isinstance(content, dict) and "toolUse" in content:
155 | tool_uses.append(content["toolUse"])
156 |
157 | # Validate tool uses
158 | # Avoid modifying original `tool_uses` variable during iteration
159 | tool_uses_copy = tool_uses.copy()
160 | for tool in tool_uses_copy:
161 | try:
162 | validate_tool_use(tool)
163 | except InvalidToolUseNameException as e:
164 | # Replace the invalid toolUse name and return invalid name error as ToolResult to the LLM as context
165 | tool_uses.remove(tool)
166 | tool["name"] = "INVALID_TOOL_NAME"
167 | invalid_tool_use_ids.append(tool["toolUseId"])
168 | tool_uses.append(tool)
169 | tool_results.append(
170 | {
171 | "toolUseId": tool["toolUseId"],
172 | "status": "error",
173 | "content": [{"text": f"Error: {str(e)}"}],
174 | }
175 | )
176 |
--------------------------------------------------------------------------------
/src/strands/tools/mcp/__init__.py:
--------------------------------------------------------------------------------
1 | """Model Context Protocol (MCP) integration.
2 |
3 | This package provides integration with the Model Context Protocol (MCP), allowing agents to use tools provided by MCP
4 | servers.
5 |
6 | - Docs: https://www.anthropic.com/news/model-context-protocol
7 | """
8 |
9 | from .mcp_agent_tool import MCPAgentTool
10 | from .mcp_client import MCPClient
11 | from .mcp_types import MCPTransport
12 |
13 | __all__ = ["MCPAgentTool", "MCPClient", "MCPTransport"]
14 |
--------------------------------------------------------------------------------
/src/strands/tools/mcp/mcp_agent_tool.py:
--------------------------------------------------------------------------------
1 | """MCP Agent Tool module for adapting Model Context Protocol tools to the agent framework.
2 |
3 | This module provides the MCPAgentTool class which serves as an adapter between
4 | MCP (Model Context Protocol) tools and the agent framework's tool interface.
5 | It allows MCP tools to be seamlessly integrated and used within the agent ecosystem.
6 | """
7 |
8 | import logging
9 | from typing import TYPE_CHECKING, Any
10 |
11 | from mcp.types import Tool as MCPTool
12 |
13 | from ...types.tools import AgentTool, ToolResult, ToolSpec, ToolUse
14 |
15 | if TYPE_CHECKING:
16 | from .mcp_client import MCPClient
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | class MCPAgentTool(AgentTool):
22 | """Adapter class that wraps an MCP tool and exposes it as an AgentTool.
23 |
24 | This class bridges the gap between the MCP protocol's tool representation
25 | and the agent framework's tool interface, allowing MCP tools to be used
26 | seamlessly within the agent framework.
27 | """
28 |
29 | def __init__(self, mcp_tool: MCPTool, mcp_client: "MCPClient") -> None:
30 | """Initialize a new MCPAgentTool instance.
31 |
32 | Args:
33 | mcp_tool: The MCP tool to adapt
34 | mcp_client: The MCP server connection to use for tool invocation
35 | """
36 | super().__init__()
37 | logger.debug("tool_name=<%s> | creating mcp agent tool", mcp_tool.name)
38 | self.mcp_tool = mcp_tool
39 | self.mcp_client = mcp_client
40 |
41 | @property
42 | def tool_name(self) -> str:
43 | """Get the name of the tool.
44 |
45 | Returns:
46 | str: The name of the MCP tool
47 | """
48 | return self.mcp_tool.name
49 |
50 | @property
51 | def tool_spec(self) -> ToolSpec:
52 | """Get the specification of the tool.
53 |
54 | This method converts the MCP tool specification to the agent framework's
55 | ToolSpec format, including the input schema and description.
56 |
57 | Returns:
58 | ToolSpec: The tool specification in the agent framework format
59 | """
60 | description: str = self.mcp_tool.description or f"Tool which performs {self.mcp_tool.name}"
61 | return {
62 | "inputSchema": {"json": self.mcp_tool.inputSchema},
63 | "name": self.mcp_tool.name,
64 | "description": description,
65 | }
66 |
67 | @property
68 | def tool_type(self) -> str:
69 | """Get the type of the tool.
70 |
71 | Returns:
72 | str: The type of the tool, always "python" for MCP tools
73 | """
74 | return "python"
75 |
76 | def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolResult:
77 | """Invoke the MCP tool.
78 |
79 | This method delegates the tool invocation to the MCP server connection,
80 | passing the tool use ID, tool name, and input arguments.
81 | """
82 | logger.debug("invoking MCP tool '%s' with tool_use_id=%s", self.tool_name, tool["toolUseId"])
83 | return self.mcp_client.call_tool_sync(
84 | tool_use_id=tool["toolUseId"], name=self.tool_name, arguments=tool["input"]
85 | )
86 |
--------------------------------------------------------------------------------
/src/strands/tools/mcp/mcp_types.py:
--------------------------------------------------------------------------------
1 | """Type definitions for MCP integration."""
2 |
3 | from contextlib import AbstractAsyncContextManager
4 |
5 | from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
6 | from mcp.client.streamable_http import GetSessionIdCallback
7 | from mcp.shared.memory import MessageStream
8 | from mcp.shared.message import SessionMessage
9 |
10 | """
11 | MCPTransport defines the interface for MCP transport implementations. This abstracts
12 | communication with an MCP server, hiding details of the underlying transport mechanism (WebSocket, stdio, etc.).
13 |
14 | It represents an async context manager that yields a tuple of read and write streams for MCP communication.
15 | When used with `async with`, it should establish the connection and yield the streams, then clean up
16 | when the context is exited.
17 |
18 | The read stream receives messages from the client (or exceptions if parsing fails), while the write
19 | stream sends messages to the client.
20 |
21 | Example implementation (simplified):
22 | ```python
23 | @contextlib.asynccontextmanager
24 | async def my_transport_implementation():
25 | # Set up connection
26 | read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
27 | write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
28 |
29 | # Start background tasks to handle actual I/O
30 | async with anyio.create_task_group() as tg:
31 | tg.start_soon(reader_task, read_stream_writer)
32 | tg.start_soon(writer_task, write_stream_reader)
33 |
34 | # Yield the streams to the caller
35 | yield (read_stream, write_stream)
36 | ```
37 | """
38 | # GetSessionIdCallback was added for HTTP Streaming but was not applied to the MessageStream type
39 | # https://github.com/modelcontextprotocol/python-sdk/blob/ed25167fa5d715733437996682e20c24470e8177/src/mcp/client/streamable_http.py#L418
40 | _MessageStreamWithGetSessionIdCallback = tuple[
41 | MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage], GetSessionIdCallback
42 | ]
43 | MCPTransport = AbstractAsyncContextManager[MessageStream | _MessageStreamWithGetSessionIdCallback]
44 |
--------------------------------------------------------------------------------
/src/strands/tools/thread_pool_executor.py:
--------------------------------------------------------------------------------
1 | """Thread pool execution management for parallel tool calls."""
2 |
3 | import concurrent.futures
4 | from concurrent.futures import ThreadPoolExecutor
5 | from typing import Any, Callable, Iterable, Iterator, Optional
6 |
7 | from ..types.event_loop import Future, ParallelToolExecutorInterface
8 |
9 |
10 | class ThreadPoolExecutorWrapper(ParallelToolExecutorInterface):
11 | """Wrapper around ThreadPoolExecutor to implement the strands.types.event_loop.ParallelToolExecutorInterface.
12 |
13 | This class adapts Python's standard ThreadPoolExecutor to conform to the SDK's ParallelToolExecutorInterface,
14 | allowing it to be used for parallel tool execution within the agent event loop. It provides methods for submitting
15 | tasks, monitoring their completion, and shutting down the executor.
16 |
17 | Attributes:
18 | thread_pool: The underlying ThreadPoolExecutor instance.
19 | """
20 |
21 | def __init__(self, thread_pool: ThreadPoolExecutor):
22 | """Initialize with a ThreadPoolExecutor instance.
23 |
24 | Args:
25 | thread_pool: The ThreadPoolExecutor to wrap.
26 | """
27 | self.thread_pool = thread_pool
28 |
29 | def submit(self, fn: Callable[..., Any], /, *args: Any, **kwargs: Any) -> Future:
30 | """Submit a callable to be executed with the given arguments.
31 |
32 | This method schedules the callable to be executed as fn(*args, **kwargs)
33 | and returns a Future instance representing the execution of the callable.
34 |
35 | Args:
36 | fn: The callable to execute.
37 | *args: Positional arguments for the callable.
38 | **kwargs: Keyword arguments for the callable.
39 |
40 | Returns:
41 | A Future instance representing the execution of the callable.
42 | """
43 | return self.thread_pool.submit(fn, *args, **kwargs)
44 |
45 | def as_completed(self, futures: Iterable[Future], timeout: Optional[int] = None) -> Iterator[Future]:
46 | """Return an iterator over the futures as they complete.
47 |
48 | The returned iterator yields futures as they complete (finished or cancelled).
49 |
50 | Args:
51 | futures: The futures to iterate over.
52 | timeout: The maximum number of seconds to wait.
53 | None means no limit.
54 |
55 | Returns:
56 | An iterator yielding futures as they complete.
57 |
58 | Raises:
59 | concurrent.futures.TimeoutError: If the timeout is reached.
60 | """
61 | return concurrent.futures.as_completed(futures, timeout=timeout) # type: ignore
62 |
63 | def shutdown(self, wait: bool = True) -> None:
64 | """Shutdown the thread pool executor.
65 |
66 | Args:
67 | wait: If True, waits until all running futures have finished executing.
68 | """
69 | self.thread_pool.shutdown(wait=wait)
70 |
--------------------------------------------------------------------------------
/src/strands/tools/watcher.py:
--------------------------------------------------------------------------------
1 | """Tool watcher for hot reloading tools during development.
2 |
3 | This module provides functionality to watch tool directories for changes and automatically reload tools when they are
4 | modified.
5 | """
6 |
7 | import logging
8 | from pathlib import Path
9 | from typing import Any, Dict, Set
10 |
11 | from watchdog.events import FileSystemEventHandler
12 | from watchdog.observers import Observer
13 |
14 | from .registry import ToolRegistry
15 |
16 | logger = logging.getLogger(__name__)
17 |
18 |
19 | class ToolWatcher:
20 | """Watches tool directories for changes and reloads tools when they are modified."""
21 |
22 | # This class uses class variables for the observer and handlers because watchdog allows only one Observer instance
23 | # per directory. Using class variables ensures that all ToolWatcher instances share a single Observer, with the
24 | # MasterChangeHandler routing file system events to the appropriate individual handlers for each registry. This
25 | # design pattern avoids conflicts when multiple tool registries are watching the same directories.
26 |
27 | _shared_observer = None
28 | _watched_dirs: Set[str] = set()
29 | _observer_started = False
30 | _registry_handlers: Dict[str, Dict[int, "ToolWatcher.ToolChangeHandler"]] = {}
31 |
32 | def __init__(self, tool_registry: ToolRegistry) -> None:
33 | """Initialize a tool watcher for the given tool registry.
34 |
35 | Args:
36 | tool_registry: The tool registry to report changes.
37 | """
38 | self.tool_registry = tool_registry
39 | self.start()
40 |
41 | class ToolChangeHandler(FileSystemEventHandler):
42 | """Handler for tool file changes."""
43 |
44 | def __init__(self, tool_registry: ToolRegistry) -> None:
45 | """Initialize a tool change handler.
46 |
47 | Args:
48 | tool_registry: The tool registry to update when tools change.
49 | """
50 | self.tool_registry = tool_registry
51 |
52 | def on_modified(self, event: Any) -> None:
53 | """Reload tool if file modification detected.
54 |
55 | Args:
56 | event: The file system event that triggered this handler.
57 | """
58 | if event.src_path.endswith(".py"):
59 | tool_path = Path(event.src_path)
60 | tool_name = tool_path.stem
61 |
62 | if tool_name not in ["__init__"]:
63 | logger.debug("tool_name=<%s> | tool change detected", tool_name)
64 | try:
65 | self.tool_registry.reload_tool(tool_name)
66 | except Exception as e:
67 | logger.error("tool_name=<%s>, exception=<%s> | failed to reload tool", tool_name, str(e))
68 |
69 | class MasterChangeHandler(FileSystemEventHandler):
70 | """Master handler that delegates to all registered handlers."""
71 |
72 | def __init__(self, dir_path: str) -> None:
73 | """Initialize a master change handler for a specific directory.
74 |
75 | Args:
76 | dir_path: The directory path to watch.
77 | """
78 | self.dir_path = dir_path
79 |
80 | def on_modified(self, event: Any) -> None:
81 | """Delegate file modification events to all registered handlers.
82 |
83 | Args:
84 | event: The file system event that triggered this handler.
85 | """
86 | if event.src_path.endswith(".py"):
87 | tool_path = Path(event.src_path)
88 | tool_name = tool_path.stem
89 |
90 | if tool_name not in ["__init__"]:
91 | # Delegate to all registered handlers for this directory
92 | for handler in ToolWatcher._registry_handlers.get(self.dir_path, {}).values():
93 | try:
94 | handler.on_modified(event)
95 | except Exception as e:
96 | logger.error("exception=<%s> | handler error", str(e))
97 |
98 | def start(self) -> None:
99 | """Start watching all tools directories for changes."""
100 | # Initialize shared observer if not already done
101 | if ToolWatcher._shared_observer is None:
102 | ToolWatcher._shared_observer = Observer()
103 |
104 | # Create handler for this instance
105 | self.tool_change_handler = self.ToolChangeHandler(self.tool_registry)
106 | registry_id = id(self.tool_registry)
107 |
108 | # Get tools directories to watch
109 | tools_dirs = self.tool_registry.get_tools_dirs()
110 |
111 | for tools_dir in tools_dirs:
112 | dir_str = str(tools_dir)
113 |
114 | # Initialize the registry handlers dict for this directory if needed
115 | if dir_str not in ToolWatcher._registry_handlers:
116 | ToolWatcher._registry_handlers[dir_str] = {}
117 |
118 | # Store this handler with its registry id
119 | ToolWatcher._registry_handlers[dir_str][registry_id] = self.tool_change_handler
120 |
121 | # Schedule or update the master handler for this directory
122 | if dir_str not in ToolWatcher._watched_dirs:
123 | # First time seeing this directory, create a master handler
124 | master_handler = self.MasterChangeHandler(dir_str)
125 | ToolWatcher._shared_observer.schedule(master_handler, dir_str, recursive=False)
126 | ToolWatcher._watched_dirs.add(dir_str)
127 | logger.debug("tools_dir=<%s> | started watching tools directory", tools_dir)
128 | else:
129 | # Directory already being watched, just log it
130 | logger.debug("tools_dir=<%s> | directory already being watched", tools_dir)
131 |
132 | # Start the observer if not already started
133 | if not ToolWatcher._observer_started:
134 | ToolWatcher._shared_observer.start()
135 | ToolWatcher._observer_started = True
136 | logger.debug("tool directory watching initialized")
137 |
--------------------------------------------------------------------------------
/src/strands/types/__init__.py:
--------------------------------------------------------------------------------
1 | """SDK type definitions."""
2 |
--------------------------------------------------------------------------------
/src/strands/types/content.py:
--------------------------------------------------------------------------------
1 | """Content-related type definitions for the SDK.
2 |
3 | This module defines the types used to represent messages, content blocks, and other content-related structures in the
4 | SDK. These types are modeled after the Bedrock API.
5 |
6 | - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html
7 | """
8 |
9 | from typing import Dict, List, Literal, Optional
10 |
11 | from typing_extensions import TypedDict
12 |
13 | from .media import DocumentContent, ImageContent, VideoContent
14 | from .tools import ToolResult, ToolUse
15 |
16 |
17 | class GuardContentText(TypedDict):
18 | """Text content to be evaluated by guardrails.
19 |
20 | Attributes:
21 | qualifiers: The qualifiers describing the text block.
22 | text: The input text details to be evaluated by the guardrail.
23 | """
24 |
25 | qualifiers: List[Literal["grounding_source", "query", "guard_content"]]
26 | text: str
27 |
28 |
29 | class GuardContent(TypedDict):
30 | """Content block to be evaluated by guardrails.
31 |
32 | Attributes:
33 | text: Text within content block to be evaluated by the guardrail.
34 | """
35 |
36 | text: GuardContentText
37 |
38 |
39 | class ReasoningTextBlock(TypedDict, total=False):
40 | """Contains the reasoning that the model used to return the output.
41 |
42 | Attributes:
43 | signature: A token that verifies that the reasoning text was generated by the model.
44 | text: The reasoning that the model used to return the output.
45 | """
46 |
47 | signature: Optional[str]
48 | text: str
49 |
50 |
51 | class ReasoningContentBlock(TypedDict, total=False):
52 | """Contains content regarding the reasoning that is carried out by the model.
53 |
54 | Attributes:
55 | reasoningText: The reasoning that the model used to return the output.
56 | redactedContent: The content in the reasoning that was encrypted by the model provider for safety reasons.
57 | """
58 |
59 | reasoningText: ReasoningTextBlock
60 | redactedContent: bytes
61 |
62 |
63 | class CachePoint(TypedDict):
64 | """A cache point configuration for optimizing conversation history.
65 |
66 | Attributes:
67 | type: The type of cache point, typically "default".
68 | """
69 |
70 | type: str
71 |
72 |
73 | class ContentBlock(TypedDict, total=False):
74 | """A block of content for a message that you pass to, or receive from, a model.
75 |
76 | Attributes:
77 | cachePoint: A cache point configuration to optimize conversation history.
78 | document: A document to include in the message.
79 | guardContent: Contains the content to assess with the guardrail.
80 | image: Image to include in the message.
81 | reasoningContent: Contains content regarding the reasoning that is carried out by the model.
82 | text: Text to include in the message.
83 | toolResult: The result for a tool request that a model makes.
84 | toolUse: Information about a tool use request from a model.
85 | video: Video to include in the message.
86 | """
87 |
88 | cachePoint: CachePoint
89 | document: DocumentContent
90 | guardContent: GuardContent
91 | image: ImageContent
92 | reasoningContent: ReasoningContentBlock
93 | text: str
94 | toolResult: ToolResult
95 | toolUse: ToolUse
96 | video: VideoContent
97 |
98 |
99 | class SystemContentBlock(TypedDict, total=False):
100 | """Contains configurations for instructions to provide the model for how to handle input.
101 |
102 | Attributes:
103 | guardContent: A content block to assess with the guardrail.
104 | text: A system prompt for the model.
105 | """
106 |
107 | guardContent: GuardContent
108 | text: str
109 |
110 |
111 | class DeltaContent(TypedDict, total=False):
112 | """A block of content in a streaming response.
113 |
114 | Attributes:
115 | text: The content text.
116 | toolUse: Information about a tool that the model is requesting to use.
117 | """
118 |
119 | text: str
120 | toolUse: Dict[Literal["input"], str]
121 |
122 |
123 | class ContentBlockStartToolUse(TypedDict):
124 | """The start of a tool use block.
125 |
126 | Attributes:
127 | name: The name of the tool that the model is requesting to use.
128 | toolUseId: The ID for the tool request.
129 | """
130 |
131 | name: str
132 | toolUseId: str
133 |
134 |
135 | class ContentBlockStart(TypedDict, total=False):
136 | """Content block start information.
137 |
138 | Attributes:
139 | toolUse: Information about a tool that the model is requesting to use.
140 | """
141 |
142 | toolUse: Optional[ContentBlockStartToolUse]
143 |
144 |
145 | class ContentBlockDelta(TypedDict):
146 | """The content block delta event.
147 |
148 | Attributes:
149 | contentBlockIndex: The block index for a content block delta event.
150 | delta: The delta for a content block delta event.
151 | """
152 |
153 | contentBlockIndex: int
154 | delta: DeltaContent
155 |
156 |
157 | class ContentBlockStop(TypedDict):
158 | """A content block stop event.
159 |
160 | Attributes:
161 | contentBlockIndex: The index for a content block.
162 | """
163 |
164 | contentBlockIndex: int
165 |
166 |
167 | Role = Literal["user", "assistant"]
168 | """Role of a message sender.
169 |
170 | - "user": Messages from the user to the assistant
171 | - "assistant": Messages from the assistant to the user
172 | """
173 |
174 |
175 | class Message(TypedDict):
176 | """A message in a conversation with the agent.
177 |
178 | Attributes:
179 | content: The message content.
180 | role: The role of the message sender.
181 | """
182 |
183 | content: List[ContentBlock]
184 | role: Role
185 |
186 |
187 | Messages = List[Message]
188 | """A list of messages representing a conversation."""
189 |
--------------------------------------------------------------------------------
/src/strands/types/event_loop.py:
--------------------------------------------------------------------------------
1 | """Event loop-related type definitions for the SDK."""
2 |
3 | from typing import Any, Callable, Iterable, Iterator, Literal, Optional, Protocol
4 |
5 | from typing_extensions import TypedDict, runtime_checkable
6 |
7 |
8 | class Usage(TypedDict):
9 | """Token usage information for model interactions.
10 |
11 | Attributes:
12 | inputTokens: Number of tokens sent in the request to the model..
13 | outputTokens: Number of tokens that the model generated for the request.
14 | totalTokens: Total number of tokens (input + output).
15 | """
16 |
17 | inputTokens: int
18 | outputTokens: int
19 | totalTokens: int
20 |
21 |
22 | class Metrics(TypedDict):
23 | """Performance metrics for model interactions.
24 |
25 | Attributes:
26 | latencyMs (int): Latency of the model request in milliseconds.
27 | """
28 |
29 | latencyMs: int
30 |
31 |
32 | StopReason = Literal[
33 | "content_filtered",
34 | "end_turn",
35 | "guardrail_intervened",
36 | "max_tokens",
37 | "stop_sequence",
38 | "tool_use",
39 | ]
40 | """Reason for the model ending its response generation.
41 |
42 | - "content_filtered": Content was filtered due to policy violation
43 | - "end_turn": Normal completion of the response
44 | - "guardrail_intervened": Guardrail system intervened
45 | - "max_tokens": Maximum token limit reached
46 | - "stop_sequence": Stop sequence encountered
47 | - "tool_use": Model requested to use a tool
48 | """
49 |
50 |
51 | @runtime_checkable
52 | class Future(Protocol):
53 | """Interface representing the result of an asynchronous computation."""
54 |
55 | def result(self, timeout: Optional[int] = None) -> Any:
56 | """Return the result of the call that the future represents.
57 |
58 | This method will block until the asynchronous operation completes or until the specified timeout is reached.
59 |
60 | Args:
61 | timeout: The number of seconds to wait for the result.
62 | If None, then there is no limit on the wait time.
63 |
64 | Returns:
65 | Any: The result of the asynchronous operation.
66 | """
67 |
68 |
69 | @runtime_checkable
70 | class ParallelToolExecutorInterface(Protocol):
71 | """Interface for parallel tool execution.
72 |
73 | Attributes:
74 | timeout: Default timeout in seconds for futures.
75 | """
76 |
77 | timeout: int = 900 # default 15 minute timeout for futures
78 |
79 | def submit(self, fn: Callable[..., Any], /, *args: Any, **kwargs: Any) -> Future:
80 | """Submit a callable to be executed with the given arguments.
81 |
82 | Schedules the callable to be executed as fn(*args, **kwargs) and returns a Future instance representing the
83 | execution of the callable.
84 |
85 | Args:
86 | fn: The callable to execute.
87 | *args: Positional arguments to pass to the callable.
88 | **kwargs: Keyword arguments to pass to the callable.
89 |
90 | Returns:
91 | Future: A Future representing the given call.
92 | """
93 |
94 | def as_completed(self, futures: Iterable[Future], timeout: Optional[int] = timeout) -> Iterator[Future]:
95 | """Iterate over the given futures, yielding each as it completes.
96 |
97 | Args:
98 | futures: The sequence of Futures to iterate over.
99 | timeout: The maximum number of seconds to wait.
100 | If None, then there is no limit on the wait time.
101 |
102 | Returns:
103 | An iterator that yields the given Futures as they complete (finished or cancelled).
104 | """
105 |
106 | def shutdown(self, wait: bool = True) -> None:
107 | """Shutdown the executor and free associated resources.
108 |
109 | Args:
110 | wait: If True, shutdown will not return until all running futures have finished executing.
111 | """
112 |
--------------------------------------------------------------------------------
/src/strands/types/exceptions.py:
--------------------------------------------------------------------------------
1 | """Exception-related type definitions for the SDK."""
2 |
3 | from typing import Any
4 |
5 |
6 | class EventLoopException(Exception):
7 | """Exception raised by the event loop."""
8 |
9 | def __init__(self, original_exception: Exception, request_state: Any = None) -> None:
10 | """Initialize exception.
11 |
12 | Args:
13 | original_exception: The original exception that was raised.
14 | request_state: The state of the request at the time of the exception.
15 | """
16 | self.original_exception = original_exception
17 | self.request_state = request_state if request_state is not None else {}
18 | super().__init__(str(original_exception))
19 |
20 |
21 | class ContextWindowOverflowException(Exception):
22 | """Exception raised when the context window is exceeded.
23 |
24 | This exception is raised when the input to a model exceeds the maximum context window size that the model can
25 | handle. This typically occurs when the combined length of the conversation history, system prompt, and current
26 | message is too large for the model to process.
27 | """
28 |
29 | pass
30 |
31 |
32 | class MCPClientInitializationError(Exception):
33 | """Raised when the MCP server fails to initialize properly."""
34 |
35 | pass
36 |
37 |
38 | class ModelThrottledException(Exception):
39 | """Exception raised when the model is throttled.
40 |
41 | This exception is raised when the model is throttled by the service. This typically occurs when the service is
42 | throttling the requests from the client.
43 | """
44 |
45 | def __init__(self, message: str) -> None:
46 | """Initialize exception.
47 |
48 | Args:
49 | message: The message from the service that describes the throttling.
50 | """
51 | self.message = message
52 | super().__init__(message)
53 |
54 | pass
55 |
--------------------------------------------------------------------------------
/src/strands/types/media.py:
--------------------------------------------------------------------------------
1 | """Media-related type definitions for the SDK.
2 |
3 | These types are modeled after the Bedrock API.
4 |
5 | - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html
6 | """
7 |
8 | from typing import Literal
9 |
10 | from typing_extensions import TypedDict
11 |
12 | DocumentFormat = Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"]
13 | """Supported document formats."""
14 |
15 |
16 | class DocumentSource(TypedDict):
17 | """Contains the content of a document.
18 |
19 | Attributes:
20 | bytes: The binary content of the document.
21 | """
22 |
23 | bytes: bytes
24 |
25 |
26 | class DocumentContent(TypedDict):
27 | """A document to include in a message.
28 |
29 | Attributes:
30 | format: The format of the document (e.g., "pdf", "txt").
31 | name: The name of the document.
32 | source: The source containing the document's binary content.
33 | """
34 |
35 | format: Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"]
36 | name: str
37 | source: DocumentSource
38 |
39 |
40 | ImageFormat = Literal["png", "jpeg", "gif", "webp"]
41 | """Supported image formats."""
42 |
43 |
44 | class ImageSource(TypedDict):
45 | """Contains the content of an image.
46 |
47 | Attributes:
48 | bytes: The binary content of the image.
49 | """
50 |
51 | bytes: bytes
52 |
53 |
54 | class ImageContent(TypedDict):
55 | """An image to include in a message.
56 |
57 | Attributes:
58 | format: The format of the image (e.g., "png", "jpeg").
59 | source: The source containing the image's binary content.
60 | """
61 |
62 | format: ImageFormat
63 | source: ImageSource
64 |
65 |
66 | VideoFormat = Literal["flv", "mkv", "mov", "mpeg", "mpg", "mp4", "three_gp", "webm", "wmv"]
67 | """Supported video formats."""
68 |
69 |
70 | class VideoSource(TypedDict):
71 | """Contains the content of a video.
72 |
73 | Attributes:
74 | bytes: The binary content of the video.
75 | """
76 |
77 | bytes: bytes
78 |
79 |
80 | class VideoContent(TypedDict):
81 | """A video to include in a message.
82 |
83 | Attributes:
84 | format: The format of the video (e.g., "mp4", "avi").
85 | source: The source containing the video's binary content.
86 | """
87 |
88 | format: VideoFormat
89 | source: VideoSource
90 |
--------------------------------------------------------------------------------
/src/strands/types/models/__init__.py:
--------------------------------------------------------------------------------
1 | """Model-related type definitions for the SDK."""
2 |
3 | from .model import Model
4 | from .openai import OpenAIModel
5 |
6 | __all__ = ["Model", "OpenAIModel"]
7 |
--------------------------------------------------------------------------------
/src/strands/types/models/model.py:
--------------------------------------------------------------------------------
1 | """Model-related type definitions for the SDK."""
2 |
3 | import abc
4 | import logging
5 | from typing import Any, Iterable, Optional
6 |
7 | from ..content import Messages
8 | from ..streaming import StreamEvent
9 | from ..tools import ToolSpec
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | class Model(abc.ABC):
15 | """Abstract base class for AI model implementations.
16 |
17 | This class defines the interface for all model implementations in the Strands Agents SDK. It provides a
18 | standardized way to configure, format, and process requests for different AI model providers.
19 | """
20 |
21 | @abc.abstractmethod
22 | # pragma: no cover
23 | def update_config(self, **model_config: Any) -> None:
24 | """Update the model configuration with the provided arguments.
25 |
26 | Args:
27 | **model_config: Configuration overrides.
28 | """
29 | pass
30 |
31 | @abc.abstractmethod
32 | # pragma: no cover
33 | def get_config(self) -> Any:
34 | """Return the model configuration.
35 |
36 | Returns:
37 | The model's configuration.
38 | """
39 | pass
40 |
41 | @abc.abstractmethod
42 | # pragma: no cover
43 | def format_request(
44 | self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
45 | ) -> Any:
46 | """Format a streaming request to the underlying model.
47 |
48 | Args:
49 | messages: List of message objects to be processed by the model.
50 | tool_specs: List of tool specifications to make available to the model.
51 | system_prompt: System prompt to provide context to the model.
52 |
53 | Returns:
54 | The formatted request.
55 | """
56 | pass
57 |
58 | @abc.abstractmethod
59 | # pragma: no cover
60 | def format_chunk(self, event: Any) -> StreamEvent:
61 | """Format the model response events into standardized message chunks.
62 |
63 | Args:
64 | event: A response event from the model.
65 |
66 | Returns:
67 | The formatted chunk.
68 | """
69 | pass
70 |
71 | @abc.abstractmethod
72 | # pragma: no cover
73 | def stream(self, request: Any) -> Iterable[Any]:
74 | """Send the request to the model and get a streaming response.
75 |
76 | Args:
77 | request: The formatted request to send to the model.
78 |
79 | Returns:
80 | The model's response.
81 |
82 | Raises:
83 | ModelThrottledException: When the model service is throttling requests from the client.
84 | """
85 | pass
86 |
87 | def converse(
88 | self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
89 | ) -> Iterable[StreamEvent]:
90 | """Converse with the model.
91 |
92 | This method handles the full lifecycle of conversing with the model:
93 | 1. Format the messages, tool specs, and configuration into a streaming request
94 | 2. Send the request to the model
95 | 3. Yield the formatted message chunks
96 |
97 | Args:
98 | messages: List of message objects to be processed by the model.
99 | tool_specs: List of tool specifications to make available to the model.
100 | system_prompt: System prompt to provide context to the model.
101 |
102 | Yields:
103 | Formatted message chunks from the model.
104 |
105 | Raises:
106 | ModelThrottledException: When the model service is throttling requests from the client.
107 | """
108 | logger.debug("formatting request")
109 | request = self.format_request(messages, tool_specs, system_prompt)
110 |
111 | logger.debug("invoking model")
112 | response = self.stream(request)
113 |
114 | logger.debug("got response from model")
115 | for event in response:
116 | yield self.format_chunk(event)
117 |
118 | logger.debug("finished streaming response from model")
119 |
--------------------------------------------------------------------------------
/src/strands/types/streaming.py:
--------------------------------------------------------------------------------
1 | """Streaming-related type definitions for the SDK.
2 |
3 | These types are modeled after the Bedrock API.
4 |
5 | - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html
6 | """
7 |
8 | from typing import Optional, Union
9 |
10 | from typing_extensions import TypedDict
11 |
12 | from .content import ContentBlockStart, Role
13 | from .event_loop import Metrics, StopReason, Usage
14 | from .guardrails import Trace
15 |
16 |
17 | class MessageStartEvent(TypedDict):
18 | """Event signaling the start of a message in a streaming response.
19 |
20 | Attributes:
21 | role: The role of the message sender (e.g., "assistant", "user").
22 | """
23 |
24 | role: Role
25 |
26 |
27 | class ContentBlockStartEvent(TypedDict, total=False):
28 | """Event signaling the start of a content block in a streaming response.
29 |
30 | Attributes:
31 | contentBlockIndex: Index of the content block within the message.
32 | This is optional to accommodate different model providers.
33 | start: Information about the content block being started.
34 | """
35 |
36 | contentBlockIndex: Optional[int]
37 | start: ContentBlockStart
38 |
39 |
40 | class ContentBlockDeltaText(TypedDict):
41 | """Text content delta in a streaming response.
42 |
43 | Attributes:
44 | text: The text fragment being streamed.
45 | """
46 |
47 | text: str
48 |
49 |
50 | class ContentBlockDeltaToolUse(TypedDict):
51 | """Tool use input delta in a streaming response.
52 |
53 | Attributes:
54 | input: The tool input fragment being streamed.
55 | """
56 |
57 | input: str
58 |
59 |
60 | class ReasoningContentBlockDelta(TypedDict, total=False):
61 | """Delta for reasoning content block in a streaming response.
62 |
63 | Attributes:
64 | redactedContent: The content in the reasoning that was encrypted by the model provider for safety reasons.
65 | signature: A token that verifies that the reasoning text was generated by the model.
66 | text: The reasoning that the model used to return the output.
67 | """
68 |
69 | redactedContent: Optional[bytes]
70 | signature: Optional[str]
71 | text: Optional[str]
72 |
73 |
74 | class ContentBlockDelta(TypedDict, total=False):
75 | """A block of content in a streaming response.
76 |
77 | Attributes:
78 | reasoningContent: Contains content regarding the reasoning that is carried out by the model.
79 | text: Text fragment being streamed.
80 | toolUse: Tool use input fragment being streamed.
81 | """
82 |
83 | reasoningContent: ReasoningContentBlockDelta
84 | text: str
85 | toolUse: ContentBlockDeltaToolUse
86 |
87 |
88 | class ContentBlockDeltaEvent(TypedDict, total=False):
89 | """Event containing a delta update for a content block in a streaming response.
90 |
91 | Attributes:
92 | contentBlockIndex: Index of the content block within the message.
93 | This is optional to accommodate different model providers.
94 | delta: The incremental content update for the content block.
95 | """
96 |
97 | contentBlockIndex: Optional[int]
98 | delta: ContentBlockDelta
99 |
100 |
101 | class ContentBlockStopEvent(TypedDict, total=False):
102 | """Event signaling the end of a content block in a streaming response.
103 |
104 | Attributes:
105 | contentBlockIndex: Index of the content block within the message.
106 | This is optional to accommodate different model providers.
107 | """
108 |
109 | contentBlockIndex: Optional[int]
110 |
111 |
112 | class MessageStopEvent(TypedDict, total=False):
113 | """Event signaling the end of a message in a streaming response.
114 |
115 | Attributes:
116 | additionalModelResponseFields: Additional fields to include in model response.
117 | stopReason: The reason why the model stopped generating content.
118 | """
119 |
120 | additionalModelResponseFields: Optional[Union[dict, list, int, float, str, bool, None]]
121 | stopReason: StopReason
122 |
123 |
124 | class MetadataEvent(TypedDict, total=False):
125 | """Event containing metadata about the streaming response.
126 |
127 | Attributes:
128 | metrics: Performance metrics related to the model invocation.
129 | trace: Trace information for debugging and monitoring.
130 | usage: Resource usage information for the model invocation.
131 | """
132 |
133 | metrics: Metrics
134 | trace: Optional[Trace]
135 | usage: Usage
136 |
137 |
138 | class ExceptionEvent(TypedDict):
139 | """Base event for exceptions in a streaming response.
140 |
141 | Attributes:
142 | message: The error message describing what went wrong.
143 | """
144 |
145 | message: str
146 |
147 |
148 | class ModelStreamErrorEvent(ExceptionEvent):
149 | """Event for model streaming errors.
150 |
151 | Attributes:
152 | originalMessage: The original error message from the model provider.
153 | originalStatusCode: The HTTP status code returned by the model provider.
154 | """
155 |
156 | originalMessage: str
157 | originalStatusCode: int
158 |
159 |
160 | class RedactContentEvent(TypedDict, total=False):
161 | """Event for redacting content.
162 |
163 | Attributes:
164 | redactUserContentMessage: The string to overwrite the users input with.
165 | redactAssistantContentMessage: The string to overwrite the assistants output with.
166 |
167 | """
168 |
169 | redactUserContentMessage: Optional[str]
170 | redactAssistantContentMessage: Optional[str]
171 |
172 |
173 | class StreamEvent(TypedDict, total=False):
174 | """The messages output stream.
175 |
176 | Attributes:
177 | contentBlockDelta: Delta content for a content block.
178 | contentBlockStart: Start of a content block.
179 | contentBlockStop: End of a content block.
180 | internalServerException: Internal server error information.
181 | messageStart: Start of a message.
182 | messageStop: End of a message.
183 | metadata: Metadata about the streaming response.
184 | modelStreamErrorException: Model streaming error information.
185 | serviceUnavailableException: Service unavailable error information.
186 | throttlingException: Throttling error information.
187 | validationException: Validation error information.
188 | """
189 |
190 | contentBlockDelta: ContentBlockDeltaEvent
191 | contentBlockStart: ContentBlockStartEvent
192 | contentBlockStop: ContentBlockStopEvent
193 | internalServerException: ExceptionEvent
194 | messageStart: MessageStartEvent
195 | messageStop: MessageStopEvent
196 | metadata: MetadataEvent
197 | redactContent: RedactContentEvent
198 | modelStreamErrorException: ModelStreamErrorEvent
199 | serviceUnavailableException: ExceptionEvent
200 | throttlingException: ExceptionEvent
201 | validationException: ExceptionEvent
202 |
--------------------------------------------------------------------------------
/src/strands/types/traces.py:
--------------------------------------------------------------------------------
1 | """Tracing type definitions for the SDK."""
2 |
3 | from typing import List, Union
4 |
5 | AttributeValue = Union[str, bool, float, int, List[str], List[bool], List[float], List[int]]
6 |
--------------------------------------------------------------------------------
/tests-integ/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/strands-agents/sdk-python/a64e80dfa5f5f6053a3fe53767f59fc3c5c1af95/tests-integ/__init__.py
--------------------------------------------------------------------------------
/tests-integ/echo_server.py:
--------------------------------------------------------------------------------
1 | """
2 | Echo Server for MCP Integration Testing
3 |
4 | This module implements a simple echo server using the Model Context Protocol (MCP).
5 | It provides a basic tool that echoes back any input string, which is useful for
6 | testing the MCP communication flow and validating that messages are properly
7 | transmitted between the client and server.
8 |
9 | The server runs with stdio transport, making it suitable for integration tests
10 | where the client can spawn this process and communicate with it through standard
11 | input/output streams.
12 |
13 | Usage:
14 | Run this file directly to start the echo server:
15 | $ python echo_server.py
16 | """
17 |
18 | from mcp.server import FastMCP
19 |
20 |
21 | def start_echo_server():
22 | """
23 | Initialize and start the MCP echo server.
24 |
25 | Creates a FastMCP server instance with a single 'echo' tool that returns
26 | any input string back to the caller. The server uses stdio transport
27 | for communication.
28 | """
29 | mcp = FastMCP("Echo Server")
30 |
31 | @mcp.tool(description="Echos response back to the user")
32 | def echo(to_echo: str) -> str:
33 | return to_echo
34 |
35 | mcp.run(transport="stdio")
36 |
37 |
38 | if __name__ == "__main__":
39 | start_echo_server()
40 |
--------------------------------------------------------------------------------
/tests-integ/test_bedrock_cache_point.py:
--------------------------------------------------------------------------------
1 | from strands import Agent
2 | from strands.types.content import Messages
3 |
4 |
5 | def test_bedrock_cache_point():
6 | messages: Messages = [
7 | {
8 | "role": "user",
9 | "content": [
10 | {
11 | "text": "Some really long text!" * 1000 # Minimum token count for cachePoint is 1024 tokens
12 | },
13 | {"cachePoint": {"type": "default"}},
14 | ],
15 | },
16 | {"role": "assistant", "content": [{"text": "Blue!"}]},
17 | ]
18 |
19 | cache_point_usage = 0
20 |
21 | def cache_point_callback_handler(**kwargs):
22 | nonlocal cache_point_usage
23 | if "event" in kwargs and kwargs["event"] and "metadata" in kwargs["event"] and kwargs["event"]["metadata"]:
24 | metadata = kwargs["event"]["metadata"]
25 | if "usage" in metadata and metadata["usage"]:
26 | if "cacheReadInputTokens" in metadata["usage"] or "cacheWriteInputTokens" in metadata["usage"]:
27 | cache_point_usage += 1
28 |
29 | agent = Agent(messages=messages, callback_handler=cache_point_callback_handler, load_tools_from_directory=False)
30 | agent("What is favorite color?")
31 | assert cache_point_usage > 0
32 |
--------------------------------------------------------------------------------
/tests-integ/test_bedrock_guardrails.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import boto3
4 | import pytest
5 |
6 | from strands import Agent
7 | from strands.models.bedrock import BedrockModel
8 |
9 | BLOCKED_INPUT = "BLOCKED_INPUT"
10 | BLOCKED_OUTPUT = "BLOCKED_OUTPUT"
11 |
12 |
13 | @pytest.fixture(scope="module")
14 | def boto_session():
15 | return boto3.Session(region_name="us-west-2")
16 |
17 |
18 | @pytest.fixture(scope="module")
19 | def bedrock_guardrail(boto_session):
20 | """
21 | Fixture that creates a guardrail before tests if it doesn't already exist."
22 | """
23 |
24 | client = boto_session.client("bedrock")
25 |
26 | guardrail_name = "test-guardrail-block-cactus"
27 | guardrail_id = get_guardrail_id(client, guardrail_name)
28 |
29 | if guardrail_id:
30 | print(f"Guardrail {guardrail_name} already exists with ID: {guardrail_id}")
31 | else:
32 | print(f"Creating guardrail {guardrail_name}")
33 | response = client.create_guardrail(
34 | name=guardrail_name,
35 | description="Testing Guardrail",
36 | wordPolicyConfig={
37 | "wordsConfig": [
38 | {
39 | "text": "CACTUS",
40 | "inputAction": "BLOCK",
41 | "outputAction": "BLOCK",
42 | "inputEnabled": True,
43 | "outputEnabled": True,
44 | },
45 | ],
46 | },
47 | blockedInputMessaging=BLOCKED_INPUT,
48 | blockedOutputsMessaging=BLOCKED_OUTPUT,
49 | )
50 | guardrail_id = response.get("guardrailId")
51 | print(f"Created test guardrail with ID: {guardrail_id}")
52 | wait_for_guardrail_active(client, guardrail_id)
53 | return guardrail_id
54 |
55 |
56 | def get_guardrail_id(client, guardrail_name):
57 | """
58 | Retrieves the ID of a guardrail by its name.
59 |
60 | Args:
61 | client: The Bedrock client instance
62 | guardrail_name: Name of the guardrail to look up
63 |
64 | Returns:
65 | str: The ID of the guardrail if found, None otherwise
66 | """
67 | response = client.list_guardrails()
68 | for guardrail in response.get("guardrails", []):
69 | if guardrail["name"] == guardrail_name:
70 | return guardrail["id"]
71 | return None
72 |
73 |
74 | def wait_for_guardrail_active(bedrock_client, guardrail_id, max_attempts=10, delay=5):
75 | """
76 | Wait for the guardrail to become active
77 | """
78 | for _ in range(max_attempts):
79 | response = bedrock_client.get_guardrail(guardrailIdentifier=guardrail_id)
80 | status = response.get("status")
81 |
82 | if status == "READY":
83 | print(f"Guardrail {guardrail_id} is now active")
84 | return True
85 |
86 | print(f"Waiting for guardrail to become active. Current status: {status}")
87 | time.sleep(delay)
88 |
89 | print(f"Guardrail did not become active within {max_attempts * delay} seconds.")
90 | raise RuntimeError("Guardrail did not become active.")
91 |
92 |
93 | def test_guardrail_input_intervention(boto_session, bedrock_guardrail):
94 | bedrock_model = BedrockModel(
95 | guardrail_id=bedrock_guardrail,
96 | guardrail_version="DRAFT",
97 | boto_session=boto_session,
98 | )
99 |
100 | agent = Agent(model=bedrock_model, system_prompt="You are a helpful assistant.", callback_handler=None)
101 |
102 | response1 = agent("CACTUS")
103 | response2 = agent("Hello!")
104 |
105 | assert response1.stop_reason == "guardrail_intervened"
106 | assert str(response1).strip() == BLOCKED_INPUT
107 | assert response2.stop_reason != "guardrail_intervened"
108 | assert str(response2).strip() != BLOCKED_INPUT
109 |
110 |
111 | @pytest.mark.parametrize("processing_mode", ["sync", "async"])
112 | def test_guardrail_output_intervention(boto_session, bedrock_guardrail, processing_mode):
113 | bedrock_model = BedrockModel(
114 | guardrail_id=bedrock_guardrail,
115 | guardrail_version="DRAFT",
116 | guardrail_redact_output=False,
117 | guardrail_stream_processing_mode=processing_mode,
118 | boto_session=boto_session,
119 | )
120 |
121 | agent = Agent(
122 | model=bedrock_model,
123 | system_prompt="When asked to say the word, say CACTUS.",
124 | callback_handler=None,
125 | load_tools_from_directory=False,
126 | )
127 |
128 | response1 = agent("Say the word.")
129 | response2 = agent("Hello!")
130 | assert response1.stop_reason == "guardrail_intervened"
131 | assert BLOCKED_OUTPUT in str(response1)
132 | assert response2.stop_reason != "guardrail_intervened"
133 | assert BLOCKED_OUTPUT not in str(response2)
134 |
135 |
136 | @pytest.mark.parametrize("processing_mode", ["sync", "async"])
137 | def test_guardrail_output_intervention_redact_output(bedrock_guardrail, processing_mode):
138 | REDACT_MESSAGE = "Redacted."
139 | bedrock_model = BedrockModel(
140 | guardrail_id=bedrock_guardrail,
141 | guardrail_version="DRAFT",
142 | guardrail_stream_processing_mode=processing_mode,
143 | guardrail_redact_output=True,
144 | guardrail_redact_output_message=REDACT_MESSAGE,
145 | region_name="us-west-2",
146 | )
147 |
148 | agent = Agent(
149 | model=bedrock_model,
150 | system_prompt="When asked to say the word, say CACTUS.",
151 | callback_handler=None,
152 | load_tools_from_directory=False,
153 | )
154 |
155 | response1 = agent("Say the word.")
156 | response2 = agent("Hello!")
157 | assert response1.stop_reason == "guardrail_intervened"
158 | assert REDACT_MESSAGE in str(response1)
159 | assert response2.stop_reason != "guardrail_intervened"
160 | assert REDACT_MESSAGE not in str(response2)
161 |
--------------------------------------------------------------------------------
/tests-integ/test_context_overflow.py:
--------------------------------------------------------------------------------
1 | from strands import Agent
2 | from strands.types.content import Messages
3 |
4 |
5 | def test_context_window_overflow():
6 | messages: Messages = [
7 | {"role": "user", "content": [{"text": "Too much text!" * 100000}]},
8 | {"role": "assistant", "content": [{"text": "That was a lot of text!"}]},
9 | ]
10 |
11 | agent = Agent(messages=messages, load_tools_from_directory=False)
12 | agent("Hi!")
13 | assert len(agent.messages) == 2
14 |
--------------------------------------------------------------------------------
/tests-integ/test_function_tools.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Test script for function-based tools
4 | """
5 |
6 | import logging
7 | from typing import Optional
8 |
9 | from strands import Agent, tool
10 |
11 | logging.getLogger("strands").setLevel(logging.DEBUG)
12 | logging.basicConfig(format="%(levelname)s | %(name)s | %(message)s", handlers=[logging.StreamHandler()])
13 |
14 |
15 | @tool
16 | def word_counter(text: str) -> str:
17 | """
18 | Count words in text.
19 |
20 | Args:
21 | text: Text to analyze
22 | """
23 | count = len(text.split())
24 | return f"Word count: {count}"
25 |
26 |
27 | @tool(name="count_chars", description="Count characters in text")
28 | def count_chars(text: str, include_spaces: Optional[bool] = True) -> str:
29 | """
30 | Count characters in text.
31 |
32 | Args:
33 | text: Text to analyze
34 | include_spaces: Whether to include spaces in the count
35 | """
36 | if not include_spaces:
37 | text = text.replace(" ", "")
38 | return f"Character count: {len(text)}"
39 |
40 |
41 | # Initialize agent with function tools
42 | agent = Agent(tools=[word_counter, count_chars])
43 |
44 | print("\n===== Testing Direct Tool Access =====")
45 | # Use the tools directly
46 | word_result = agent.tool.word_counter(text="Hello world, this is a test")
47 | print(f"\nWord counter result: {word_result}")
48 |
49 | char_result = agent.tool.count_chars(text="Hello world!", include_spaces=False)
50 | print(f"\nCharacter counter result: {char_result}")
51 |
52 | print("\n===== Testing Natural Language Access =====")
53 | # Use through natural language
54 | nl_result = agent("Count the words in this sentence: 'The quick brown fox jumps over the lazy dog'")
55 | print(f"\nNL Result: {nl_result}")
56 |
--------------------------------------------------------------------------------
/tests-integ/test_hot_tool_reload_decorator.py:
--------------------------------------------------------------------------------
1 | """
2 | Integration test for hot tool reloading functionality with the @tool decorator.
3 |
4 | This test verifies that the Strands Agent can automatically detect and load
5 | new tools created with the @tool decorator when they are added to a tools directory.
6 | """
7 |
8 | import logging
9 | import os
10 | import time
11 | from pathlib import Path
12 |
13 | from strands import Agent
14 |
15 | logging.getLogger("strands").setLevel(logging.DEBUG)
16 | logging.basicConfig(format="%(levelname)s | %(name)s | %(message)s", handlers=[logging.StreamHandler()])
17 |
18 |
19 | def test_hot_reload_decorator():
20 | """
21 | Test that the Agent automatically loads tools created with @tool decorator
22 | when added to the current working directory's tools folder.
23 | """
24 | # Set up the tools directory in current working directory
25 | tools_dir = Path.cwd() / "tools"
26 | os.makedirs(tools_dir, exist_ok=True)
27 |
28 | # Tool path that will need cleanup
29 | test_tool_path = tools_dir / "uppercase.py"
30 |
31 | try:
32 | # Create an Agent instance without any tools
33 | agent = Agent()
34 |
35 | # Create a test tool using @tool decorator
36 | with open(test_tool_path, "w") as f:
37 | f.write("""
38 | from strands import tool
39 |
40 | @tool
41 | def uppercase(text: str) -> str:
42 | \"\"\"Convert text to uppercase.\"\"\"
43 | return f"Input: {text}, Output: {text.upper()}"
44 | """)
45 |
46 | # Wait for tool detection
47 | time.sleep(3)
48 |
49 | # Verify the tool was automatically loaded
50 | assert "uppercase" in agent.tool_names, "Agent should have detected and loaded the uppercase tool"
51 |
52 | # Test calling the dynamically loaded tool
53 | result = agent.tool.uppercase(text="hello world")
54 |
55 | # Check that the result is successful
56 | assert result.get("status") == "success", "Tool call should be successful"
57 |
58 | # Check the content of the response
59 | content_list = result.get("content", [])
60 | assert len(content_list) > 0, "Tool response should have content"
61 |
62 | # Check that the expected message is in the content
63 | text_content = next((item.get("text") for item in content_list if "text" in item), "")
64 | assert "Input: hello world, Output: HELLO WORLD" in text_content
65 |
66 | finally:
67 | # Clean up - remove the test file
68 | if test_tool_path.exists():
69 | os.remove(test_tool_path)
70 |
71 |
72 | def test_hot_reload_decorator_update():
73 | """
74 | Test that the Agent detects updates to tools created with @tool decorator.
75 | """
76 | # Set up the tools directory in current working directory
77 | tools_dir = Path.cwd() / "tools"
78 | os.makedirs(tools_dir, exist_ok=True)
79 |
80 | # Tool path that will need cleanup - make sure filename matches function name
81 | test_tool_path = tools_dir / "greeting.py"
82 |
83 | try:
84 | # Create an Agent instance
85 | agent = Agent()
86 |
87 | # Create the initial version of the tool
88 | with open(test_tool_path, "w") as f:
89 | f.write("""
90 | from strands import tool
91 |
92 | @tool
93 | def greeting(name: str) -> str:
94 | \"\"\"Generate a simple greeting.\"\"\"
95 | return f"Hello, {name}!"
96 | """)
97 |
98 | # Wait for tool detection
99 | time.sleep(3)
100 |
101 | # Verify the tool was loaded
102 | assert "greeting" in agent.tool_names, "Agent should have detected and loaded the greeting tool"
103 |
104 | # Test calling the tool
105 | result1 = agent.tool.greeting(name="Strands")
106 | text_content1 = next((item.get("text") for item in result1.get("content", []) if "text" in item), "")
107 | assert "Hello, Strands!" in text_content1, "Tool should return simple greeting"
108 |
109 | # Update the tool with new functionality
110 | with open(test_tool_path, "w") as f:
111 | f.write("""
112 | from strands import tool
113 | import datetime
114 |
115 | @tool
116 | def greeting(name: str, formal: bool = False) -> str:
117 | \"\"\"Generate a greeting with optional formality.\"\"\"
118 | current_hour = datetime.datetime.now().hour
119 | time_of_day = "morning" if current_hour < 12 else "afternoon" if current_hour < 18 else "evening"
120 |
121 | if formal:
122 | return f"Good {time_of_day}, {name}. It's a pleasure to meet you."
123 | else:
124 | return f"Hey {name}! How's your {time_of_day} going?"
125 | """)
126 |
127 | # Wait for hot reload to detect the change
128 | time.sleep(3)
129 |
130 | # Test calling the updated tool
131 | result2 = agent.tool.greeting(name="Strands", formal=True)
132 | text_content2 = next((item.get("text") for item in result2.get("content", []) if "text" in item), "")
133 | assert "Good" in text_content2 and "Strands" in text_content2 and "pleasure to meet you" in text_content2
134 |
135 | # Test with informal parameter
136 | result3 = agent.tool.greeting(name="Strands", formal=False)
137 | text_content3 = next((item.get("text") for item in result3.get("content", []) if "text" in item), "")
138 | assert "Hey Strands!" in text_content3 and "going" in text_content3
139 |
140 | finally:
141 | # Clean up - remove the test file
142 | if test_tool_path.exists():
143 | os.remove(test_tool_path)
144 |
--------------------------------------------------------------------------------
/tests-integ/test_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/strands-agents/sdk-python/a64e80dfa5f5f6053a3fe53767f59fc3c5c1af95/tests-integ/test_image.png
--------------------------------------------------------------------------------
/tests-integ/test_mcp_client.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import threading
3 | import time
4 | from typing import List, Literal
5 |
6 | from mcp import StdioServerParameters, stdio_client
7 | from mcp.client.sse import sse_client
8 | from mcp.client.streamable_http import streamablehttp_client
9 | from mcp.types import ImageContent as MCPImageContent
10 |
11 | from strands import Agent
12 | from strands.tools.mcp.mcp_client import MCPClient
13 | from strands.tools.mcp.mcp_types import MCPTransport
14 | from strands.types.content import Message
15 | from strands.types.tools import ToolUse
16 |
17 |
18 | def start_calculator_server(transport: Literal["sse", "streamable-http"], port=int):
19 | """
20 | Initialize and start an MCP calculator server for integration testing.
21 |
22 | This function creates a FastMCP server instance that provides a simple
23 | calculator tool for performing addition operations. The server uses
24 | Server-Sent Events (SSE) transport for communication, making it accessible
25 | over HTTP.
26 | """
27 | from mcp.server import FastMCP
28 |
29 | mcp = FastMCP("Calculator Server", port=port)
30 |
31 | @mcp.tool(description="Calculator tool which performs calculations")
32 | def calculator(x: int, y: int) -> int:
33 | return x + y
34 |
35 | @mcp.tool(description="Generates a custom image")
36 | def generate_custom_image() -> MCPImageContent:
37 | try:
38 | with open("tests-integ/test_image.png", "rb") as image_file:
39 | encoded_image = base64.b64encode(image_file.read())
40 | return MCPImageContent(type="image", data=encoded_image, mimeType="image/png")
41 | except Exception as e:
42 | print("Error while generating custom image: {}".format(e))
43 |
44 | mcp.run(transport=transport)
45 |
46 |
47 | def test_mcp_client():
48 | """
49 | Test should yield output similar to the following
50 | {'role': 'user', 'content': [{'text': 'add 1 and 2, then echo the result back to me'}]}
51 | {'role': 'assistant', 'content': [{'text': "I'll help you add 1 and 2 and then echo the result back to you.\n\nFirst, I'll calculate 1 + 2:"}, {'toolUse': {'toolUseId': 'tooluse_17ptaKUxQB20ySZxwgiI_w', 'name': 'calculator', 'input': {'x': 1, 'y': 2}}}]}
52 | {'role': 'user', 'content': [{'toolResult': {'status': 'success', 'toolUseId': 'tooluse_17ptaKUxQB20ySZxwgiI_w', 'content': [{'text': '3'}]}}]}
53 | {'role': 'assistant', 'content': [{'text': "\n\nNow I'll echo the result back to you:"}, {'toolUse': {'toolUseId': 'tooluse_GlOc5SN8TE6ti8jVZJMBOg', 'name': 'echo', 'input': {'to_echo': '3'}}}]}
54 | {'role': 'user', 'content': [{'toolResult': {'status': 'success', 'toolUseId': 'tooluse_GlOc5SN8TE6ti8jVZJMBOg', 'content': [{'text': '3'}]}}]}
55 | {'role': 'assistant', 'content': [{'text': '\n\nThe result of adding 1 and 2 is 3.'}]}
56 | """ # noqa: E501
57 |
58 | server_thread = threading.Thread(
59 | target=start_calculator_server, kwargs={"transport": "sse", "port": 8000}, daemon=True
60 | )
61 | server_thread.start()
62 | time.sleep(2) # wait for server to startup completely
63 |
64 | sse_mcp_client = MCPClient(lambda: sse_client("http://127.0.0.1:8000/sse"))
65 | stdio_mcp_client = MCPClient(
66 | lambda: stdio_client(StdioServerParameters(command="python", args=["tests-integ/echo_server.py"]))
67 | )
68 | with sse_mcp_client, stdio_mcp_client:
69 | agent = Agent(tools=sse_mcp_client.list_tools_sync() + stdio_mcp_client.list_tools_sync())
70 | agent("add 1 and 2, then echo the result back to me")
71 |
72 | tool_use_content_blocks = _messages_to_content_blocks(agent.messages)
73 | assert any([block["name"] == "echo" for block in tool_use_content_blocks])
74 | assert any([block["name"] == "calculator" for block in tool_use_content_blocks])
75 |
76 | image_prompt = """
77 | Generate a custom image, then tell me if the image is red, blue, yellow, pink, orange, or green.
78 | RESPOND ONLY WITH THE COLOR
79 | """
80 | assert any(
81 | [
82 | "yellow".casefold() in block["text"].casefold()
83 | for block in agent(image_prompt).message["content"]
84 | if "text" in block
85 | ]
86 | )
87 |
88 |
89 | def test_can_reuse_mcp_client():
90 | stdio_mcp_client = MCPClient(
91 | lambda: stdio_client(StdioServerParameters(command="python", args=["tests-integ/echo_server.py"]))
92 | )
93 | with stdio_mcp_client:
94 | stdio_mcp_client.list_tools_sync()
95 | pass
96 | with stdio_mcp_client:
97 | agent = Agent(tools=stdio_mcp_client.list_tools_sync())
98 | agent("echo the following to me DOG")
99 |
100 | tool_use_content_blocks = _messages_to_content_blocks(agent.messages)
101 | assert any([block["name"] == "echo" for block in tool_use_content_blocks])
102 |
103 |
104 | def test_streamable_http_mcp_client():
105 | server_thread = threading.Thread(
106 | target=start_calculator_server, kwargs={"transport": "streamable-http", "port": 8001}, daemon=True
107 | )
108 | server_thread.start()
109 | time.sleep(2) # wait for server to startup completely
110 |
111 | def transport_callback() -> MCPTransport:
112 | return streamablehttp_client(url="http://127.0.0.1:8001/mcp")
113 |
114 | streamable_http_client = MCPClient(transport_callback)
115 | with streamable_http_client:
116 | agent = Agent(tools=streamable_http_client.list_tools_sync())
117 | agent("add 1 and 2 using a calculator")
118 |
119 | tool_use_content_blocks = _messages_to_content_blocks(agent.messages)
120 | assert any([block["name"] == "calculator" for block in tool_use_content_blocks])
121 |
122 |
123 | def _messages_to_content_blocks(messages: List[Message]) -> List[ToolUse]:
124 | return [block["toolUse"] for message in messages for block in message["content"] if "toolUse" in block]
125 |
--------------------------------------------------------------------------------
/tests-integ/test_model_anthropic.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pytest
4 |
5 | import strands
6 | from strands import Agent
7 | from strands.models.anthropic import AnthropicModel
8 |
9 |
10 | @pytest.fixture
11 | def model():
12 | return AnthropicModel(
13 | client_args={
14 | "api_key": os.getenv("ANTHROPIC_API_KEY"),
15 | },
16 | model_id="claude-3-7-sonnet-20250219",
17 | max_tokens=512,
18 | )
19 |
20 |
21 | @pytest.fixture
22 | def tools():
23 | @strands.tool
24 | def tool_time() -> str:
25 | return "12:00"
26 |
27 | @strands.tool
28 | def tool_weather() -> str:
29 | return "sunny"
30 |
31 | return [tool_time, tool_weather]
32 |
33 |
34 | @pytest.fixture
35 | def system_prompt():
36 | return "You are an AI assistant that uses & instead of ."
37 |
38 |
39 | @pytest.fixture
40 | def agent(model, tools, system_prompt):
41 | return Agent(model=model, tools=tools, system_prompt=system_prompt)
42 |
43 |
44 | @pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing")
45 | def test_agent(agent):
46 | result = agent("What is the time and weather in New York?")
47 | text = result.message["content"][0]["text"].lower()
48 |
49 | assert all(string in text for string in ["12:00", "sunny", "&"])
50 |
--------------------------------------------------------------------------------
/tests-integ/test_model_bedrock.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import strands
4 | from strands import Agent
5 | from strands.models import BedrockModel
6 |
7 |
8 | @pytest.fixture
9 | def system_prompt():
10 | return "You are an AI assistant that uses & instead of ."
11 |
12 |
13 | @pytest.fixture
14 | def streaming_model():
15 | return BedrockModel(
16 | model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0",
17 | streaming=True,
18 | )
19 |
20 |
21 | @pytest.fixture
22 | def non_streaming_model():
23 | return BedrockModel(
24 | model_id="us.meta.llama3-2-90b-instruct-v1:0",
25 | streaming=False,
26 | )
27 |
28 |
29 | @pytest.fixture
30 | def streaming_agent(streaming_model, system_prompt):
31 | return Agent(model=streaming_model, system_prompt=system_prompt, load_tools_from_directory=False)
32 |
33 |
34 | @pytest.fixture
35 | def non_streaming_agent(non_streaming_model, system_prompt):
36 | return Agent(model=non_streaming_model, system_prompt=system_prompt, load_tools_from_directory=False)
37 |
38 |
39 | def test_streaming_agent(streaming_agent):
40 | """Test agent with streaming model."""
41 | result = streaming_agent("Hello!")
42 |
43 | assert len(str(result)) > 0
44 |
45 |
46 | def test_non_streaming_agent(non_streaming_agent):
47 | """Test agent with non-streaming model."""
48 | result = non_streaming_agent("Hello!")
49 |
50 | assert len(str(result)) > 0
51 |
52 |
53 | def test_streaming_model_events(streaming_model):
54 | """Test streaming model events."""
55 | messages = [{"role": "user", "content": [{"text": "Hello"}]}]
56 |
57 | # Call converse and collect events
58 | events = list(streaming_model.converse(messages))
59 |
60 | # Verify basic structure of events
61 | assert any("messageStart" in event for event in events)
62 | assert any("contentBlockDelta" in event for event in events)
63 | assert any("messageStop" in event for event in events)
64 |
65 |
66 | def test_non_streaming_model_events(non_streaming_model):
67 | """Test non-streaming model events."""
68 | messages = [{"role": "user", "content": [{"text": "Hello"}]}]
69 |
70 | # Call converse and collect events
71 | events = list(non_streaming_model.converse(messages))
72 |
73 | # Verify basic structure of events
74 | assert any("messageStart" in event for event in events)
75 | assert any("contentBlockDelta" in event for event in events)
76 | assert any("messageStop" in event for event in events)
77 |
78 |
79 | def test_tool_use_streaming(streaming_model):
80 | """Test tool use with streaming model."""
81 |
82 | tool_was_called = False
83 |
84 | @strands.tool
85 | def calculator(expression: str) -> float:
86 | """Calculate the result of a mathematical expression."""
87 |
88 | nonlocal tool_was_called
89 | tool_was_called = True
90 | return eval(expression)
91 |
92 | agent = Agent(model=streaming_model, tools=[calculator], load_tools_from_directory=False)
93 | result = agent("What is 123 + 456?")
94 |
95 | # Print the full message content for debugging
96 | print("\nFull message content:")
97 | import json
98 |
99 | print(json.dumps(result.message["content"], indent=2))
100 |
101 | assert tool_was_called
102 |
103 |
104 | def test_tool_use_non_streaming(non_streaming_model):
105 | """Test tool use with non-streaming model."""
106 |
107 | tool_was_called = False
108 |
109 | @strands.tool
110 | def calculator(expression: str) -> float:
111 | """Calculate the result of a mathematical expression."""
112 |
113 | nonlocal tool_was_called
114 | tool_was_called = True
115 | return eval(expression)
116 |
117 | agent = Agent(model=non_streaming_model, tools=[calculator], load_tools_from_directory=False)
118 | agent("What is 123 + 456?")
119 |
120 | assert tool_was_called
121 |
--------------------------------------------------------------------------------
/tests-integ/test_model_litellm.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import strands
4 | from strands import Agent
5 | from strands.models.litellm import LiteLLMModel
6 |
7 |
8 | @pytest.fixture
9 | def model():
10 | return LiteLLMModel(model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0")
11 |
12 |
13 | @pytest.fixture
14 | def tools():
15 | @strands.tool
16 | def tool_time() -> str:
17 | return "12:00"
18 |
19 | @strands.tool
20 | def tool_weather() -> str:
21 | return "sunny"
22 |
23 | return [tool_time, tool_weather]
24 |
25 |
26 | @pytest.fixture
27 | def agent(model, tools):
28 | return Agent(model=model, tools=tools)
29 |
30 |
31 | def test_agent(agent):
32 | result = agent("What is the time and weather in New York?")
33 | text = result.message["content"][0]["text"].lower()
34 |
35 | assert all(string in text for string in ["12:00", "sunny"])
36 |
--------------------------------------------------------------------------------
/tests-integ/test_model_llamaapi.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates
2 | import os
3 |
4 | import pytest
5 |
6 | import strands
7 | from strands import Agent
8 | from strands.models.llamaapi import LlamaAPIModel
9 |
10 |
11 | @pytest.fixture
12 | def model():
13 | return LlamaAPIModel(
14 | model_id="Llama-4-Maverick-17B-128E-Instruct-FP8",
15 | client_args={
16 | "api_key": os.getenv("LLAMA_API_KEY"),
17 | },
18 | )
19 |
20 |
21 | @pytest.fixture
22 | def tools():
23 | @strands.tool
24 | def tool_time() -> str:
25 | return "12:00"
26 |
27 | @strands.tool
28 | def tool_weather() -> str:
29 | return "sunny"
30 |
31 | return [tool_time, tool_weather]
32 |
33 |
34 | @pytest.fixture
35 | def agent(model, tools):
36 | return Agent(model=model, tools=tools)
37 |
38 |
39 | @pytest.mark.skipif(
40 | "LLAMA_API_KEY" not in os.environ,
41 | reason="LLAMA_API_KEY environment variable missing",
42 | )
43 | def test_agent(agent):
44 | result = agent("What is the time and weather in New York?")
45 | text = result.message["content"][0]["text"].lower()
46 |
47 | assert all(string in text for string in ["12:00", "sunny"])
48 |
--------------------------------------------------------------------------------
/tests-integ/test_model_openai.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pytest
4 |
5 | import strands
6 | from strands import Agent
7 | from strands.models.openai import OpenAIModel
8 |
9 |
10 | @pytest.fixture
11 | def model():
12 | return OpenAIModel(
13 | model_id="gpt-4o",
14 | client_args={
15 | "api_key": os.getenv("OPENAI_API_KEY"),
16 | },
17 | )
18 |
19 |
20 | @pytest.fixture
21 | def tools():
22 | @strands.tool
23 | def tool_time() -> str:
24 | return "12:00"
25 |
26 | @strands.tool
27 | def tool_weather() -> str:
28 | return "sunny"
29 |
30 | return [tool_time, tool_weather]
31 |
32 |
33 | @pytest.fixture
34 | def agent(model, tools):
35 | return Agent(model=model, tools=tools)
36 |
37 |
38 | @pytest.mark.skipif(
39 | "OPENAI_API_KEY" not in os.environ,
40 | reason="OPENAI_API_KEY environment variable missing",
41 | )
42 | def test_agent(agent):
43 | result = agent("What is the time and weather in New York?")
44 | text = result.message["content"][0]["text"].lower()
45 |
46 | assert all(string in text for string in ["12:00", "sunny"])
47 |
--------------------------------------------------------------------------------
/tests-integ/test_stream_agent.py:
--------------------------------------------------------------------------------
1 | """
2 | Test script for Strands' custom callback handler functionality.
3 | Demonstrates different patterns of callback handling and processing.
4 | """
5 |
6 | import logging
7 |
8 | from strands import Agent
9 |
10 | logging.getLogger("strands").setLevel(logging.DEBUG)
11 | logging.basicConfig(format="%(levelname)s | %(name)s | %(message)s", handlers=[logging.StreamHandler()])
12 |
13 |
14 | class ToolCountingCallbackHandler:
15 | def __init__(self):
16 | self.tool_count = 0
17 | self.message_count = 0
18 |
19 | def callback_handler(self, **kwargs) -> None:
20 | """
21 | Custom callback handler that processes and displays different types of events.
22 |
23 | Args:
24 | **kwargs: Callback event data including:
25 | - data: Regular output
26 | - complete: Completion status
27 | - message: Message processing
28 | - current_tool_use: Tool execution
29 | """
30 | # Extract event data
31 | data = kwargs.get("data", "")
32 | complete = kwargs.get("complete", False)
33 | message = kwargs.get("message", {})
34 | current_tool_use = kwargs.get("current_tool_use", {})
35 |
36 | # Handle regular data output
37 | if data:
38 | print(f"🔄 Data: {data}")
39 |
40 | # Handle tool execution events
41 | if current_tool_use:
42 | self.tool_count += 1
43 | tool_name = current_tool_use.get("name", "")
44 | tool_input = current_tool_use.get("input", {})
45 | print(f"🛠️ Tool Execution #{self.tool_count}\nTool: {tool_name}\nInput: {tool_input}")
46 |
47 | # Handle message processing
48 | if message:
49 | self.message_count += 1
50 | print(f"📝 Message #{self.message_count}")
51 |
52 | # Handle completion
53 | if complete:
54 | self.console.print("✨ Callback Complete", style="bold green")
55 |
56 |
57 | def test_basic_interaction():
58 | """Test basic AGI interaction with custom callback handler."""
59 | print("\nTesting Basic Interaction")
60 |
61 | # Initialize agent with custom handler
62 | agent = Agent(
63 | callback_handler=ToolCountingCallbackHandler().callback_handler,
64 | load_tools_from_directory=False,
65 | )
66 |
67 | # Simple prompt to test callbacking
68 | agent("Tell me a short joke from your general knowledge")
69 |
70 | print("\nBasic Interaction Complete")
71 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/strands-agents/sdk-python/a64e80dfa5f5f6053a3fe53767f59fc3c5c1af95/tests/__init__.py
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import configparser
2 | import logging
3 | import os
4 | import sys
5 |
6 | import boto3
7 | import moto
8 | import pytest
9 |
10 | ## Moto
11 |
12 | # Get the log level from the environment variable
13 | log_level = os.environ.get("LOG_LEVEL", "INFO").upper()
14 |
15 | logging.getLogger("strands").setLevel(log_level)
16 | logging.basicConfig(
17 | format="%(levelname)s | %(name)s | %(message)s", handlers=[logging.StreamHandler(stream=sys.stdout)]
18 | )
19 |
20 |
21 | @pytest.fixture
22 | def moto_env(monkeypatch):
23 | monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test")
24 | monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test")
25 | monkeypatch.setenv("AWS_SECURITY_TOKEN", "test")
26 | monkeypatch.setenv("AWS_DEFAULT_REGION", "us-west-2")
27 | monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False)
28 | monkeypatch.delenv("OTEL_EXPORTER_OTLP_HEADERS", raising=False)
29 |
30 |
31 | @pytest.fixture
32 | def moto_mock_aws():
33 | with moto.mock_aws():
34 | yield
35 |
36 |
37 | @pytest.fixture
38 | def moto_cloudwatch_client():
39 | return boto3.client("cloudwatch")
40 |
41 |
42 | ## Boto3
43 |
44 |
45 | @pytest.fixture
46 | def boto3_profile_name():
47 | return "test-profile"
48 |
49 |
50 | @pytest.fixture
51 | def boto3_profile(boto3_profile_name):
52 | config = configparser.ConfigParser()
53 | config[boto3_profile_name] = {
54 | "aws_access_key_id": "test",
55 | "aws_secret_access_key": "test",
56 | }
57 |
58 | return config
59 |
60 |
61 | @pytest.fixture
62 | def boto3_profile_path(boto3_profile, tmp_path, monkeypatch):
63 | path = tmp_path / ".aws/credentials"
64 | path.parent.mkdir(exist_ok=True)
65 | with path.open("w") as fp:
66 | boto3_profile.write(fp)
67 |
68 | monkeypatch.setenv("AWS_SHARED_CREDENTIALS_FILE", str(path))
69 |
70 | return path
71 |
--------------------------------------------------------------------------------
/tests/strands/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/strands-agents/sdk-python/a64e80dfa5f5f6053a3fe53767f59fc3c5c1af95/tests/strands/__init__.py
--------------------------------------------------------------------------------
/tests/strands/agent/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/strands-agents/sdk-python/a64e80dfa5f5f6053a3fe53767f59fc3c5c1af95/tests/strands/agent/__init__.py
--------------------------------------------------------------------------------
/tests/strands/agent/test_agent_result.py:
--------------------------------------------------------------------------------
1 | import unittest.mock
2 | from typing import cast
3 |
4 | import pytest
5 |
6 | from strands.agent.agent_result import AgentResult
7 | from strands.telemetry.metrics import EventLoopMetrics
8 | from strands.types.content import Message
9 | from strands.types.streaming import StopReason
10 |
11 |
12 | @pytest.fixture
13 | def mock_metrics():
14 | return unittest.mock.Mock(spec=EventLoopMetrics)
15 |
16 |
17 | @pytest.fixture
18 | def simple_message():
19 | return {"role": "assistant", "content": [{"text": "Hello world!"}]}
20 |
21 |
22 | @pytest.fixture
23 | def complex_message():
24 | return {
25 | "role": "assistant",
26 | "content": [
27 | {"text": "First paragraph"},
28 | {"text": "Second paragraph"},
29 | {"non_text_content": "This should be ignored"},
30 | {"text": "Third paragraph"},
31 | ],
32 | }
33 |
34 |
35 | @pytest.fixture
36 | def empty_message():
37 | return {"role": "assistant", "content": []}
38 |
39 |
40 | def test__init__(mock_metrics, simple_message: Message):
41 | """Test that AgentResult can be properly initialized with all required fields."""
42 | stop_reason: StopReason = "end_turn"
43 | state = {"key": "value"}
44 |
45 | result = AgentResult(stop_reason=stop_reason, message=simple_message, metrics=mock_metrics, state=state)
46 |
47 | assert result.stop_reason == stop_reason
48 | assert result.message == simple_message
49 | assert result.metrics == mock_metrics
50 | assert result.state == state
51 |
52 |
53 | def test__str__simple(mock_metrics, simple_message: Message):
54 | """Test that str() works with a simple message."""
55 | result = AgentResult(stop_reason="end_turn", message=simple_message, metrics=mock_metrics, state={})
56 |
57 | message_string = str(result)
58 | assert message_string == "Hello world!\n"
59 |
60 |
61 | def test__str__complex(mock_metrics, complex_message: Message):
62 | """Test that str() works with a complex message with multiple text blocks."""
63 | result = AgentResult(stop_reason="end_turn", message=complex_message, metrics=mock_metrics, state={})
64 |
65 | message_string = str(result)
66 | assert message_string == "First paragraph\nSecond paragraph\nThird paragraph\n"
67 |
68 |
69 | def test__str__empty(mock_metrics, empty_message: Message):
70 | """Test that str() works with an empty message."""
71 | result = AgentResult(stop_reason="end_turn", message=empty_message, metrics=mock_metrics, state={})
72 |
73 | message_string = str(result)
74 | assert message_string == ""
75 |
76 |
77 | def test__str__no_content(mock_metrics):
78 | """Test that str() works with a message that has no content field."""
79 | message_without_content = cast(Message, {"role": "assistant"})
80 |
81 | result = AgentResult(stop_reason="end_turn", message=message_without_content, metrics=mock_metrics, state={})
82 |
83 | message_string = str(result)
84 | assert message_string == ""
85 |
86 |
87 | def test__str__non_dict_content(mock_metrics):
88 | """Test that str() handles non-dictionary content items gracefully."""
89 | message_with_non_dict = cast(
90 | Message,
91 | {"role": "assistant", "content": [{"text": "Valid text"}, "Not a dictionary", {"text": "More valid text"}]},
92 | )
93 |
94 | result = AgentResult(stop_reason="end_turn", message=message_with_non_dict, metrics=mock_metrics, state={})
95 |
96 | message_string = str(result)
97 | assert message_string == "Valid text\nMore valid text\n"
98 |
--------------------------------------------------------------------------------
/tests/strands/event_loop/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/strands-agents/sdk-python/a64e80dfa5f5f6053a3fe53767f59fc3c5c1af95/tests/strands/event_loop/__init__.py
--------------------------------------------------------------------------------
/tests/strands/event_loop/test_error_handler.py:
--------------------------------------------------------------------------------
1 | import unittest.mock
2 |
3 | import botocore
4 | import pytest
5 |
6 | import strands
7 | from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException
8 |
9 |
10 | @pytest.fixture
11 | def callback_handler():
12 | return unittest.mock.Mock()
13 |
14 |
15 | @pytest.fixture
16 | def kwargs():
17 | return {"request_state": "value"}
18 |
19 |
20 | @pytest.fixture
21 | def tool_handler():
22 | return unittest.mock.Mock()
23 |
24 |
25 | @pytest.fixture
26 | def model():
27 | return unittest.mock.Mock()
28 |
29 |
30 | @pytest.fixture
31 | def tool_config():
32 | return {}
33 |
34 |
35 | @pytest.fixture
36 | def system_prompt():
37 | return "prompt."
38 |
39 |
40 | @pytest.fixture
41 | def sdk_event_loop():
42 | with unittest.mock.patch.object(strands.event_loop.event_loop, "recurse_event_loop") as mock:
43 | yield mock
44 |
45 |
46 | @pytest.fixture
47 | def event_stream_error(request):
48 | message = request.param
49 | return botocore.exceptions.EventStreamError({"Error": {"Message": message}}, "mock_operation")
50 |
51 |
52 | def test_handle_throttling_error(callback_handler, kwargs):
53 | exception = ModelThrottledException("ThrottlingException | ConverseStream")
54 | max_attempts = 2
55 | delay = 0.1
56 | max_delay = 1
57 |
58 | tru_retries = []
59 | tru_delays = []
60 | for attempt in range(max_attempts):
61 | retry, delay = strands.event_loop.error_handler.handle_throttling_error(
62 | exception, attempt, max_attempts, delay, max_delay, callback_handler, kwargs
63 | )
64 |
65 | tru_retries.append(retry)
66 | tru_delays.append(delay)
67 |
68 | exp_retries = [True, False]
69 | exp_delays = [0.2, 0.2]
70 |
71 | assert tru_retries == exp_retries and tru_delays == exp_delays
72 |
73 | callback_handler.assert_has_calls(
74 | [
75 | unittest.mock.call(event_loop_throttled_delay=0.1, request_state="value"),
76 | unittest.mock.call(force_stop=True, force_stop_reason=str(exception)),
77 | ]
78 | )
79 |
80 |
81 | def test_handle_throttling_error_does_not_exist(callback_handler, kwargs):
82 | exception = ModelThrottledException("Other Error")
83 | attempt = 0
84 | max_attempts = 1
85 | delay = 1
86 | max_delay = 1
87 |
88 | tru_retry, tru_delay = strands.event_loop.error_handler.handle_throttling_error(
89 | exception, attempt, max_attempts, delay, max_delay, callback_handler, kwargs
90 | )
91 |
92 | exp_retry = False
93 | exp_delay = 1
94 |
95 | assert tru_retry == exp_retry and tru_delay == exp_delay
96 |
97 | callback_handler.assert_called_with(force_stop=True, force_stop_reason=str(exception))
98 |
99 |
100 | @pytest.mark.parametrize("event_stream_error", ["Input is too long for requested model"], indirect=True)
101 | def test_handle_input_too_long_error(
102 | sdk_event_loop,
103 | event_stream_error,
104 | model,
105 | system_prompt,
106 | tool_config,
107 | callback_handler,
108 | tool_handler,
109 | kwargs,
110 | ):
111 | sdk_event_loop.return_value = "success"
112 |
113 | messages = [
114 | {
115 | "role": "user",
116 | "content": [
117 | {"toolResult": {"toolUseId": "t1", "status": "success", "content": [{"text": "needs truncation"}]}}
118 | ],
119 | }
120 | ]
121 |
122 | tru_result = strands.event_loop.error_handler.handle_input_too_long_error(
123 | event_stream_error,
124 | messages,
125 | model,
126 | system_prompt,
127 | tool_config,
128 | callback_handler,
129 | tool_handler,
130 | kwargs,
131 | )
132 | exp_result = "success"
133 |
134 | tru_messages = messages
135 | exp_messages = [
136 | {
137 | "role": "user",
138 | "content": [
139 | {
140 | "toolResult": {
141 | "toolUseId": "t1",
142 | "status": "error",
143 | "content": [{"text": "The tool result was too large!"}],
144 | },
145 | },
146 | ],
147 | },
148 | ]
149 |
150 | assert tru_result == exp_result and tru_messages == exp_messages
151 |
152 | sdk_event_loop.assert_called_once_with(
153 | model=model,
154 | system_prompt=system_prompt,
155 | messages=messages,
156 | tool_config=tool_config,
157 | callback_handler=callback_handler,
158 | tool_handler=tool_handler,
159 | request_state="value",
160 | )
161 |
162 | callback_handler.assert_not_called()
163 |
164 |
165 | @pytest.mark.parametrize("event_stream_error", ["Other error"], indirect=True)
166 | def test_handle_input_too_long_error_does_not_exist(
167 | sdk_event_loop,
168 | event_stream_error,
169 | model,
170 | system_prompt,
171 | tool_config,
172 | callback_handler,
173 | tool_handler,
174 | kwargs,
175 | ):
176 | messages = []
177 |
178 | with pytest.raises(ContextWindowOverflowException):
179 | strands.event_loop.error_handler.handle_input_too_long_error(
180 | event_stream_error,
181 | messages,
182 | model,
183 | system_prompt,
184 | tool_config,
185 | callback_handler,
186 | tool_handler,
187 | kwargs,
188 | )
189 |
190 | sdk_event_loop.assert_not_called()
191 | callback_handler.assert_called_with(force_stop=True, force_stop_reason=str(event_stream_error))
192 |
193 |
194 | @pytest.mark.parametrize("event_stream_error", ["Input is too long for requested model"], indirect=True)
195 | def test_handle_input_too_long_error_no_tool_result(
196 | sdk_event_loop,
197 | event_stream_error,
198 | model,
199 | system_prompt,
200 | tool_config,
201 | callback_handler,
202 | tool_handler,
203 | kwargs,
204 | ):
205 | messages = []
206 |
207 | with pytest.raises(ContextWindowOverflowException):
208 | strands.event_loop.error_handler.handle_input_too_long_error(
209 | event_stream_error,
210 | messages,
211 | model,
212 | system_prompt,
213 | tool_config,
214 | callback_handler,
215 | tool_handler,
216 | kwargs,
217 | )
218 |
219 | sdk_event_loop.assert_not_called()
220 | callback_handler.assert_called_with(force_stop=True, force_stop_reason=str(event_stream_error))
221 |
--------------------------------------------------------------------------------
/tests/strands/event_loop/test_message_processor.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import pytest
4 |
5 | from strands.event_loop import message_processor
6 |
7 |
8 | @pytest.mark.parametrize(
9 | "messages,expected,expected_messages",
10 | [
11 | # Orphaned toolUse with empty input, no toolResult
12 | (
13 | [
14 | {"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "input": {}, "name": "foo"}}]},
15 | {"role": "user", "content": [{"toolResult": {"toolUseId": "2"}}]},
16 | ],
17 | True,
18 | [
19 | {"role": "assistant", "content": [{"text": "[Attempted to use foo, but operation was canceled]"}]},
20 | {"role": "user", "content": [{"toolResult": {"toolUseId": "2"}}]},
21 | ],
22 | ),
23 | # toolUse with input, has matching toolResult
24 | (
25 | [
26 | {"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "input": {"a": 1}, "name": "foo"}}]},
27 | {"role": "user", "content": [{"toolResult": {"toolUseId": "1"}}]},
28 | ],
29 | False,
30 | [
31 | {"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "input": {"a": 1}, "name": "foo"}}]},
32 | {"role": "user", "content": [{"toolResult": {"toolUseId": "1"}}]},
33 | ],
34 | ),
35 | # No messages
36 | (
37 | [],
38 | False,
39 | [],
40 | ),
41 | ],
42 | )
43 | def test_clean_orphaned_empty_tool_uses(messages, expected, expected_messages):
44 | test_messages = copy.deepcopy(messages)
45 | result = message_processor.clean_orphaned_empty_tool_uses(test_messages)
46 | assert result == expected
47 | assert test_messages == expected_messages
48 |
49 |
50 | @pytest.mark.parametrize(
51 | "messages,expected_idx",
52 | [
53 | (
54 | [
55 | {"role": "user", "content": [{"text": "hi"}]},
56 | {"role": "user", "content": [{"toolResult": {"toolUseId": "1"}}]},
57 | {"role": "assistant", "content": [{"text": "ok"}]},
58 | ],
59 | 1,
60 | ),
61 | (
62 | [
63 | {"role": "user", "content": [{"text": "hi"}]},
64 | {"role": "assistant", "content": [{"text": "ok"}]},
65 | ],
66 | None,
67 | ),
68 | (
69 | [],
70 | None,
71 | ),
72 | ],
73 | )
74 | def test_find_last_message_with_tool_results(messages, expected_idx):
75 | idx = message_processor.find_last_message_with_tool_results(messages)
76 | assert idx == expected_idx
77 |
78 |
79 | @pytest.mark.parametrize(
80 | "messages,msg_idx,expected_changed,expected_content",
81 | [
82 | (
83 | [
84 | {
85 | "role": "user",
86 | "content": [{"toolResult": {"toolUseId": "1", "status": "ok", "content": [{"text": "big"}]}}],
87 | }
88 | ],
89 | 0,
90 | True,
91 | [
92 | {
93 | "toolResult": {
94 | "toolUseId": "1",
95 | "status": "error",
96 | "content": [{"text": "The tool result was too large!"}],
97 | }
98 | }
99 | ],
100 | ),
101 | (
102 | [{"role": "user", "content": [{"text": "no tool result"}]}],
103 | 0,
104 | False,
105 | [{"text": "no tool result"}],
106 | ),
107 | (
108 | [],
109 | 0,
110 | False,
111 | [],
112 | ),
113 | (
114 | [{"role": "user", "content": [{"toolResult": {"toolUseId": "1"}}]}],
115 | 2,
116 | False,
117 | [{"toolResult": {"toolUseId": "1"}}],
118 | ),
119 | ],
120 | )
121 | def test_truncate_tool_results(messages, msg_idx, expected_changed, expected_content):
122 | test_messages = copy.deepcopy(messages)
123 | changed = message_processor.truncate_tool_results(test_messages, msg_idx)
124 | assert changed == expected_changed
125 | if 0 <= msg_idx < len(test_messages):
126 | assert test_messages[msg_idx]["content"] == expected_content
127 | else:
128 | assert test_messages == messages
129 |
--------------------------------------------------------------------------------
/tests/strands/handlers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/strands-agents/sdk-python/a64e80dfa5f5f6053a3fe53767f59fc3c5c1af95/tests/strands/handlers/__init__.py
--------------------------------------------------------------------------------
/tests/strands/handlers/test_callback_handler.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests for the SDK callback handler module.
3 |
4 | These tests ensure the basic print-based callback handler in the SDK functions correctly.
5 | """
6 |
7 | import unittest.mock
8 |
9 | import pytest
10 |
11 | from strands.handlers.callback_handler import CompositeCallbackHandler, PrintingCallbackHandler
12 |
13 |
14 | @pytest.fixture
15 | def handler():
16 | """Create a fresh PrintingCallbackHandler instance for testing."""
17 | return PrintingCallbackHandler()
18 |
19 |
20 | @pytest.fixture
21 | def mock_print():
22 | with unittest.mock.patch("builtins.print") as mock:
23 | yield mock
24 |
25 |
26 | def test_call_with_empty_args(handler, mock_print):
27 | """Test calling the handler with no arguments."""
28 | handler()
29 | # No output should be printed
30 | mock_print.assert_not_called()
31 |
32 |
33 | def test_call_handler_reasoningText(handler, mock_print):
34 | """Test calling the handler with reasoningText."""
35 | handler(reasoningText="This is reasoning text")
36 | # Should print reasoning text without newline
37 | mock_print.assert_called_once_with("This is reasoning text", end="")
38 |
39 |
40 | def test_call_without_reasoningText(handler, mock_print):
41 | """Test calling the handler without reasoningText argument."""
42 | handler(data="Some output")
43 | # Should only print data, not reasoningText
44 | mock_print.assert_called_once_with("Some output", end="")
45 |
46 |
47 | def test_call_with_reasoningText_and_data(handler, mock_print):
48 | """Test calling the handler with both reasoningText and data."""
49 | handler(reasoningText="Reasoning", data="Output")
50 | # Should print reasoningText and data, both without newline
51 | calls = [
52 | unittest.mock.call("Reasoning", end=""),
53 | unittest.mock.call("Output", end=""),
54 | ]
55 | mock_print.assert_has_calls(calls)
56 |
57 |
58 | def test_call_with_data_incomplete(handler, mock_print):
59 | """Test calling the handler with data but not complete."""
60 | handler(data="Test output")
61 | # Should print without newline
62 | mock_print.assert_called_once_with("Test output", end="")
63 |
64 |
65 | def test_call_with_data_complete(handler, mock_print):
66 | """Test calling the handler with data and complete=True."""
67 | handler(data="Test output", complete=True)
68 | # Should print with newline
69 | # The handler prints the data, and then also prints a newline when complete=True and data exists
70 | assert mock_print.call_count == 2
71 | mock_print.assert_any_call("Test output", end="\n")
72 | mock_print.assert_any_call("\n")
73 |
74 |
75 | def test_call_with_current_tool_use_new(handler, mock_print):
76 | """Test calling the handler with a new tool use."""
77 | current_tool_use = {"name": "test_tool", "input": {"param": "value"}}
78 |
79 | handler(current_tool_use=current_tool_use)
80 |
81 | # Should print tool information
82 | mock_print.assert_called_once_with("\nTool #1: test_tool")
83 |
84 | # Should update the handler state
85 | assert handler.tool_count == 1
86 | assert handler.previous_tool_use == current_tool_use
87 |
88 |
89 | def test_call_with_current_tool_use_same(handler, mock_print):
90 | """Test calling the handler with the same tool use twice."""
91 | current_tool_use = {"name": "test_tool", "input": {"param": "value"}}
92 |
93 | # First call
94 | handler(current_tool_use=current_tool_use)
95 | mock_print.reset_mock()
96 |
97 | # Second call with same tool use
98 | handler(current_tool_use=current_tool_use)
99 |
100 | # Should not print tool information again
101 | mock_print.assert_not_called()
102 |
103 | # Tool count should not increase
104 | assert handler.tool_count == 1
105 |
106 |
107 | def test_call_with_current_tool_use_different(handler, mock_print):
108 | """Test calling the handler with different tool uses."""
109 | first_tool_use = {"name": "first_tool", "input": {"param": "value1"}}
110 | second_tool_use = {"name": "second_tool", "input": {"param": "value2"}}
111 |
112 | # First call
113 | handler(current_tool_use=first_tool_use)
114 | mock_print.reset_mock()
115 |
116 | # Second call with different tool use
117 | handler(current_tool_use=second_tool_use)
118 |
119 | # Should print info for the new tool
120 | mock_print.assert_called_once_with("\nTool #2: second_tool")
121 |
122 | # Tool count should increase
123 | assert handler.tool_count == 2
124 | assert handler.previous_tool_use == second_tool_use
125 |
126 |
127 | def test_call_with_data_and_complete_extra_newline(handler, mock_print):
128 | """Test that an extra newline is printed when data is complete."""
129 | handler(data="Test output", complete=True)
130 |
131 | # The handler prints the data with newline and an extra newline for completion
132 | assert mock_print.call_count == 2
133 | mock_print.assert_any_call("Test output", end="\n")
134 | mock_print.assert_any_call("\n")
135 |
136 |
137 | def test_call_with_message_no_effect(handler, mock_print):
138 | """Test that passing a message without special content has no effect."""
139 | message = {"role": "user", "content": [{"text": "Hello"}]}
140 |
141 | handler(message=message)
142 |
143 | # No print calls should be made
144 | mock_print.assert_not_called()
145 |
146 |
147 | def test_call_with_multiple_parameters(handler, mock_print):
148 | """Test calling handler with multiple parameters."""
149 | current_tool_use = {"name": "test_tool", "input": {"param": "value"}}
150 |
151 | handler(data="Test output", complete=True, current_tool_use=current_tool_use)
152 |
153 | # Should print data with newline, an extra newline for completion, and tool information
154 | assert mock_print.call_count == 3
155 | mock_print.assert_any_call("Test output", end="\n")
156 | mock_print.assert_any_call("\n")
157 | mock_print.assert_any_call("\nTool #1: test_tool")
158 |
159 |
160 | def test_unknown_tool_name_handling(handler, mock_print):
161 | """Test handling of a tool use without a name."""
162 | # The SDK implementation doesn't have a fallback for tool uses without a name field
163 | # It checks for both presence of current_tool_use and current_tool_use.get("name")
164 | current_tool_use = {"input": {"param": "value"}, "name": "Unknown tool"}
165 |
166 | handler(current_tool_use=current_tool_use)
167 |
168 | # Should print the tool information
169 | mock_print.assert_called_once_with("\nTool #1: Unknown tool")
170 |
171 |
172 | def test_tool_use_empty_object(handler, mock_print):
173 | """Test handling of an empty tool use object."""
174 | # Tool use is an empty dict
175 | current_tool_use = {}
176 |
177 | handler(current_tool_use=current_tool_use)
178 |
179 | # Should not print anything
180 | mock_print.assert_not_called()
181 |
182 | # Should not update state
183 | assert handler.tool_count == 0
184 | assert handler.previous_tool_use is None
185 |
186 |
187 | def test_composite_handler_forwards_to_all_handlers():
188 | mock_handlers = [unittest.mock.Mock() for _ in range(3)]
189 | composite_handler = CompositeCallbackHandler(*mock_handlers)
190 |
191 | """Test that calling the handler forwards the call to all handlers."""
192 | # Create test arguments
193 | kwargs = {
194 | "data": "Test output",
195 | "complete": True,
196 | "current_tool_use": {"name": "test_tool", "input": {"param": "value"}},
197 | }
198 |
199 | # Call the composite handler
200 | composite_handler(**kwargs)
201 |
202 | # Verify each handler was called with the same arguments
203 | for handler in mock_handlers:
204 | handler.assert_called_once_with(**kwargs)
205 |
--------------------------------------------------------------------------------
/tests/strands/handlers/test_tool_handler.py:
--------------------------------------------------------------------------------
1 | import unittest.mock
2 |
3 | import pytest
4 |
5 | import strands
6 |
7 |
8 | @pytest.fixture
9 | def tool_registry():
10 | return strands.tools.registry.ToolRegistry()
11 |
12 |
13 | @pytest.fixture
14 | def tool_handler(tool_registry):
15 | return strands.handlers.tool_handler.AgentToolHandler(tool_registry)
16 |
17 |
18 | @pytest.fixture
19 | def tool_use_identity(tool_registry):
20 | @strands.tools.tool
21 | def identity(a: int) -> int:
22 | return a
23 |
24 | identity_tool = strands.tools.tools.FunctionTool(identity)
25 | tool_registry.register_tool(identity_tool)
26 |
27 | return {"toolUseId": "identity", "name": "identity", "input": {"a": 1}}
28 |
29 |
30 | @pytest.fixture
31 | def tool_use_error(tool_registry):
32 | def error():
33 | return
34 |
35 | error.TOOL_SPEC = {"invalid": True}
36 |
37 | error_tool = strands.tools.tools.FunctionTool(error)
38 | tool_registry.register_tool(error_tool)
39 |
40 | return {"toolUseId": "error", "name": "error", "input": {}}
41 |
42 |
43 | def test_preprocess(tool_handler, tool_use_identity):
44 | tool_handler.preprocess(tool_use_identity, tool_config={})
45 |
46 |
47 | def test_process(tool_handler, tool_use_identity):
48 | tru_result = tool_handler.process(
49 | tool_use_identity,
50 | model=unittest.mock.Mock(),
51 | system_prompt="p1",
52 | messages=[],
53 | tool_config={},
54 | callback_handler=unittest.mock.Mock(),
55 | )
56 | exp_result = {"toolUseId": "identity", "status": "success", "content": [{"text": "1"}]}
57 |
58 | assert tru_result == exp_result
59 |
60 |
61 | def test_process_missing_tool(tool_handler):
62 | tru_result = tool_handler.process(
63 | tool={"toolUseId": "missing", "name": "missing", "input": {}},
64 | model=unittest.mock.Mock(),
65 | system_prompt="p1",
66 | messages=[],
67 | tool_config={},
68 | callback_handler=unittest.mock.Mock(),
69 | )
70 | exp_result = {
71 | "toolUseId": "missing",
72 | "status": "error",
73 | "content": [{"text": "Unknown tool: missing"}],
74 | }
75 |
76 | assert tru_result == exp_result
77 |
--------------------------------------------------------------------------------
/tests/strands/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/strands-agents/sdk-python/a64e80dfa5f5f6053a3fe53767f59fc3c5c1af95/tests/strands/models/__init__.py
--------------------------------------------------------------------------------
/tests/strands/models/test_litellm.py:
--------------------------------------------------------------------------------
1 | import unittest.mock
2 |
3 | import pytest
4 |
5 | import strands
6 | from strands.models.litellm import LiteLLMModel
7 |
8 |
9 | @pytest.fixture
10 | def litellm_client_cls():
11 | with unittest.mock.patch.object(strands.models.litellm.litellm, "LiteLLM") as mock_client_cls:
12 | yield mock_client_cls
13 |
14 |
15 | @pytest.fixture
16 | def litellm_client(litellm_client_cls):
17 | return litellm_client_cls.return_value
18 |
19 |
20 | @pytest.fixture
21 | def model_id():
22 | return "m1"
23 |
24 |
25 | @pytest.fixture
26 | def model(litellm_client, model_id):
27 | _ = litellm_client
28 |
29 | return LiteLLMModel(model_id=model_id)
30 |
31 |
32 | @pytest.fixture
33 | def messages():
34 | return [{"role": "user", "content": [{"text": "test"}]}]
35 |
36 |
37 | @pytest.fixture
38 | def system_prompt():
39 | return "s1"
40 |
41 |
42 | def test__init__(litellm_client_cls, model_id):
43 | model = LiteLLMModel({"api_key": "k1"}, model_id=model_id, params={"max_tokens": 1})
44 |
45 | tru_config = model.get_config()
46 | exp_config = {"model_id": "m1", "params": {"max_tokens": 1}}
47 |
48 | assert tru_config == exp_config
49 |
50 | litellm_client_cls.assert_called_once_with(api_key="k1")
51 |
52 |
53 | def test_update_config(model, model_id):
54 | model.update_config(model_id=model_id)
55 |
56 | tru_model_id = model.get_config().get("model_id")
57 | exp_model_id = model_id
58 |
59 | assert tru_model_id == exp_model_id
60 |
61 |
62 | @pytest.mark.parametrize(
63 | "content, exp_result",
64 | [
65 | # Case 1: Thinking
66 | (
67 | {
68 | "reasoningContent": {
69 | "reasoningText": {
70 | "signature": "reasoning_signature",
71 | "text": "reasoning_text",
72 | },
73 | },
74 | },
75 | {
76 | "signature": "reasoning_signature",
77 | "thinking": "reasoning_text",
78 | "type": "thinking",
79 | },
80 | ),
81 | # Case 2: Video
82 | (
83 | {
84 | "video": {
85 | "source": {"bytes": "base64encodedvideo"},
86 | },
87 | },
88 | {
89 | "type": "video_url",
90 | "video_url": {
91 | "detail": "auto",
92 | "url": "base64encodedvideo",
93 | },
94 | },
95 | ),
96 | # Case 3: Text
97 | (
98 | {"text": "hello"},
99 | {"type": "text", "text": "hello"},
100 | ),
101 | ],
102 | )
103 | def test_format_request_message_content(content, exp_result):
104 | tru_result = LiteLLMModel.format_request_message_content(content)
105 | assert tru_result == exp_result
106 |
--------------------------------------------------------------------------------
/tests/strands/models/test_openai.py:
--------------------------------------------------------------------------------
1 | import unittest.mock
2 |
3 | import pytest
4 |
5 | import strands
6 | from strands.models.openai import OpenAIModel
7 |
8 |
9 | @pytest.fixture
10 | def openai_client_cls():
11 | with unittest.mock.patch.object(strands.models.openai.openai, "OpenAI") as mock_client_cls:
12 | yield mock_client_cls
13 |
14 |
15 | @pytest.fixture
16 | def openai_client(openai_client_cls):
17 | return openai_client_cls.return_value
18 |
19 |
20 | @pytest.fixture
21 | def model_id():
22 | return "m1"
23 |
24 |
25 | @pytest.fixture
26 | def model(openai_client, model_id):
27 | _ = openai_client
28 |
29 | return OpenAIModel(model_id=model_id)
30 |
31 |
32 | @pytest.fixture
33 | def messages():
34 | return [{"role": "user", "content": [{"text": "test"}]}]
35 |
36 |
37 | @pytest.fixture
38 | def system_prompt():
39 | return "s1"
40 |
41 |
42 | def test__init__(openai_client_cls, model_id):
43 | model = OpenAIModel({"api_key": "k1"}, model_id=model_id, params={"max_tokens": 1})
44 |
45 | tru_config = model.get_config()
46 | exp_config = {"model_id": "m1", "params": {"max_tokens": 1}}
47 |
48 | assert tru_config == exp_config
49 |
50 | openai_client_cls.assert_called_once_with(api_key="k1")
51 |
52 |
53 | def test_update_config(model, model_id):
54 | model.update_config(model_id=model_id)
55 |
56 | tru_model_id = model.get_config().get("model_id")
57 | exp_model_id = model_id
58 |
59 | assert tru_model_id == exp_model_id
60 |
61 |
62 | def test_stream(openai_client, model):
63 | mock_tool_call_1_part_1 = unittest.mock.Mock(index=0)
64 | mock_tool_call_2_part_1 = unittest.mock.Mock(index=1)
65 | mock_delta_1 = unittest.mock.Mock(
66 | content="I'll calculate", tool_calls=[mock_tool_call_1_part_1, mock_tool_call_2_part_1]
67 | )
68 |
69 | mock_tool_call_1_part_2 = unittest.mock.Mock(index=0)
70 | mock_tool_call_2_part_2 = unittest.mock.Mock(index=1)
71 | mock_delta_2 = unittest.mock.Mock(
72 | content="that for you", tool_calls=[mock_tool_call_1_part_2, mock_tool_call_2_part_2]
73 | )
74 |
75 | mock_delta_3 = unittest.mock.Mock(content="", tool_calls=None)
76 |
77 | mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_1)])
78 | mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_2)])
79 | mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_3)])
80 | mock_event_4 = unittest.mock.Mock()
81 |
82 | openai_client.chat.completions.create.return_value = iter([mock_event_1, mock_event_2, mock_event_3, mock_event_4])
83 |
84 | request = {"model": "m1", "messages": [{"role": "user", "content": [{"type": "text", "text": "calculate 2+2"}]}]}
85 | response = model.stream(request)
86 |
87 | tru_events = list(response)
88 | exp_events = [
89 | {"chunk_type": "message_start"},
90 | {"chunk_type": "content_start", "data_type": "text"},
91 | {"chunk_type": "content_delta", "data_type": "text", "data": "I'll calculate"},
92 | {"chunk_type": "content_delta", "data_type": "text", "data": "that for you"},
93 | {"chunk_type": "content_stop", "data_type": "text"},
94 | {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call_1_part_1},
95 | {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_1_part_1},
96 | {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_1_part_2},
97 | {"chunk_type": "content_stop", "data_type": "tool"},
98 | {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call_2_part_1},
99 | {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_2_part_1},
100 | {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_2_part_2},
101 | {"chunk_type": "content_stop", "data_type": "tool"},
102 | {"chunk_type": "message_stop", "data": "tool_calls"},
103 | {"chunk_type": "metadata", "data": mock_event_4.usage},
104 | ]
105 |
106 | assert tru_events == exp_events
107 | openai_client.chat.completions.create.assert_called_once_with(**request)
108 |
109 |
110 | def test_stream_empty(openai_client, model):
111 | mock_delta = unittest.mock.Mock(content=None, tool_calls=None)
112 | mock_usage = unittest.mock.Mock(prompt_tokens=0, completion_tokens=0, total_tokens=0)
113 |
114 | mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)])
115 | mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)])
116 | mock_event_3 = unittest.mock.Mock()
117 | mock_event_4 = unittest.mock.Mock(usage=mock_usage)
118 |
119 | openai_client.chat.completions.create.return_value = iter([mock_event_1, mock_event_2, mock_event_3, mock_event_4])
120 |
121 | request = {"model": "m1", "messages": [{"role": "user", "content": []}]}
122 | response = model.stream(request)
123 |
124 | tru_events = list(response)
125 | exp_events = [
126 | {"chunk_type": "message_start"},
127 | {"chunk_type": "content_start", "data_type": "text"},
128 | {"chunk_type": "content_stop", "data_type": "text"},
129 | {"chunk_type": "message_stop", "data": "stop"},
130 | {"chunk_type": "metadata", "data": mock_usage},
131 | ]
132 |
133 | assert tru_events == exp_events
134 | openai_client.chat.completions.create.assert_called_once_with(**request)
135 |
136 |
137 | def test_stream_with_empty_choices(openai_client, model):
138 | mock_delta = unittest.mock.Mock(content="content", tool_calls=None)
139 | mock_usage = unittest.mock.Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30)
140 |
141 | # Event with no choices attribute
142 | mock_event_1 = unittest.mock.Mock(spec=[])
143 |
144 | # Event with empty choices list
145 | mock_event_2 = unittest.mock.Mock(choices=[])
146 |
147 | # Valid event with content
148 | mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)])
149 |
150 | # Event with finish reason
151 | mock_event_4 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)])
152 |
153 | # Final event with usage info
154 | mock_event_5 = unittest.mock.Mock(usage=mock_usage)
155 |
156 | openai_client.chat.completions.create.return_value = iter(
157 | [mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5]
158 | )
159 |
160 | request = {"model": "m1", "messages": [{"role": "user", "content": ["test"]}]}
161 | response = model.stream(request)
162 |
163 | tru_events = list(response)
164 | exp_events = [
165 | {"chunk_type": "message_start"},
166 | {"chunk_type": "content_start", "data_type": "text"},
167 | {"chunk_type": "content_delta", "data_type": "text", "data": "content"},
168 | {"chunk_type": "content_delta", "data_type": "text", "data": "content"},
169 | {"chunk_type": "content_stop", "data_type": "text"},
170 | {"chunk_type": "message_stop", "data": "stop"},
171 | {"chunk_type": "metadata", "data": mock_usage},
172 | ]
173 |
174 | assert tru_events == exp_events
175 | openai_client.chat.completions.create.assert_called_once_with(**request)
176 |
--------------------------------------------------------------------------------
/tests/strands/tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/strands-agents/sdk-python/a64e80dfa5f5f6053a3fe53767f59fc3c5c1af95/tests/strands/tools/__init__.py
--------------------------------------------------------------------------------
/tests/strands/tools/mcp/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/strands-agents/sdk-python/a64e80dfa5f5f6053a3fe53767f59fc3c5c1af95/tests/strands/tools/mcp/__init__.py
--------------------------------------------------------------------------------
/tests/strands/tools/mcp/test_mcp_agent_tool.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import MagicMock
2 |
3 | import pytest
4 | from mcp.types import Tool as MCPTool
5 |
6 | from strands.tools.mcp import MCPAgentTool, MCPClient
7 |
8 |
9 | @pytest.fixture
10 | def mock_mcp_tool():
11 | mock_tool = MagicMock(spec=MCPTool)
12 | mock_tool.name = "test_tool"
13 | mock_tool.description = "A test tool"
14 | mock_tool.inputSchema = {"type": "object", "properties": {}}
15 | return mock_tool
16 |
17 |
18 | @pytest.fixture
19 | def mock_mcp_client():
20 | mock_server = MagicMock(spec=MCPClient)
21 | mock_server.call_tool_sync.return_value = {
22 | "status": "success",
23 | "toolUseId": "test-123",
24 | "content": [{"text": "Success result"}],
25 | }
26 | return mock_server
27 |
28 |
29 | @pytest.fixture
30 | def mcp_agent_tool(mock_mcp_tool, mock_mcp_client):
31 | return MCPAgentTool(mock_mcp_tool, mock_mcp_client)
32 |
33 |
34 | def test_tool_name(mcp_agent_tool, mock_mcp_tool):
35 | assert mcp_agent_tool.tool_name == "test_tool"
36 | assert mcp_agent_tool.tool_name == mock_mcp_tool.name
37 |
38 |
39 | def test_tool_type(mcp_agent_tool):
40 | assert mcp_agent_tool.tool_type == "python"
41 |
42 |
43 | def test_tool_spec_with_description(mcp_agent_tool, mock_mcp_tool):
44 | tool_spec = mcp_agent_tool.tool_spec
45 |
46 | assert tool_spec["name"] == "test_tool"
47 | assert tool_spec["description"] == "A test tool"
48 | assert tool_spec["inputSchema"]["json"] == {"type": "object", "properties": {}}
49 |
50 |
51 | def test_tool_spec_without_description(mock_mcp_tool, mock_mcp_client):
52 | mock_mcp_tool.description = None
53 |
54 | agent_tool = MCPAgentTool(mock_mcp_tool, mock_mcp_client)
55 | tool_spec = agent_tool.tool_spec
56 |
57 | assert tool_spec["description"] == "Tool which performs test_tool"
58 |
59 |
60 | def test_invoke(mcp_agent_tool, mock_mcp_client):
61 | tool_use = {"toolUseId": "test-123", "name": "test_tool", "input": {"param": "value"}}
62 |
63 | result = mcp_agent_tool.invoke(tool_use)
64 |
65 | mock_mcp_client.call_tool_sync.assert_called_once_with(
66 | tool_use_id="test-123", name="test_tool", arguments={"param": "value"}
67 | )
68 | assert result == mock_mcp_client.call_tool_sync.return_value
69 |
--------------------------------------------------------------------------------
/tests/strands/tools/mcp/test_mcp_client.py:
--------------------------------------------------------------------------------
1 | import time
2 | from unittest.mock import AsyncMock, MagicMock, patch
3 |
4 | import pytest
5 | from mcp import ListToolsResult
6 | from mcp.types import CallToolResult as MCPCallToolResult
7 | from mcp.types import TextContent as MCPTextContent
8 | from mcp.types import Tool as MCPTool
9 |
10 | from strands.tools.mcp import MCPClient
11 | from strands.types.exceptions import MCPClientInitializationError
12 |
13 |
14 | @pytest.fixture
15 | def mock_transport():
16 | mock_read_stream = AsyncMock()
17 | mock_write_stream = AsyncMock()
18 | mock_transport_cm = AsyncMock()
19 | mock_transport_cm.__aenter__.return_value = (mock_read_stream, mock_write_stream)
20 | mock_transport_callable = MagicMock(return_value=mock_transport_cm)
21 |
22 | return {
23 | "read_stream": mock_read_stream,
24 | "write_stream": mock_write_stream,
25 | "transport_cm": mock_transport_cm,
26 | "transport_callable": mock_transport_callable,
27 | }
28 |
29 |
30 | @pytest.fixture
31 | def mock_session():
32 | mock_session = AsyncMock()
33 | mock_session.initialize = AsyncMock()
34 |
35 | # Create a mock context manager for ClientSession
36 | mock_session_cm = AsyncMock()
37 | mock_session_cm.__aenter__.return_value = mock_session
38 |
39 | # Patch ClientSession to return our mock session
40 | with patch("strands.tools.mcp.mcp_client.ClientSession", return_value=mock_session_cm):
41 | yield mock_session
42 |
43 |
44 | @pytest.fixture
45 | def mcp_client(mock_transport, mock_session):
46 | with MCPClient(mock_transport["transport_callable"]) as client:
47 | yield client
48 |
49 |
50 | def test_mcp_client_context_manager(mock_transport, mock_session):
51 | """Test that the MCPClient context manager properly initializes and cleans up."""
52 | with MCPClient(mock_transport["transport_callable"]) as client:
53 | assert client._background_thread is not None
54 | assert client._background_thread.is_alive()
55 | assert client._init_future.done()
56 |
57 | mock_transport["transport_cm"].__aenter__.assert_called_once()
58 | mock_session.initialize.assert_called_once()
59 |
60 | # After exiting the context manager, verify that the thread was cleaned up
61 | # Give a small delay for the thread to fully terminate
62 | time.sleep(0.1)
63 | assert client._background_thread is None
64 |
65 |
66 | def test_list_tools_sync(mock_transport, mock_session):
67 | """Test that list_tools_sync correctly retrieves and adapts tools."""
68 | mock_tool = MCPTool(name="test_tool", description="A test tool", inputSchema={"type": "object", "properties": {}})
69 | mock_session.list_tools.return_value = ListToolsResult(tools=[mock_tool])
70 |
71 | with MCPClient(mock_transport["transport_callable"]) as client:
72 | tools = client.list_tools_sync()
73 |
74 | mock_session.list_tools.assert_called_once()
75 |
76 | assert len(tools) == 1
77 | assert tools[0].tool_name == "test_tool"
78 |
79 |
80 | def test_list_tools_sync_session_not_active():
81 | """Test that list_tools_sync raises an error when session is not active."""
82 | client = MCPClient(MagicMock())
83 |
84 | with pytest.raises(MCPClientInitializationError, match="client.session is not running"):
85 | client.list_tools_sync()
86 |
87 |
88 | @pytest.mark.parametrize("is_error,expected_status", [(False, "success"), (True, "error")])
89 | def test_call_tool_sync_status(mock_transport, mock_session, is_error, expected_status):
90 | """Test that call_tool_sync correctly handles success and error results."""
91 | mock_content = MCPTextContent(type="text", text="Test message")
92 | mock_session.call_tool.return_value = MCPCallToolResult(isError=is_error, content=[mock_content])
93 |
94 | with MCPClient(mock_transport["transport_callable"]) as client:
95 | result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"})
96 |
97 | mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None)
98 |
99 | assert result["status"] == expected_status
100 | assert result["toolUseId"] == "test-123"
101 | assert len(result["content"]) == 1
102 | assert result["content"][0]["text"] == "Test message"
103 |
104 |
105 | def test_call_tool_sync_session_not_active():
106 | """Test that call_tool_sync raises an error when session is not active."""
107 | client = MCPClient(MagicMock())
108 |
109 | with pytest.raises(MCPClientInitializationError, match="client.session is not running"):
110 | client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"})
111 |
112 |
113 | def test_call_tool_sync_exception(mock_transport, mock_session):
114 | """Test that call_tool_sync correctly handles exceptions."""
115 | mock_session.call_tool.side_effect = Exception("Test exception")
116 |
117 | with MCPClient(mock_transport["transport_callable"]) as client:
118 | result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"})
119 |
120 | assert result["status"] == "error"
121 | assert result["toolUseId"] == "test-123"
122 | assert len(result["content"]) == 1
123 | assert "Test exception" in result["content"][0]["text"]
124 |
125 |
126 | def test_enter_with_initialization_exception(mock_transport):
127 | """Test that __enter__ handles exceptions during initialization properly."""
128 | # Make the transport callable throw an exception
129 | mock_transport["transport_cm"].__aenter__.side_effect = Exception("Transport initialization failed")
130 |
131 | client = MCPClient(mock_transport["transport_callable"])
132 |
133 | with pytest.raises(MCPClientInitializationError, match="the client initialization failed"):
134 | client.start()
135 |
136 |
137 | def test_exception_when_future_not_running():
138 | """Test exception handling when the future is not running."""
139 | # Create a client.with a mock transport
140 | mock_transport_callable = MagicMock()
141 | client = MCPClient(mock_transport_callable)
142 |
143 | # Create a mock future that is not running
144 | mock_future = MagicMock()
145 | mock_future.running.return_value = False
146 | client._init_future = mock_future
147 |
148 | # Create a mock event loop
149 | mock_event_loop = MagicMock()
150 | mock_event_loop.run_until_complete.side_effect = Exception("Test exception")
151 |
152 | # Patch the event loop creation
153 | with patch("asyncio.new_event_loop", return_value=mock_event_loop):
154 | # Run the background task which should trigger the exception
155 | try:
156 | client._background_task()
157 | except Exception:
158 | pass # We expect an exception to be raised
159 |
160 | # Verify that set_exception was not called since the future was not running
161 | mock_future.set_exception.assert_not_called()
162 |
--------------------------------------------------------------------------------
/tests/strands/tools/test_registry.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests for the SDK tool registry module.
3 | """
4 |
5 | import pytest
6 |
7 | from strands.tools.registry import ToolRegistry
8 |
9 |
10 | def test_load_tool_from_filepath_failure():
11 | """Test error handling when load_tool fails."""
12 | tool_registry = ToolRegistry()
13 | error_message = "Failed to load tool failing_tool: Tool file not found: /path/to/failing_tool.py"
14 |
15 | with pytest.raises(ValueError, match=error_message):
16 | tool_registry.load_tool_from_filepath("failing_tool", "/path/to/failing_tool.py")
17 |
18 |
19 | def test_process_tools_with_invalid_path():
20 | """Test that process_tools raises an exception when a non-path string is passed."""
21 | tool_registry = ToolRegistry()
22 | invalid_path = "not a filepath"
23 |
24 | with pytest.raises(ValueError, match=f"Failed to load tool {invalid_path.split('.')[0]}: Tool file not found:.*"):
25 | tool_registry.process_tools([invalid_path])
26 |
--------------------------------------------------------------------------------
/tests/strands/tools/test_thread_pool_executor.py:
--------------------------------------------------------------------------------
1 | import concurrent
2 |
3 | import pytest
4 |
5 | import strands
6 |
7 |
8 | @pytest.fixture
9 | def thread_pool():
10 | return concurrent.futures.ThreadPoolExecutor(max_workers=1)
11 |
12 |
13 | @pytest.fixture
14 | def thread_pool_wrapper(thread_pool):
15 | return strands.tools.ThreadPoolExecutorWrapper(thread_pool)
16 |
17 |
18 | def test_submit(thread_pool_wrapper):
19 | def fun(a, b):
20 | return (a, b)
21 |
22 | future = thread_pool_wrapper.submit(fun, 1, b=2)
23 |
24 | tru_result = future.result()
25 | exp_result = (1, 2)
26 |
27 | assert tru_result == exp_result
28 |
29 |
30 | def test_as_completed(thread_pool_wrapper):
31 | def fun(i):
32 | return i
33 |
34 | futures = [thread_pool_wrapper.submit(fun, i) for i in range(2)]
35 |
36 | tru_results = sorted(future.result() for future in thread_pool_wrapper.as_completed(futures))
37 | exp_results = [0, 1]
38 |
39 | assert tru_results == exp_results
40 |
41 |
42 | def test_shutdown(thread_pool_wrapper):
43 | thread_pool_wrapper.shutdown()
44 |
45 | with pytest.raises(RuntimeError):
46 | thread_pool_wrapper.submit(lambda: None)
47 |
--------------------------------------------------------------------------------
/tests/strands/tools/test_watcher.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests for the SDK tool watcher module.
3 | """
4 |
5 | from unittest.mock import MagicMock, patch
6 |
7 | import pytest
8 |
9 | from strands.tools.registry import ToolRegistry
10 | from strands.tools.watcher import ToolWatcher
11 |
12 |
13 | def test_tool_watcher_initialization():
14 | """Test that the handler initializes with the correct tool registry."""
15 | tool_registry = ToolRegistry()
16 | watcher = ToolWatcher(tool_registry)
17 | assert watcher.tool_registry == tool_registry
18 |
19 |
20 | @pytest.mark.parametrize(
21 | "test_case",
22 | [
23 | # Regular Python file - should reload
24 | {
25 | "description": "Python file",
26 | "src_path": "/path/to/test_tool.py",
27 | "is_directory": False,
28 | "should_reload": True,
29 | "expected_tool_name": "test_tool",
30 | },
31 | # Non-Python file - should not reload
32 | {
33 | "description": "Non-Python file",
34 | "src_path": "/path/to/test_tool.txt",
35 | "is_directory": False,
36 | "should_reload": False,
37 | },
38 | # __init__.py file - should not reload
39 | {
40 | "description": "Init file",
41 | "src_path": "/path/to/__init__.py",
42 | "is_directory": False,
43 | "should_reload": False,
44 | },
45 | # Directory path - should not reload
46 | {
47 | "description": "Directory path",
48 | "src_path": "/path/to/tools_directory",
49 | "is_directory": True,
50 | "should_reload": False,
51 | },
52 | # Python file marked as directory - should still reload
53 | {
54 | "description": "Python file marked as directory",
55 | "src_path": "/path/to/test_tool2.py",
56 | "is_directory": True,
57 | "should_reload": True,
58 | "expected_tool_name": "test_tool2",
59 | },
60 | ],
61 | )
62 | @patch.object(ToolRegistry, "reload_tool")
63 | def test_on_modified_cases(mock_reload_tool, test_case):
64 | """Test various cases for the on_modified method."""
65 | tool_registry = ToolRegistry()
66 | watcher = ToolWatcher(tool_registry)
67 |
68 | # Create a mock event with the specified properties
69 | event = MagicMock()
70 | event.src_path = test_case["src_path"]
71 | if "is_directory" in test_case:
72 | event.is_directory = test_case["is_directory"]
73 |
74 | # Call the on_modified method
75 | watcher.tool_change_handler.on_modified(event)
76 |
77 | # Verify the expected behavior
78 | if test_case["should_reload"]:
79 | mock_reload_tool.assert_called_once_with(test_case["expected_tool_name"])
80 | else:
81 | mock_reload_tool.assert_not_called()
82 |
83 |
84 | @patch.object(ToolRegistry, "reload_tool", side_effect=Exception("Test error"))
85 | def test_on_modified_error_handling(mock_reload_tool):
86 | """Test that on_modified handles errors during tool reloading."""
87 | tool_registry = ToolRegistry()
88 | watcher = ToolWatcher(tool_registry)
89 |
90 | # Create a mock event with a Python file path
91 | event = MagicMock()
92 | event.src_path = "/path/to/test_tool.py"
93 |
94 | # Call the on_modified method - should not raise an exception
95 | watcher.tool_change_handler.on_modified(event)
96 |
97 | # Verify that reload_tool was called
98 | mock_reload_tool.assert_called_once_with("test_tool")
99 |
--------------------------------------------------------------------------------
/tests/strands/types/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/strands-agents/sdk-python/a64e80dfa5f5f6053a3fe53767f59fc3c5c1af95/tests/strands/types/models/__init__.py
--------------------------------------------------------------------------------
/tests/strands/types/models/test_model.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from strands.types.models import Model as SAModel
4 |
5 |
6 | class TestModel(SAModel):
7 | def update_config(self, **model_config):
8 | return model_config
9 |
10 | def get_config(self):
11 | return
12 |
13 | def format_request(self, messages, tool_specs, system_prompt):
14 | return {
15 | "messages": messages,
16 | "tool_specs": tool_specs,
17 | "system_prompt": system_prompt,
18 | }
19 |
20 | def format_chunk(self, event):
21 | return {"event": event}
22 |
23 | def stream(self, request):
24 | yield {"request": request}
25 |
26 |
27 | @pytest.fixture
28 | def model():
29 | return TestModel()
30 |
31 |
32 | @pytest.fixture
33 | def messages():
34 | return [
35 | {
36 | "role": "user",
37 | "content": [{"text": "hello"}],
38 | },
39 | ]
40 |
41 |
42 | @pytest.fixture
43 | def tool_specs():
44 | return [
45 | {
46 | "name": "test_tool",
47 | "description": "A test tool",
48 | "inputSchema": {
49 | "json": {
50 | "type": "object",
51 | "properties": {
52 | "input": {"type": "string"},
53 | },
54 | "required": ["input"],
55 | },
56 | },
57 | },
58 | ]
59 |
60 |
61 | @pytest.fixture
62 | def system_prompt():
63 | return "s1"
64 |
65 |
66 | def test_converse(model, messages, tool_specs, system_prompt):
67 | response = model.converse(messages, tool_specs, system_prompt)
68 |
69 | tru_events = list(response)
70 | exp_events = [
71 | {
72 | "event": {
73 | "request": {
74 | "messages": messages,
75 | "tool_specs": tool_specs,
76 | "system_prompt": system_prompt,
77 | },
78 | },
79 | },
80 | ]
81 | assert tru_events == exp_events
82 |
--------------------------------------------------------------------------------