├── .github ├── CODEOWNERS ├── ISSUE_TEMPLATE │ ├── bug_report.yml │ ├── config.yml │ └── feature_request.yml ├── PULL_REQUEST_TEMPLATE.md ├── dependabot.yml └── workflows │ ├── pr-and-push.yml │ ├── pypi-publish-on-release.yml │ └── test-lint.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE ├── NOTICE ├── README.md ├── STYLE_GUIDE.md ├── pyproject.toml ├── src └── strands │ ├── __init__.py │ ├── agent │ ├── __init__.py │ ├── agent.py │ ├── agent_result.py │ └── conversation_manager │ │ ├── __init__.py │ │ ├── conversation_manager.py │ │ ├── null_conversation_manager.py │ │ └── sliding_window_conversation_manager.py │ ├── event_loop │ ├── __init__.py │ ├── error_handler.py │ ├── event_loop.py │ ├── message_processor.py │ └── streaming.py │ ├── handlers │ ├── __init__.py │ ├── callback_handler.py │ └── tool_handler.py │ ├── models │ ├── __init__.py │ ├── anthropic.py │ ├── bedrock.py │ ├── litellm.py │ ├── llamaapi.py │ ├── ollama.py │ └── openai.py │ ├── py.typed │ ├── telemetry │ ├── __init__.py │ ├── metrics.py │ └── tracer.py │ ├── tools │ ├── __init__.py │ ├── decorator.py │ ├── executor.py │ ├── loader.py │ ├── mcp │ │ ├── __init__.py │ │ ├── mcp_agent_tool.py │ │ ├── mcp_client.py │ │ └── mcp_types.py │ ├── registry.py │ ├── thread_pool_executor.py │ ├── tools.py │ └── watcher.py │ └── types │ ├── __init__.py │ ├── content.py │ ├── event_loop.py │ ├── exceptions.py │ ├── guardrails.py │ ├── media.py │ ├── models │ ├── __init__.py │ ├── model.py │ └── openai.py │ ├── streaming.py │ ├── tools.py │ └── traces.py ├── tests-integ ├── __init__.py ├── echo_server.py ├── test_bedrock_cache_point.py ├── test_bedrock_guardrails.py ├── test_context_overflow.py ├── test_function_tools.py ├── test_hot_tool_reload_decorator.py ├── test_image.png ├── test_mcp_client.py ├── test_model_anthropic.py ├── test_model_bedrock.py ├── test_model_litellm.py ├── test_model_llamaapi.py ├── test_model_openai.py └── test_stream_agent.py └── tests ├── __init__.py ├── conftest.py └── strands ├── __init__.py ├── agent ├── __init__.py ├── test_agent.py ├── test_agent_result.py └── test_conversation_manager.py ├── event_loop ├── __init__.py ├── test_error_handler.py ├── test_event_loop.py ├── test_message_processor.py └── test_streaming.py ├── handlers ├── __init__.py ├── test_callback_handler.py └── test_tool_handler.py ├── models ├── __init__.py ├── test_anthropic.py ├── test_bedrock.py ├── test_litellm.py ├── test_llamaapi.py ├── test_ollama.py └── test_openai.py ├── telemetry ├── test_metrics.py └── test_tracer.py ├── tools ├── __init__.py ├── mcp │ ├── __init__.py │ ├── test_mcp_agent_tool.py │ └── test_mcp_client.py ├── test_decorator.py ├── test_executor.py ├── test_loader.py ├── test_registry.py ├── test_thread_pool_executor.py ├── test_tools.py └── test_watcher.py └── types └── models ├── __init__.py ├── test_model.py └── test_openai.py /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # These owners will be the default owners for everything in 2 | # the repo. Unless a later match takes precedence, 3 | # @strands-agents/contributors will be requested for 4 | # review when someone opens a pull request. 5 | * @strands-agents/maintainers -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | name: Bug Report 2 | description: Report a bug in the Strands Agents SDK 3 | title: "[BUG] " 4 | labels: ["bug", "triage"] 5 | assignees: [] 6 | body: 7 | - type: markdown 8 | attributes: 9 | value: | 10 | Thanks for taking the time to fill out this bug report for Strands SDK! 11 | - type: checkboxes 12 | id: "checks" 13 | attributes: 14 | label: "Checks" 15 | options: 16 | - label: "I have updated to the lastest minor and patch version of Strands" 17 | required: true 18 | - label: "I have checked the documentation and this is not expected behavior" 19 | required: true 20 | - label: "I have searched [./issues](./issues?q=) and there are no duplicates of my issue" 21 | required: true 22 | - type: input 23 | id: strands-version 24 | attributes: 25 | label: Strands Version 26 | description: Which version of Strands are you using? 27 | placeholder: e.g., 0.5.2 28 | validations: 29 | required: true 30 | - type: input 31 | id: python-version 32 | attributes: 33 | label: Python Version 34 | description: Which version of Python are you using? 35 | placeholder: e.g., 3.10.5 36 | validations: 37 | required: true 38 | - type: input 39 | id: os 40 | attributes: 41 | label: Operating System 42 | description: Which operating system are you using? 43 | placeholder: e.g., macOS 12.6 44 | validations: 45 | required: true 46 | - type: dropdown 47 | id: installation-method 48 | attributes: 49 | label: Installation Method 50 | description: How did you install Strands? 51 | options: 52 | - pip 53 | - git clone 54 | - binary 55 | - other 56 | validations: 57 | required: true 58 | - type: textarea 59 | id: steps-to-reproduce 60 | attributes: 61 | label: Steps to Reproduce 62 | description: Detailed steps to reproduce the behavior 63 | placeholder: | 64 | 1. Install Strands using... 65 | 2. Run the command... 66 | 3. See error... 67 | validations: 68 | required: true 69 | - type: textarea 70 | id: expected-behavior 71 | attributes: 72 | label: Expected Behavior 73 | description: A clear description of what you expected to happen 74 | validations: 75 | required: true 76 | - type: textarea 77 | id: actual-behavior 78 | attributes: 79 | label: Actual Behavior 80 | description: What actually happened 81 | validations: 82 | required: true 83 | - type: textarea 84 | id: additional-context 85 | attributes: 86 | label: Additional Context 87 | description: Any other relevant information, logs, screenshots, etc. 88 | - type: textarea 89 | id: possible-solution 90 | attributes: 91 | label: Possible Solution 92 | description: Optional - If you have suggestions on how to fix the bug 93 | - type: input 94 | id: related-issues 95 | attributes: 96 | label: Related Issues 97 | description: Optional - Link to related issues if applicable 98 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: Strands Agents SDK Support 4 | url: https://github.com/strands-agents/sdk-python/discussions 5 | about: Please ask and answer questions here 6 | - name: Strands Agents SDK Documentation 7 | url: https://github.com/strands-agents/docs 8 | about: Visit our documentation for help 9 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | name: Feature Request 2 | description: Suggest a new feature or enhancement for Strands Agents SDK 3 | title: "[FEATURE] " 4 | labels: ["enhancement", "triage"] 5 | assignees: [] 6 | body: 7 | - type: markdown 8 | attributes: 9 | value: | 10 | Thanks for suggesting a new feature for Strands Agents SDK! 11 | - type: textarea 12 | id: problem-statement 13 | attributes: 14 | label: Problem Statement 15 | description: Describe the problem you're trying to solve. What is currently difficult or impossible to do? 16 | placeholder: I would like Strands to... 17 | validations: 18 | required: true 19 | - type: textarea 20 | id: proposed-solution 21 | attributes: 22 | label: Proposed Solution 23 | description: Optional - Describe your proposed solution in detail. How would this feature work? 24 | - type: textarea 25 | id: use-case 26 | attributes: 27 | label: Use Case 28 | description: Provide specific use cases for the feature. How would people use it? 29 | placeholder: This would help with... 30 | validations: 31 | required: true 32 | - type: textarea 33 | id: alternatives-solutions 34 | attributes: 35 | label: Alternatives Solutions 36 | description: Optional - Have you considered alternative approaches? What are their pros and cons? 37 | - type: textarea 38 | id: additional-context 39 | attributes: 40 | label: Additional Context 41 | description: Include any other context, screenshots, code examples, or references that might help understand the feature request. 42 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | [Provide a detailed description of the changes in this PR] 3 | 4 | ## Related Issues 5 | [Link to related issues using #issue-number format] 6 | 7 | ## Documentation PR 8 | [Link to related associated PR in the agent-docs repo] 9 | 10 | ## Type of Change 11 | - Bug fix 12 | - New feature 13 | - Breaking change 14 | - Documentation update 15 | - Other (please describe): 16 | 17 | [Choose one of the above types of changes] 18 | 19 | 20 | ## Testing 21 | [How have you tested the change?] 22 | 23 | * `hatch fmt --linter` 24 | * `hatch fmt --formatter` 25 | * `hatch test --all` 26 | * Verify that the changes do not break functionality or introduce warnings in consuming repositories: agents-docs, agents-tools, agents-cli 27 | 28 | 29 | ## Checklist 30 | - [ ] I have read the CONTRIBUTING document 31 | - [ ] I have added tests that prove my fix is effective or my feature works 32 | - [ ] I have updated the documentation accordingly 33 | - [ ] I have added an appropriate example to the documentation to outline the feature 34 | - [ ] My changes generate no new warnings 35 | - [ ] Any dependent changes have been merged and published 36 | 37 | By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. 38 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "pip" 4 | directory: "/" 5 | schedule: 6 | interval: "daily" 7 | open-pull-requests-limit: 100 8 | commit-message: 9 | prefix: ci 10 | groups: 11 | dev-dependencies: 12 | patterns: 13 | - "pytest" 14 | - package-ecosystem: "github-actions" 15 | directory: "/" 16 | schedule: 17 | interval: "daily" 18 | open-pull-requests-limit: 100 19 | commit-message: 20 | prefix: ci 21 | -------------------------------------------------------------------------------- /.github/workflows/pr-and-push.yml: -------------------------------------------------------------------------------- 1 | name: Pull Request and Push Action 2 | 3 | on: 4 | pull_request: # Safer than pull_request_target for untrusted code 5 | branches: [ main ] 6 | types: [opened, synchronize, reopened, ready_for_review, review_requested, review_request_removed] 7 | push: 8 | branches: [ main ] # Also run on direct pushes to main 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} 11 | cancel-in-progress: true 12 | 13 | jobs: 14 | call-test-lint: 15 | uses: ./.github/workflows/test-lint.yml 16 | permissions: 17 | contents: read 18 | with: 19 | ref: ${{ github.event.pull_request.head.sha }} 20 | -------------------------------------------------------------------------------- /.github/workflows/pypi-publish-on-release.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python Package 2 | 3 | on: 4 | release: 5 | types: 6 | - published 7 | 8 | jobs: 9 | call-test-lint: 10 | uses: ./.github/workflows/test-lint.yml 11 | permissions: 12 | contents: read 13 | with: 14 | ref: ${{ github.event.release.target_commitish }} 15 | 16 | build: 17 | name: Build distribution 📦 18 | permissions: 19 | contents: read 20 | needs: 21 | - call-test-lint 22 | runs-on: ubuntu-latest 23 | 24 | steps: 25 | - uses: actions/checkout@v4 26 | with: 27 | persist-credentials: false 28 | 29 | - name: Set up Python 30 | uses: actions/setup-python@v5 31 | with: 32 | python-version: '3.10' 33 | 34 | - name: Install dependencies 35 | run: | 36 | python -m pip install --upgrade pip 37 | pip install hatch twine 38 | 39 | - name: Validate version 40 | run: | 41 | version=$(hatch version) 42 | if [[ $version =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then 43 | echo "Valid version format" 44 | exit 0 45 | else 46 | echo "Invalid version format" 47 | exit 1 48 | fi 49 | 50 | - name: Build 51 | run: | 52 | hatch build 53 | 54 | - name: Store the distribution packages 55 | uses: actions/upload-artifact@v4 56 | with: 57 | name: python-package-distributions 58 | path: dist/ 59 | 60 | deploy: 61 | name: Upload release to PyPI 62 | needs: 63 | - build 64 | runs-on: ubuntu-latest 65 | 66 | # environment is used by PyPI Trusted Publisher and is strongly encouraged 67 | # https://docs.pypi.org/trusted-publishers/adding-a-publisher/ 68 | environment: 69 | name: pypi 70 | url: https://pypi.org/p/strands-agents 71 | permissions: 72 | # IMPORTANT: this permission is mandatory for Trusted Publishing 73 | id-token: write 74 | 75 | steps: 76 | - name: Download all the dists 77 | uses: actions/download-artifact@v4 78 | with: 79 | name: python-package-distributions 80 | path: dist/ 81 | - name: Publish distribution 📦 to PyPI 82 | uses: pypa/gh-action-pypi-publish@release/v1 83 | -------------------------------------------------------------------------------- /.github/workflows/test-lint.yml: -------------------------------------------------------------------------------- 1 | name: Test and Lint 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | ref: 7 | required: true 8 | type: string 9 | 10 | jobs: 11 | unit-test: 12 | name: Unit Tests - Python ${{ matrix.python-version }} - ${{ matrix.os-name }} 13 | permissions: 14 | contents: read 15 | strategy: 16 | matrix: 17 | include: 18 | # Linux 19 | - os: ubuntu-latest 20 | os-name: 'linux' 21 | python-version: "3.10" 22 | - os: ubuntu-latest 23 | os-name: 'linux' 24 | python-version: "3.11" 25 | - os: ubuntu-latest 26 | os-name: 'linux' 27 | python-version: "3.12" 28 | - os: ubuntu-latest 29 | os-name: 'linux' 30 | python-version: "3.13" 31 | # Windows 32 | - os: windows-latest 33 | os-name: 'windows' 34 | python-version: "3.10" 35 | - os: windows-latest 36 | os-name: 'windows' 37 | python-version: "3.11" 38 | - os: windows-latest 39 | os-name: 'windows' 40 | python-version: "3.12" 41 | - os: windows-latest 42 | os-name: 'windows' 43 | python-version: "3.13" 44 | # MacOS - latest only; not enough runners for macOS 45 | - os: macos-latest 46 | os-name: 'macOS' 47 | python-version: "3.13" 48 | fail-fast: true 49 | runs-on: ${{ matrix.os }} 50 | env: 51 | LOG_LEVEL: DEBUG 52 | steps: 53 | - name: Checkout code 54 | uses: actions/checkout@v4 55 | with: 56 | ref: ${{ inputs.ref }} # Explicitly define which commit to check out 57 | persist-credentials: false # Don't persist credentials for subsequent actions 58 | - name: Set up Python 59 | uses: actions/setup-python@v5 60 | with: 61 | python-version: ${{ matrix.python-version }} 62 | - name: Install dependencies 63 | run: | 64 | pip install --no-cache-dir hatch 65 | - name: Run Unit tests 66 | id: tests 67 | run: hatch test tests --cover 68 | continue-on-error: false 69 | lint: 70 | name: Lint 71 | runs-on: ubuntu-latest 72 | permissions: 73 | contents: read 74 | steps: 75 | - name: Checkout code 76 | uses: actions/checkout@v4 77 | with: 78 | ref: ${{ inputs.ref }} 79 | persist-credentials: false 80 | 81 | - name: Set up Python 82 | uses: actions/setup-python@v5 83 | with: 84 | python-version: '3.10' 85 | cache: 'pip' 86 | 87 | - name: Install dependencies 88 | run: | 89 | pip install --no-cache-dir hatch 90 | 91 | - name: Run lint 92 | id: lint 93 | run: hatch run test-lint 94 | continue-on-error: false 95 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | __pycache__* 3 | .coverage* 4 | .env 5 | .venv 6 | .mypy_cache 7 | .pytest_cache 8 | .ruff_cache 9 | *.bak 10 | .vscode 11 | dist -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: local 3 | hooks: 4 | - id: hatch-format 5 | name: Format code 6 | entry: hatch fmt --formatter 7 | language: system 8 | pass_filenames: false 9 | types: [python] 10 | stages: [pre-commit] 11 | - id: hatch-lint 12 | name: Lint code 13 | entry: hatch run test-lint 14 | language: system 15 | pass_filenames: false 16 | types: [python] 17 | stages: [pre-commit] 18 | - id: hatch-test-lint 19 | name: Type linting 20 | entry: hatch run test-lint 21 | language: system 22 | pass_filenames: false 23 | types: [ python ] 24 | stages: [ pre-commit ] 25 | - id: hatch-test 26 | name: Unit tests 27 | entry: hatch test 28 | language: system 29 | pass_filenames: false 30 | types: [python] 31 | stages: [pre-commit] 32 | - id: commitizen-check 33 | name: Check commit message 34 | entry: hatch run cz check --commit-msg-file 35 | language: system 36 | stages: [commit-msg] -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the [Bug Reports](../../issues/new?template=bug_report.yml) file to report bugs or [Feature Requests](../../issues/new?template=feature_request.yml) to suggest features. 13 | 14 | For a list of known bugs and feature requests: 15 | - Check [Bug Reports](../../issues?q=is%3Aissue%20state%3Aopen%20label%3Abug) for currently tracked issues 16 | - See [Feature Requests](../../issues?q=is%3Aissue%20state%3Aopen%20label%3Aenhancement) for requested enhancements 17 | 18 | When filing an issue, please check for already tracked items 19 | 20 | Please try to include as much information as you can. Details like these are incredibly useful: 21 | 22 | * A reproducible test case or series of steps 23 | * The version of our code being used (commit ID) 24 | * Any modifications you've made relevant to the bug 25 | * Anything unusual about your environment or deployment 26 | 27 | 28 | ## Development Environment 29 | 30 | This project uses [hatchling](https://hatch.pypa.io/latest/build/#hatchling) as the build backend and [hatch](https://hatch.pypa.io/latest/) for development workflow management. 31 | 32 | ### Setting Up Your Development Environment 33 | 34 | 1. Entering virtual environment using `hatch` (recommended), then launch your IDE in the new shell. 35 | ```bash 36 | hatch shell dev 37 | ``` 38 | 39 | Alternatively, install development dependencies in a manually created virtual environment: 40 | ```bash 41 | pip install -e ".[dev]" && pip install -e ".[litellm]" 42 | ``` 43 | 44 | 45 | 2. Set up pre-commit hooks: 46 | ```bash 47 | pre-commit install -t pre-commit -t commit-msg 48 | ``` 49 | This will automatically run formatters and conventional commit checks on your code before each commit. 50 | 51 | 3. Run code formatters manually: 52 | ```bash 53 | hatch fmt --formatter 54 | ``` 55 | 56 | 4. Run linters: 57 | ```bash 58 | hatch fmt --linter 59 | ``` 60 | 61 | 5. Run unit tests: 62 | ```bash 63 | hatch test 64 | ``` 65 | 66 | 6. Run integration tests: 67 | ```bash 68 | hatch run test-integ 69 | ``` 70 | 71 | ### Pre-commit Hooks 72 | 73 | We use [pre-commit](https://pre-commit.com/) to automatically run quality checks before each commit. The hook will run `hatch run format`, `hatch run lint`, `hatch run test`, and `hatch run cz check` on when you make a commit, ensuring code consistency. 74 | 75 | The pre-commit hook is installed with: 76 | 77 | ```bash 78 | pre-commit install 79 | ``` 80 | 81 | You can also run the hooks manually on all files: 82 | 83 | ```bash 84 | pre-commit run --all-files 85 | ``` 86 | 87 | ### Code Formatting and Style Guidelines 88 | 89 | We use the following tools to ensure code quality: 90 | 1. **ruff** - For formatting and linting 91 | 2. **mypy** - For static type checking 92 | 93 | These tools are configured in the [pyproject.toml](./pyproject.toml) file. Please ensure your code passes all linting and type checks before submitting a pull request: 94 | 95 | ```bash 96 | # Run all checks 97 | hatch fmt --formatter 98 | hatch fmt --linter 99 | ``` 100 | 101 | If you're using an IDE like VS Code or PyCharm, consider configuring it to use these tools automatically. 102 | 103 | For additional details on styling, please see our dedicated [Style Guide](./STYLE_GUIDE.md). 104 | 105 | 106 | ## Contributing via Pull Requests 107 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 108 | 109 | 1. You are working against the latest source on the *main* branch. 110 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 111 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 112 | 113 | To send us a pull request, please: 114 | 115 | 1. Create a branch. 116 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 117 | 3. Format your code using `hatch fmt --formatter`. 118 | 4. Run linting checks with `hatch fmt --linter`. 119 | 5. Ensure local tests pass with `hatch test` and `hatch run test-integ`. 120 | 6. Commit to your branch using clear commit messages following the [Conventional Commits](https://www.conventionalcommits.org) specification. 121 | 7. Send us a pull request, answering any default questions in the pull request interface. 122 | 8. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 123 | 124 | 125 | ## Finding contributions to work on 126 | Looking at the existing issues is a great way to find something to contribute to. 127 | 128 | You can check: 129 | - Our known bugs list in [Bug Reports](../../issues?q=is%3Aissue%20state%3Aopen%20label%3Abug) for issues that need fixing 130 | - Feature requests in [Feature Requests](../../issues?q=is%3Aissue%20state%3Aopen%20label%3Aenhancement) for new functionality to implement 131 | 132 | 133 | ## Code of Conduct 134 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 135 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 136 | opensource-codeofconduct@amazon.com with any additional questions or comments. 137 | 138 | 139 | ## Security issue notifications 140 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 141 | 142 | 143 | ## Licensing 144 | 145 | See the [LICENSE](./LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 146 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |
3 | 4 | Strands Agents 5 | 6 |
7 | 8 |

9 | Strands Agents 10 |

11 | 12 |

13 | A model-driven approach to building AI agents in just a few lines of code. 14 |

15 | 16 |
17 | GitHub commit activity 18 | GitHub open issues 19 | GitHub open pull requests 20 | License 21 | PyPI version 22 | Python versions 23 |
24 | 25 |

26 | Documentation 27 | ◆ Samples 28 | ◆ Python SDK 29 | ◆ Tools 30 | ◆ Agent Builder 31 | ◆ MCP Server 32 |

33 |
34 | 35 | Strands Agents is a simple yet powerful SDK that takes a model-driven approach to building and running AI agents. From simple conversational assistants to complex autonomous workflows, from local development to production deployment, Strands Agents scales with your needs. 36 | 37 | ## Feature Overview 38 | 39 | - **Lightweight & Flexible**: Simple agent loop that just works and is fully customizable 40 | - **Model Agnostic**: Support for Amazon Bedrock, Anthropic, LiteLLM, Llama, Ollama, OpenAI, and custom providers 41 | - **Advanced Capabilities**: Multi-agent systems, autonomous agents, and streaming support 42 | - **Built-in MCP**: Native support for Model Context Protocol (MCP) servers, enabling access to thousands of pre-built tools 43 | 44 | ## Quick Start 45 | 46 | ```bash 47 | # Install Strands Agents 48 | pip install strands-agents strands-agents-tools 49 | ``` 50 | 51 | ```python 52 | from strands import Agent 53 | from strands_tools import calculator 54 | agent = Agent(tools=[calculator]) 55 | agent("What is the square root of 1764") 56 | ``` 57 | 58 | > **Note**: For the default Amazon Bedrock model provider, you'll need AWS credentials configured and model access enabled for Claude 3.7 Sonnet in the us-west-2 region. See the [Quickstart Guide](https://strandsagents.com/) for details on configuring other model providers. 59 | 60 | ## Installation 61 | 62 | Ensure you have Python 3.10+ installed, then: 63 | 64 | ```bash 65 | # Create and activate virtual environment 66 | python -m venv .venv 67 | source .venv/bin/activate # On Windows use: .venv\Scripts\activate 68 | 69 | # Install Strands and tools 70 | pip install strands-agents strands-agents-tools 71 | ``` 72 | 73 | ## Features at a Glance 74 | 75 | ### Python-Based Tools 76 | 77 | Easily build tools using Python decorators: 78 | 79 | ```python 80 | from strands import Agent, tool 81 | 82 | @tool 83 | def word_count(text: str) -> int: 84 | """Count words in text. 85 | 86 | This docstring is used by the LLM to understand the tool's purpose. 87 | """ 88 | return len(text.split()) 89 | 90 | agent = Agent(tools=[word_count]) 91 | response = agent("How many words are in this sentence?") 92 | ``` 93 | 94 | ### MCP Support 95 | 96 | Seamlessly integrate Model Context Protocol (MCP) servers: 97 | 98 | ```python 99 | from strands import Agent 100 | from strands.tools.mcp import MCPClient 101 | from mcp import stdio_client, StdioServerParameters 102 | 103 | aws_docs_client = MCPClient( 104 | lambda: stdio_client(StdioServerParameters(command="uvx", args=["awslabs.aws-documentation-mcp-server@latest"])) 105 | ) 106 | 107 | with aws_docs_client: 108 | agent = Agent(tools=aws_docs_client.list_tools_sync()) 109 | response = agent("Tell me about Amazon Bedrock and how to use it with Python") 110 | ``` 111 | 112 | ### Multiple Model Providers 113 | 114 | Support for various model providers: 115 | 116 | ```python 117 | from strands import Agent 118 | from strands.models import BedrockModel 119 | from strands.models.ollama import OllamaModel 120 | from strands.models.llamaapi import LlamaAPIModel 121 | 122 | # Bedrock 123 | bedrock_model = BedrockModel( 124 | model_id="us.amazon.nova-pro-v1:0", 125 | temperature=0.3, 126 | streaming=True, # Enable/disable streaming 127 | ) 128 | agent = Agent(model=bedrock_model) 129 | agent("Tell me about Agentic AI") 130 | 131 | # Ollama 132 | ollama_model = OllamaModel( 133 | host="http://localhost:11434", 134 | model_id="llama3" 135 | ) 136 | agent = Agent(model=ollama_model) 137 | agent("Tell me about Agentic AI") 138 | 139 | # Llama API 140 | llama_model = LlamaAPIModel( 141 | model_id="Llama-4-Maverick-17B-128E-Instruct-FP8", 142 | ) 143 | agent = Agent(model=llama_model) 144 | response = agent("Tell me about Agentic AI") 145 | ``` 146 | 147 | Built-in providers: 148 | - [Amazon Bedrock](https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/) 149 | - [Anthropic](https://strandsagents.com/latest/user-guide/concepts/model-providers/anthropic/) 150 | - [LiteLLM](https://strandsagents.com/latest/user-guide/concepts/model-providers/litellm/) 151 | - [LlamaAPI](https://strandsagents.com/latest/user-guide/concepts/model-providers/llamaapi/) 152 | - [Ollama](https://strandsagents.com/latest/user-guide/concepts/model-providers/ollama/) 153 | - [OpenAI](https://strandsagents.com/latest/user-guide/concepts/model-providers/openai/) 154 | 155 | Custom providers can be implemented using [Custom Providers](https://strandsagents.com/latest/user-guide/concepts/model-providers/custom_model_provider/) 156 | 157 | ### Example tools 158 | 159 | Strands offers an optional strands-agents-tools package with pre-built tools for quick experimentation: 160 | 161 | ```python 162 | from strands import Agent 163 | from strands_tools import calculator 164 | agent = Agent(tools=[calculator]) 165 | agent("What is the square root of 1764") 166 | ``` 167 | 168 | It's also available on GitHub via [strands-agents/tools](https://github.com/strands-agents/tools). 169 | 170 | ## Documentation 171 | 172 | For detailed guidance & examples, explore our documentation: 173 | 174 | - [User Guide](https://strandsagents.com/) 175 | - [Quick Start Guide](https://strandsagents.com/latest/user-guide/quickstart/) 176 | - [Agent Loop](https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/) 177 | - [Examples](https://strandsagents.com/latest/examples/) 178 | - [API Reference](https://strandsagents.com/latest/api-reference/agent/) 179 | - [Production & Deployment Guide](https://strandsagents.com/latest/user-guide/deploy/operating-agents-in-production/) 180 | 181 | ## Contributing ❤️ 182 | 183 | We welcome contributions! See our [Contributing Guide](CONTRIBUTING.md) for details on: 184 | - Reporting bugs & features 185 | - Development setup 186 | - Contributing via Pull Requests 187 | - Code of Conduct 188 | - Reporting of security issues 189 | 190 | ## License 191 | 192 | This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details. 193 | 194 | ## Security 195 | 196 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. 197 | 198 | ## ⚠️ Preview Status 199 | 200 | Strands Agents is currently in public preview. During this period: 201 | - APIs may change as we refine the SDK 202 | - We welcome feedback and contributions 203 | -------------------------------------------------------------------------------- /STYLE_GUIDE.md: -------------------------------------------------------------------------------- 1 | # Style Guide 2 | 3 | ## Overview 4 | 5 | The Strands Agents style guide aims to establish consistent formatting, naming conventions, and structure across all code in the repository. We strive to make our code clean, readable, and maintainable. 6 | 7 | Where possible, we will codify these style guidelines into our linting rules and pre-commit hooks to automate enforcement and reduce the manual review burden. 8 | 9 | ## Log Formatting 10 | 11 | The format for Strands Agents logs is as follows: 12 | 13 | ```python 14 | logger.debug("field1=<%s>, field2=<%s>, ... | human readable message", field1, field2, ...) 15 | ``` 16 | 17 | ### Guidelines 18 | 19 | 1. **Context**: 20 | - Add context as `=` pairs at the beginning of the log 21 | - Many log services (CloudWatch, Splunk, etc.) look for these patterns to extract fields for searching 22 | - Use `,`'s to separate pairs 23 | - Enclose values in `<>` for readability 24 | - This is particularly helpful in displaying empty values (`field=` vs `field=<>`) 25 | - Use `%s` for string interpolation as recommended by Python logging 26 | - This is an optimization to skip string interpolation when the log level is not enabled 27 | 28 | 1. **Messages**: 29 | - Add human-readable messages at the end of the log 30 | - Use lowercase for consistency 31 | - Avoid punctuation (periods, exclamation points, etc.) to reduce clutter 32 | - Keep messages concise and focused on a single statement 33 | - If multiple statements are needed, separate them with the pipe character (`|`) 34 | - Example: `"processing request | starting validation"` 35 | 36 | ### Examples 37 | 38 | #### Good 39 | 40 | ```python 41 | logger.debug("user_id=<%s>, action=<%s> | user performed action", user_id, action) 42 | logger.info("request_id=<%s>, duration_ms=<%d> | request completed", request_id, duration) 43 | logger.warning("attempt=<%d>, max_attempts=<%d> | retry limit approaching", attempt, max_attempts) 44 | ``` 45 | 46 | #### Poor 47 | 48 | ```python 49 | # Avoid: No structured fields, direct variable interpolation in message 50 | logger.debug(f"User {user_id} performed action {action}") 51 | 52 | # Avoid: Inconsistent formatting, punctuation 53 | logger.info("Request completed in %d ms.", duration) 54 | 55 | # Avoid: No separation between fields and message 56 | logger.warning("Retry limit approaching! attempt=%d max_attempts=%d", attempt, max_attempts) 57 | ``` 58 | 59 | By following these log formatting guidelines, we ensure that logs are both human-readable and machine-parseable, making debugging and monitoring more efficient. 60 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling", "hatch-vcs"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "strands-agents" 7 | dynamic = ["version"] 8 | description = "A model-driven approach to building AI agents in just a few lines of code" 9 | readme = "README.md" 10 | requires-python = ">=3.10" 11 | license = {text = "Apache-2.0"} 12 | authors = [ 13 | {name = "AWS", email = "opensource@amazon.com"}, 14 | ] 15 | classifiers = [ 16 | "Development Status :: 3 - Alpha", 17 | "Intended Audience :: Developers", 18 | "License :: OSI Approved :: Apache Software License", 19 | "Operating System :: OS Independent", 20 | "Programming Language :: Python :: 3", 21 | "Programming Language :: Python :: 3.10", 22 | "Programming Language :: Python :: 3.11", 23 | "Programming Language :: Python :: 3.12", 24 | "Programming Language :: Python :: 3.13", 25 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 26 | "Topic :: Software Development :: Libraries :: Python Modules", 27 | ] 28 | dependencies = [ 29 | "boto3>=1.26.0,<2.0.0", 30 | "botocore>=1.29.0,<2.0.0", 31 | "docstring_parser>=0.15,<0.16.0", 32 | "mcp>=1.8.0,<2.0.0", 33 | "pydantic>=2.0.0,<3.0.0", 34 | "typing-extensions>=4.13.2,<5.0.0", 35 | "watchdog>=6.0.0,<7.0.0", 36 | "opentelemetry-api>=1.30.0,<2.0.0", 37 | "opentelemetry-sdk>=1.30.0,<2.0.0", 38 | "opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0", 39 | ] 40 | 41 | [project.urls] 42 | Homepage = "https://github.com/strands-agents/sdk-python" 43 | "Bug Tracker" = "https://github.com/strands-agents/sdk-python/issues" 44 | Documentation = "https://strandsagents.com" 45 | 46 | [tool.hatch.build.targets.wheel] 47 | packages = ["src/strands"] 48 | 49 | [project.optional-dependencies] 50 | anthropic = [ 51 | "anthropic>=0.21.0,<1.0.0", 52 | ] 53 | dev = [ 54 | "commitizen>=4.4.0,<5.0.0", 55 | "hatch>=1.0.0,<2.0.0", 56 | "moto>=5.1.0,<6.0.0", 57 | "mypy>=1.15.0,<2.0.0", 58 | "pre-commit>=3.2.0,<4.2.0", 59 | "pytest>=8.0.0,<9.0.0", 60 | "pytest-asyncio>=0.26.0,<0.27.0", 61 | "ruff>=0.4.4,<0.5.0", 62 | "swagger-parser>=1.0.2,<2.0.0", 63 | ] 64 | docs = [ 65 | "sphinx>=5.0.0,<6.0.0", 66 | "sphinx-rtd-theme>=1.0.0,<2.0.0", 67 | "sphinx-autodoc-typehints>=1.12.0,<2.0.0", 68 | ] 69 | litellm = [ 70 | "litellm>=1.69.0,<2.0.0", 71 | ] 72 | llamaapi = [ 73 | "llama-api-client>=0.1.0,<1.0.0", 74 | ] 75 | ollama = [ 76 | "ollama>=0.4.8,<1.0.0", 77 | ] 78 | openai = [ 79 | "openai>=1.68.0,<2.0.0", 80 | ] 81 | 82 | [tool.hatch.version] 83 | # Tells Hatch to use your version control system (git) to determine the version. 84 | source = "vcs" 85 | 86 | [tool.hatch.envs.hatch-static-analysis] 87 | features = ["anthropic", "litellm", "llamaapi", "ollama", "openai"] 88 | dependencies = [ 89 | "mypy>=1.15.0,<2.0.0", 90 | "ruff>=0.11.6,<0.12.0", 91 | "strands-agents @ {root:uri}" 92 | ] 93 | 94 | [tool.hatch.envs.hatch-static-analysis.scripts] 95 | format-check = [ 96 | "ruff format --check" 97 | ] 98 | format-fix = [ 99 | "ruff format" 100 | ] 101 | lint-check = [ 102 | "ruff check", 103 | "mypy -p src" 104 | ] 105 | lint-fix = [ 106 | "ruff check --fix" 107 | ] 108 | 109 | [tool.hatch.envs.hatch-test] 110 | features = ["anthropic", "litellm", "llamaapi", "ollama", "openai"] 111 | extra-dependencies = [ 112 | "moto>=5.1.0,<6.0.0", 113 | "pytest>=8.0.0,<9.0.0", 114 | "pytest-asyncio>=0.26.0,<0.27.0", 115 | "pytest-cov>=4.1.0,<5.0.0", 116 | "pytest-xdist>=3.0.0,<4.0.0", 117 | ] 118 | extra-args = [ 119 | "-n", 120 | "auto", 121 | "-vv", 122 | ] 123 | 124 | [tool.hatch.envs.dev] 125 | dev-mode = true 126 | features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama"] 127 | 128 | 129 | 130 | [[tool.hatch.envs.hatch-test.matrix]] 131 | python = ["3.13", "3.12", "3.11", "3.10"] 132 | 133 | 134 | [tool.hatch.envs.hatch-test.scripts] 135 | run = [ 136 | "pytest{env:HATCH_TEST_ARGS:} {args}" 137 | ] 138 | run-cov = [ 139 | "pytest{env:HATCH_TEST_ARGS:} --cov --cov-config=pyproject.toml {args}" 140 | ] 141 | 142 | cov-combine = [] 143 | cov-report = [] 144 | 145 | 146 | [tool.hatch.envs.default.scripts] 147 | list = [ 148 | "echo 'Scripts commands available for default env:'; hatch env show --json | jq --raw-output '.default.scripts | keys[]'" 149 | ] 150 | format = [ 151 | "hatch fmt --formatter", 152 | ] 153 | test-format = [ 154 | "hatch fmt --formatter --check", 155 | ] 156 | lint = [ 157 | "hatch fmt --linter" 158 | ] 159 | test-lint = [ 160 | "hatch fmt --linter --check" 161 | ] 162 | test = [ 163 | "hatch test --cover --cov-report html --cov-report xml {args}" 164 | ] 165 | test-integ = [ 166 | "hatch test tests-integ {args}" 167 | ] 168 | 169 | 170 | [tool.mypy] 171 | python_version = "3.10" 172 | warn_return_any = true 173 | warn_unused_configs = true 174 | disallow_untyped_defs = true 175 | disallow_incomplete_defs = true 176 | check_untyped_defs = true 177 | disallow_untyped_decorators = true 178 | no_implicit_optional = true 179 | warn_redundant_casts = true 180 | warn_unused_ignores = true 181 | warn_no_return = true 182 | warn_unreachable = true 183 | follow_untyped_imports = true 184 | ignore_missing_imports = false 185 | 186 | [[tool.mypy.overrides]] 187 | module = "litellm" 188 | ignore_missing_imports = true 189 | 190 | [tool.ruff] 191 | line-length = 120 192 | include = ["examples/**/*.py", "src/**/*.py", "tests/**/*.py", "tests-integ/**/*.py"] 193 | 194 | [tool.ruff.lint] 195 | select = [ 196 | "B", # flake8-bugbear 197 | "D", # pydocstyle 198 | "E", # pycodestyle 199 | "F", # pyflakes 200 | "G", # logging format 201 | "I", # isort 202 | "LOG", # logging 203 | ] 204 | 205 | [tool.ruff.lint.per-file-ignores] 206 | "!src/**/*.py" = ["D"] 207 | 208 | [tool.ruff.lint.pydocstyle] 209 | convention = "google" 210 | 211 | [tool.pytest.ini_options] 212 | testpaths = [ 213 | "tests" 214 | ] 215 | asyncio_default_fixture_loop_scope = "function" 216 | 217 | [tool.coverage.run] 218 | branch = true 219 | source = ["src"] 220 | context = "thread" 221 | parallel = true 222 | concurrency = ["thread", "multiprocessing"] 223 | 224 | [tool.coverage.report] 225 | show_missing = true 226 | 227 | [tool.coverage.html] 228 | directory = "build/coverage/html" 229 | 230 | [tool.coverage.xml] 231 | output = "build/coverage/coverage.xml" 232 | 233 | [tool.commitizen] 234 | name = "cz_conventional_commits" 235 | tag_format = "v$version" 236 | bump_message = "chore(release): bump version $current_version -> $new_version" 237 | version_files = [ 238 | "pyproject.toml:version", 239 | ] 240 | update_changelog_on_bump = true 241 | style = [ 242 | ["qmark", "fg:#ff9d00 bold"], 243 | ["question", "bold"], 244 | ["answer", "fg:#ff9d00 bold"], 245 | ["pointer", "fg:#ff9d00 bold"], 246 | ["highlighted", "fg:#ff9d00 bold"], 247 | ["selected", "fg:#cc5454"], 248 | ["separator", "fg:#cc5454"], 249 | ["instruction", ""], 250 | ["text", ""], 251 | ["disabled", "fg:#858585 italic"] 252 | ] 253 | -------------------------------------------------------------------------------- /src/strands/__init__.py: -------------------------------------------------------------------------------- 1 | """A framework for building, deploying, and managing AI agents.""" 2 | 3 | from . import agent, event_loop, models, telemetry, types 4 | from .agent.agent import Agent 5 | from .tools.decorator import tool 6 | from .tools.thread_pool_executor import ThreadPoolExecutorWrapper 7 | 8 | __all__ = ["Agent", "ThreadPoolExecutorWrapper", "agent", "event_loop", "models", "tool", "types", "telemetry"] 9 | -------------------------------------------------------------------------------- /src/strands/agent/__init__.py: -------------------------------------------------------------------------------- 1 | """This package provides the core Agent interface and supporting components for building AI agents with the SDK. 2 | 3 | It includes: 4 | 5 | - Agent: The main interface for interacting with AI models and tools 6 | - ConversationManager: Classes for managing conversation history and context windows 7 | """ 8 | 9 | from .agent import Agent 10 | from .agent_result import AgentResult 11 | from .conversation_manager import ConversationManager, NullConversationManager, SlidingWindowConversationManager 12 | 13 | __all__ = [ 14 | "Agent", 15 | "AgentResult", 16 | "ConversationManager", 17 | "NullConversationManager", 18 | "SlidingWindowConversationManager", 19 | ] 20 | -------------------------------------------------------------------------------- /src/strands/agent/agent_result.py: -------------------------------------------------------------------------------- 1 | """Agent result handling for SDK. 2 | 3 | This module defines the AgentResult class which encapsulates the complete response from an agent's processing cycle. 4 | """ 5 | 6 | from dataclasses import dataclass 7 | from typing import Any 8 | 9 | from ..telemetry.metrics import EventLoopMetrics 10 | from ..types.content import Message 11 | from ..types.streaming import StopReason 12 | 13 | 14 | @dataclass 15 | class AgentResult: 16 | """Represents the last result of invoking an agent with a prompt. 17 | 18 | Attributes: 19 | stop_reason: The reason why the agent's processing stopped. 20 | message: The last message generated by the agent. 21 | metrics: Performance metrics collected during processing. 22 | state: Additional state information from the event loop. 23 | """ 24 | 25 | stop_reason: StopReason 26 | message: Message 27 | metrics: EventLoopMetrics 28 | state: Any 29 | 30 | def __str__(self) -> str: 31 | """Get the agent's last message as a string. 32 | 33 | This method extracts and concatenates all text content from the final message, ignoring any non-text content 34 | like images or structured data. 35 | 36 | Returns: 37 | The agent's last message as a string. 38 | """ 39 | content_array = self.message.get("content", []) 40 | 41 | result = "" 42 | for item in content_array: 43 | if isinstance(item, dict) and "text" in item: 44 | result += item.get("text", "") + "\n" 45 | 46 | return result 47 | -------------------------------------------------------------------------------- /src/strands/agent/conversation_manager/__init__.py: -------------------------------------------------------------------------------- 1 | """This package provides classes for managing conversation history during agent execution. 2 | 3 | It includes: 4 | 5 | - ConversationManager: Abstract base class defining the conversation management interface 6 | - NullConversationManager: A no-op implementation that does not modify conversation history 7 | - SlidingWindowConversationManager: An implementation that maintains a sliding window of messages to control context 8 | size while preserving conversation coherence 9 | 10 | Conversation managers help control memory usage and context length while maintaining relevant conversation state, which 11 | is critical for effective agent interactions. 12 | """ 13 | 14 | from .conversation_manager import ConversationManager 15 | from .null_conversation_manager import NullConversationManager 16 | from .sliding_window_conversation_manager import SlidingWindowConversationManager 17 | 18 | __all__ = ["ConversationManager", "NullConversationManager", "SlidingWindowConversationManager"] 19 | -------------------------------------------------------------------------------- /src/strands/agent/conversation_manager/conversation_manager.py: -------------------------------------------------------------------------------- 1 | """Abstract interface for conversation history management.""" 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import TYPE_CHECKING, Optional 5 | 6 | if TYPE_CHECKING: 7 | from ...agent.agent import Agent 8 | 9 | 10 | class ConversationManager(ABC): 11 | """Abstract base class for managing conversation history. 12 | 13 | This class provides an interface for implementing conversation management strategies to control the size of message 14 | arrays/conversation histories, helping to: 15 | 16 | - Manage memory usage 17 | - Control context length 18 | - Maintain relevant conversation state 19 | """ 20 | 21 | @abstractmethod 22 | # pragma: no cover 23 | def apply_management(self, agent: "Agent") -> None: 24 | """Applies management strategy to the provided agent. 25 | 26 | Processes the conversation history to maintain appropriate size by modifying the messages list in-place. 27 | Implementations should handle message pruning, summarization, or other size management techniques to keep the 28 | conversation context within desired bounds. 29 | 30 | Args: 31 | agent: The agent whose conversation history will be manage. 32 | This list is modified in-place. 33 | """ 34 | pass 35 | 36 | @abstractmethod 37 | # pragma: no cover 38 | def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None: 39 | """Called when the model's context window is exceeded. 40 | 41 | This method should implement the specific strategy for reducing the window size when a context overflow occurs. 42 | It is typically called after a ContextWindowOverflowException is caught. 43 | 44 | Implementations might use strategies such as: 45 | 46 | - Removing the N oldest messages 47 | - Summarizing older context 48 | - Applying importance-based filtering 49 | - Maintaining critical conversation markers 50 | 51 | Args: 52 | agent: The agent whose conversation history will be reduced. 53 | This list is modified in-place. 54 | e: The exception that triggered the context reduction, if any. 55 | """ 56 | pass 57 | -------------------------------------------------------------------------------- /src/strands/agent/conversation_manager/null_conversation_manager.py: -------------------------------------------------------------------------------- 1 | """Null implementation of conversation management.""" 2 | 3 | from typing import TYPE_CHECKING, Optional 4 | 5 | if TYPE_CHECKING: 6 | from ...agent.agent import Agent 7 | 8 | from ...types.exceptions import ContextWindowOverflowException 9 | from .conversation_manager import ConversationManager 10 | 11 | 12 | class NullConversationManager(ConversationManager): 13 | """A no-op conversation manager that does not modify the conversation history. 14 | 15 | Useful for: 16 | 17 | - Testing scenarios where conversation management should be disabled 18 | - Cases where conversation history is managed externally 19 | - Situations where the full conversation history should be preserved 20 | """ 21 | 22 | def apply_management(self, _agent: "Agent") -> None: 23 | """Does nothing to the conversation history. 24 | 25 | Args: 26 | agent: The agent whose conversation history will remain unmodified. 27 | """ 28 | pass 29 | 30 | def reduce_context(self, _agent: "Agent", e: Optional[Exception] = None) -> None: 31 | """Does not reduce context and raises an exception. 32 | 33 | Args: 34 | agent: The agent whose conversation history will remain unmodified. 35 | e: The exception that triggered the context reduction, if any. 36 | 37 | Raises: 38 | e: If provided. 39 | ContextWindowOverflowException: If e is None. 40 | """ 41 | if e: 42 | raise e 43 | else: 44 | raise ContextWindowOverflowException("Context window overflowed!") 45 | -------------------------------------------------------------------------------- /src/strands/agent/conversation_manager/sliding_window_conversation_manager.py: -------------------------------------------------------------------------------- 1 | """Sliding window conversation history management.""" 2 | 3 | import logging 4 | from typing import TYPE_CHECKING, Optional 5 | 6 | if TYPE_CHECKING: 7 | from ...agent.agent import Agent 8 | 9 | from ...types.content import Message, Messages 10 | from ...types.exceptions import ContextWindowOverflowException 11 | from .conversation_manager import ConversationManager 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def is_user_message(message: Message) -> bool: 17 | """Check if a message is from a user. 18 | 19 | Args: 20 | message: The message object to check. 21 | 22 | Returns: 23 | True if the message has the user role, False otherwise. 24 | """ 25 | return message["role"] == "user" 26 | 27 | 28 | def is_assistant_message(message: Message) -> bool: 29 | """Check if a message is from an assistant. 30 | 31 | Args: 32 | message: The message object to check. 33 | 34 | Returns: 35 | True if the message has the assistant role, False otherwise. 36 | """ 37 | return message["role"] == "assistant" 38 | 39 | 40 | class SlidingWindowConversationManager(ConversationManager): 41 | """Implements a sliding window strategy for managing conversation history. 42 | 43 | This class handles the logic of maintaining a conversation window that preserves tool usage pairs and avoids 44 | invalid window states. 45 | """ 46 | 47 | def __init__(self, window_size: int = 40): 48 | """Initialize the sliding window conversation manager. 49 | 50 | Args: 51 | window_size: Maximum number of messages to keep in the agent's history. 52 | Defaults to 40 messages. 53 | """ 54 | self.window_size = window_size 55 | 56 | def apply_management(self, agent: "Agent") -> None: 57 | """Apply the sliding window to the agent's messages array to maintain a manageable history size. 58 | 59 | This method is called after every event loop cycle, as the messages array may have been modified with tool 60 | results and assistant responses. It first removes any dangling messages that might create an invalid 61 | conversation state, then applies the sliding window if the message count exceeds the window size. 62 | 63 | Special handling is implemented to ensure we don't leave a user message with toolResult 64 | as the first message in the array. It also ensures that all toolUse blocks have corresponding toolResult 65 | blocks to maintain conversation coherence. 66 | 67 | Args: 68 | agent: The agent whose messages will be managed. 69 | This list is modified in-place. 70 | """ 71 | messages = agent.messages 72 | self._remove_dangling_messages(messages) 73 | 74 | if len(messages) <= self.window_size: 75 | logger.debug( 76 | "window_size=<%s>, message_count=<%s> | skipping context reduction", len(messages), self.window_size 77 | ) 78 | return 79 | self.reduce_context(agent) 80 | 81 | def _remove_dangling_messages(self, messages: Messages) -> None: 82 | """Remove dangling messages that would create an invalid conversation state. 83 | 84 | After the event loop cycle is executed, we expect the messages array to end with either an assistant tool use 85 | request followed by the pairing user tool result or an assistant response with no tool use request. If the 86 | event loop cycle fails, we may end up in an invalid message state, and so this method will remove problematic 87 | messages from the end of the array. 88 | 89 | This method handles two specific cases: 90 | 91 | - User with no tool result: Indicates that event loop failed to generate an assistant tool use request 92 | - Assistant with tool use request: Indicates that event loop failed to generate a pairing user tool result 93 | 94 | Args: 95 | messages: The messages to clean up. 96 | This list is modified in-place. 97 | """ 98 | # remove any dangling user messages with no ToolResult 99 | if len(messages) > 0 and is_user_message(messages[-1]): 100 | if not any("toolResult" in content for content in messages[-1]["content"]): 101 | messages.pop() 102 | 103 | # remove any dangling assistant messages with ToolUse 104 | if len(messages) > 0 and is_assistant_message(messages[-1]): 105 | if any("toolUse" in content for content in messages[-1]["content"]): 106 | messages.pop() 107 | # remove remaining dangling user messages with no ToolResult after we popped off an assistant message 108 | if len(messages) > 0 and is_user_message(messages[-1]): 109 | if not any("toolResult" in content for content in messages[-1]["content"]): 110 | messages.pop() 111 | 112 | def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None: 113 | """Trim the oldest messages to reduce the conversation context size. 114 | 115 | The method handles special cases where trimming the messages leads to: 116 | - toolResult with no corresponding toolUse 117 | - toolUse with no corresponding toolResult 118 | 119 | Args: 120 | agent: The agent whose messages will be reduce. 121 | This list is modified in-place. 122 | e: The exception that triggered the context reduction, if any. 123 | 124 | Raises: 125 | ContextWindowOverflowException: If the context cannot be reduced further. 126 | Such as when the conversation is already minimal or when tool result messages cannot be properly 127 | converted. 128 | """ 129 | messages = agent.messages 130 | # If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size 131 | trim_index = 2 if len(messages) <= self.window_size else len(messages) - self.window_size 132 | 133 | # Find the next valid trim_index 134 | while trim_index < len(messages): 135 | if ( 136 | # Oldest message cannot be a toolResult because it needs a toolUse preceding it 137 | any("toolResult" in content for content in messages[trim_index]["content"]) 138 | or ( 139 | # Oldest message can be a toolUse only if a toolResult immediately follows it. 140 | any("toolUse" in content for content in messages[trim_index]["content"]) 141 | and trim_index + 1 < len(messages) 142 | and not any("toolResult" in content for content in messages[trim_index + 1]["content"]) 143 | ) 144 | ): 145 | trim_index += 1 146 | else: 147 | break 148 | else: 149 | # If we didn't find a valid trim_index, then we throw 150 | raise ContextWindowOverflowException("Unable to trim conversation context!") from e 151 | 152 | # Overwrite message history 153 | messages[:] = messages[trim_index:] 154 | -------------------------------------------------------------------------------- /src/strands/event_loop/__init__.py: -------------------------------------------------------------------------------- 1 | """This package provides the core event loop implementation for the agents SDK. 2 | 3 | The event loop enables conversational AI agents to process messages, execute tools, and handle errors in a controlled, 4 | iterative manner. 5 | """ 6 | 7 | from . import error_handler, event_loop, message_processor 8 | 9 | __all__ = ["error_handler", "event_loop", "message_processor"] 10 | -------------------------------------------------------------------------------- /src/strands/event_loop/error_handler.py: -------------------------------------------------------------------------------- 1 | """This module provides specialized error handlers for common issues that may occur during event loop execution. 2 | 3 | Examples include throttling exceptions and context window overflow errors. These handlers implement recovery strategies 4 | like exponential backoff for throttling and message truncation for context window limitations. 5 | """ 6 | 7 | import logging 8 | import time 9 | from typing import Any, Dict, Optional, Tuple 10 | 11 | from ..telemetry.metrics import EventLoopMetrics 12 | from ..types.content import Message, Messages 13 | from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException 14 | from ..types.models import Model 15 | from ..types.streaming import StopReason 16 | from .message_processor import find_last_message_with_tool_results, truncate_tool_results 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | def handle_throttling_error( 22 | e: ModelThrottledException, 23 | attempt: int, 24 | max_attempts: int, 25 | current_delay: int, 26 | max_delay: int, 27 | callback_handler: Any, 28 | kwargs: Dict[str, Any], 29 | ) -> Tuple[bool, int]: 30 | """Handle throttling exceptions from the model provider with exponential backoff. 31 | 32 | Args: 33 | e: The exception that occurred during model invocation. 34 | attempt: Number of times event loop has attempted model invocation. 35 | max_attempts: Maximum number of retry attempts allowed. 36 | current_delay: Current delay in seconds before retrying. 37 | max_delay: Maximum delay in seconds (cap for exponential growth). 38 | callback_handler: Callback for processing events as they happen. 39 | kwargs: Additional arguments to pass to the callback handler. 40 | 41 | Returns: 42 | A tuple containing: 43 | - bool: True if retry should be attempted, False otherwise 44 | - int: The new delay to use for the next retry attempt 45 | """ 46 | if attempt < max_attempts - 1: # Don't sleep on last attempt 47 | logger.debug( 48 | "retry_delay_seconds=<%s>, max_attempts=<%s>, current_attempt=<%s> " 49 | "| throttling exception encountered " 50 | "| delaying before next retry", 51 | current_delay, 52 | max_attempts, 53 | attempt + 1, 54 | ) 55 | callback_handler(event_loop_throttled_delay=current_delay, **kwargs) 56 | time.sleep(current_delay) 57 | new_delay = min(current_delay * 2, max_delay) # Double delay each retry 58 | return True, new_delay 59 | 60 | callback_handler(force_stop=True, force_stop_reason=str(e)) 61 | return False, current_delay 62 | 63 | 64 | def handle_input_too_long_error( 65 | e: ContextWindowOverflowException, 66 | messages: Messages, 67 | model: Model, 68 | system_prompt: Optional[str], 69 | tool_config: Any, 70 | callback_handler: Any, 71 | tool_handler: Any, 72 | kwargs: Dict[str, Any], 73 | ) -> Tuple[StopReason, Message, EventLoopMetrics, Any]: 74 | """Handle 'Input is too long' errors by truncating tool results. 75 | 76 | When a context window overflow exception occurs (input too long for the model), this function attempts to recover 77 | by finding and truncating the most recent tool results in the conversation history. If truncation is successful, the 78 | function will make a call to the event loop. 79 | 80 | Args: 81 | e: The ContextWindowOverflowException that occurred. 82 | messages: The conversation message history. 83 | model: Model provider for running inference. 84 | system_prompt: System prompt for the model. 85 | tool_config: Tool configuration for the conversation. 86 | callback_handler: Callback for processing events as they happen. 87 | tool_handler: Handler for tool execution. 88 | kwargs: Additional arguments for the event loop. 89 | 90 | Returns: 91 | The results from the event loop call if successful. 92 | 93 | Raises: 94 | ContextWindowOverflowException: If messages cannot be truncated. 95 | """ 96 | from .event_loop import recurse_event_loop # Import here to avoid circular imports 97 | 98 | # Find the last message with tool results 99 | last_message_with_tool_results = find_last_message_with_tool_results(messages) 100 | 101 | # If we found a message with toolResult 102 | if last_message_with_tool_results is not None: 103 | logger.debug("message_index=<%s> | found message with tool results at index", last_message_with_tool_results) 104 | 105 | # Truncate the tool results in this message 106 | truncate_tool_results(messages, last_message_with_tool_results) 107 | 108 | return recurse_event_loop( 109 | model=model, 110 | system_prompt=system_prompt, 111 | messages=messages, 112 | tool_config=tool_config, 113 | callback_handler=callback_handler, 114 | tool_handler=tool_handler, 115 | **kwargs, 116 | ) 117 | 118 | # If we can't handle this error, pass it up 119 | callback_handler(force_stop=True, force_stop_reason=str(e)) 120 | logger.error("an exception occurred in event_loop_cycle | %s", e) 121 | raise ContextWindowOverflowException() from e 122 | -------------------------------------------------------------------------------- /src/strands/event_loop/message_processor.py: -------------------------------------------------------------------------------- 1 | """This module provides utilities for processing and manipulating conversation messages within the event loop. 2 | 3 | It includes functions for cleaning up orphaned tool uses, finding messages with specific content types, and truncating 4 | large tool results to prevent context window overflow. 5 | """ 6 | 7 | import logging 8 | from typing import Dict, Optional, Set, Tuple 9 | 10 | from ..types.content import Messages 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def clean_orphaned_empty_tool_uses(messages: Messages) -> bool: 16 | """Clean up orphaned empty tool uses in conversation messages. 17 | 18 | This function identifies and removes any toolUse entries with empty input that don't have a corresponding 19 | toolResult. This prevents validation errors that occur when the model expects matching toolResult blocks for each 20 | toolUse. 21 | 22 | The function applies fixes by either: 23 | 24 | 1. Replacing a message containing only an orphaned toolUse with a context message 25 | 2. Removing the orphaned toolUse entry from a message with multiple content items 26 | 27 | Args: 28 | messages: The conversation message history. 29 | 30 | Returns: 31 | True if any fixes were applied, False otherwise. 32 | """ 33 | if not messages: 34 | return False 35 | 36 | # Dictionary to track empty toolUse entries: {tool_id: (msg_index, content_index, tool_name)} 37 | empty_tool_uses: Dict[str, Tuple[int, int, str]] = {} 38 | 39 | # Set to track toolResults that have been seen 40 | tool_results: Set[str] = set() 41 | 42 | # Identify empty toolUse entries 43 | for i, msg in enumerate(messages): 44 | if msg.get("role") != "assistant": 45 | continue 46 | 47 | for j, content in enumerate(msg.get("content", [])): 48 | if isinstance(content, dict) and "toolUse" in content: 49 | tool_use = content.get("toolUse", {}) 50 | tool_id = tool_use.get("toolUseId") 51 | tool_input = tool_use.get("input", {}) 52 | tool_name = tool_use.get("name", "unknown tool") 53 | 54 | # Check if this is an empty toolUse 55 | if tool_id and (not tool_input or tool_input == {}): 56 | empty_tool_uses[tool_id] = (i, j, tool_name) 57 | 58 | # Identify toolResults 59 | for msg in messages: 60 | if msg.get("role") != "user": 61 | continue 62 | 63 | for content in msg.get("content", []): 64 | if isinstance(content, dict) and "toolResult" in content: 65 | tool_result = content.get("toolResult", {}) 66 | tool_id = tool_result.get("toolUseId") 67 | if tool_id: 68 | tool_results.add(tool_id) 69 | 70 | # Filter for orphaned empty toolUses (no corresponding toolResult) 71 | orphaned_tool_uses = {tool_id: info for tool_id, info in empty_tool_uses.items() if tool_id not in tool_results} 72 | 73 | # Apply fixes in reverse order of occurrence (to avoid index shifting) 74 | if not orphaned_tool_uses: 75 | return False 76 | 77 | # Sort by message index and content index in reverse order 78 | sorted_orphaned = sorted(orphaned_tool_uses.items(), key=lambda x: (x[1][0], x[1][1]), reverse=True) 79 | 80 | # Apply fixes 81 | for tool_id, (msg_idx, content_idx, tool_name) in sorted_orphaned: 82 | logger.debug( 83 | "tool_name=<%s>, tool_id=<%s>, message_index=<%s>, content_index=<%s> " 84 | "fixing orphaned empty tool use at message index", 85 | tool_name, 86 | tool_id, 87 | msg_idx, 88 | content_idx, 89 | ) 90 | try: 91 | # Check if this is the sole content in the message 92 | if len(messages[msg_idx]["content"]) == 1: 93 | # Replace with a message indicating the attempted tool 94 | messages[msg_idx]["content"] = [{"text": f"[Attempted to use {tool_name}, but operation was canceled]"}] 95 | logger.debug("message_index=<%s> | replaced content with context message", msg_idx) 96 | else: 97 | # Simply remove the orphaned toolUse entry 98 | messages[msg_idx]["content"].pop(content_idx) 99 | logger.debug( 100 | "message_index=<%s>, content_index=<%s> | removed content item from message", msg_idx, content_idx 101 | ) 102 | except Exception as e: 103 | logger.warning("failed to fix orphaned tool use | %s", e) 104 | 105 | return True 106 | 107 | 108 | def find_last_message_with_tool_results(messages: Messages) -> Optional[int]: 109 | """Find the index of the last message containing tool results. 110 | 111 | This is useful for identifying messages that might need to be truncated to reduce context size. 112 | 113 | Args: 114 | messages: The conversation message history. 115 | 116 | Returns: 117 | Index of the last message with tool results, or None if no such message exists. 118 | """ 119 | # Iterate backwards through all messages (from newest to oldest) 120 | for idx in range(len(messages) - 1, -1, -1): 121 | # Check if this message has any content with toolResult 122 | current_message = messages[idx] 123 | has_tool_result = False 124 | 125 | for content in current_message.get("content", []): 126 | if isinstance(content, dict) and "toolResult" in content: 127 | has_tool_result = True 128 | break 129 | 130 | if has_tool_result: 131 | return idx 132 | 133 | return None 134 | 135 | 136 | def truncate_tool_results(messages: Messages, msg_idx: int) -> bool: 137 | """Truncate tool results in a message to reduce context size. 138 | 139 | When a message contains tool results that are too large for the model's context window, this function replaces the 140 | content of those tool results with a simple error message. 141 | 142 | Args: 143 | messages: The conversation message history. 144 | msg_idx: Index of the message containing tool results to truncate. 145 | 146 | Returns: 147 | True if any changes were made to the message, False otherwise. 148 | """ 149 | if msg_idx >= len(messages) or msg_idx < 0: 150 | return False 151 | 152 | message = messages[msg_idx] 153 | changes_made = False 154 | 155 | for i, content in enumerate(message.get("content", [])): 156 | if isinstance(content, dict) and "toolResult" in content: 157 | # Update status to error with informative message 158 | message["content"][i]["toolResult"]["status"] = "error" 159 | message["content"][i]["toolResult"]["content"] = [{"text": "The tool result was too large!"}] 160 | changes_made = True 161 | 162 | return changes_made 163 | -------------------------------------------------------------------------------- /src/strands/handlers/__init__.py: -------------------------------------------------------------------------------- 1 | """Various handlers for performing custom actions on agent state. 2 | 3 | Examples include: 4 | 5 | - Processing tool invocations 6 | - Displaying events from the event stream 7 | """ 8 | 9 | from .callback_handler import CompositeCallbackHandler, PrintingCallbackHandler, null_callback_handler 10 | 11 | __all__ = ["CompositeCallbackHandler", "null_callback_handler", "PrintingCallbackHandler"] 12 | -------------------------------------------------------------------------------- /src/strands/handlers/callback_handler.py: -------------------------------------------------------------------------------- 1 | """This module provides handlers for formatting and displaying events from the agent.""" 2 | 3 | from collections.abc import Callable 4 | from typing import Any 5 | 6 | 7 | class PrintingCallbackHandler: 8 | """Handler for streaming text output and tool invocations to stdout.""" 9 | 10 | def __init__(self) -> None: 11 | """Initialize handler.""" 12 | self.tool_count = 0 13 | self.previous_tool_use = None 14 | 15 | def __call__(self, **kwargs: Any) -> None: 16 | """Stream text output and tool invocations to stdout. 17 | 18 | Args: 19 | **kwargs: Callback event data including: 20 | - reasoningText (Optional[str]): Reasoning text to print if provided. 21 | - data (str): Text content to stream. 22 | - complete (bool): Whether this is the final chunk of a response. 23 | - current_tool_use (dict): Information about the current tool being used. 24 | """ 25 | reasoningText = kwargs.get("reasoningText", False) 26 | data = kwargs.get("data", "") 27 | complete = kwargs.get("complete", False) 28 | current_tool_use = kwargs.get("current_tool_use", {}) 29 | 30 | if reasoningText: 31 | print(reasoningText, end="") 32 | 33 | if data: 34 | print(data, end="" if not complete else "\n") 35 | 36 | if current_tool_use and current_tool_use.get("name"): 37 | tool_name = current_tool_use.get("name", "Unknown tool") 38 | if self.previous_tool_use != current_tool_use: 39 | self.previous_tool_use = current_tool_use 40 | self.tool_count += 1 41 | print(f"\nTool #{self.tool_count}: {tool_name}") 42 | 43 | if complete and data: 44 | print("\n") 45 | 46 | 47 | class CompositeCallbackHandler: 48 | """Class-based callback handler that combines multiple callback handlers. 49 | 50 | This handler allows multiple callback handlers to be invoked for the same events, 51 | enabling different processing or output formats for the same stream data. 52 | """ 53 | 54 | def __init__(self, *handlers: Callable) -> None: 55 | """Initialize handler.""" 56 | self.handlers = handlers 57 | 58 | def __call__(self, **kwargs: Any) -> None: 59 | """Invoke all handlers in the chain.""" 60 | for handler in self.handlers: 61 | handler(**kwargs) 62 | 63 | 64 | def null_callback_handler(**_kwargs: Any) -> None: 65 | """Callback handler that discards all output. 66 | 67 | Args: 68 | **_kwargs: Event data (ignored). 69 | """ 70 | return None 71 | -------------------------------------------------------------------------------- /src/strands/handlers/tool_handler.py: -------------------------------------------------------------------------------- 1 | """This module provides handlers for managing tool invocations.""" 2 | 3 | import logging 4 | from typing import Any, List, Optional 5 | 6 | from ..tools.registry import ToolRegistry 7 | from ..types.models import Model 8 | from ..types.tools import ToolConfig, ToolHandler, ToolResult, ToolUse 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class AgentToolHandler(ToolHandler): 14 | """Handler for processing tool invocations in agent. 15 | 16 | This class implements the ToolHandler interface and provides functionality for looking up tools in a registry and 17 | invoking them with the appropriate parameters. 18 | """ 19 | 20 | def __init__(self, tool_registry: ToolRegistry) -> None: 21 | """Initialize handler. 22 | 23 | Args: 24 | tool_registry: Registry of available tools. 25 | """ 26 | self.tool_registry = tool_registry 27 | 28 | def preprocess( 29 | self, 30 | tool: ToolUse, 31 | tool_config: ToolConfig, 32 | **kwargs: Any, 33 | ) -> Optional[ToolResult]: 34 | """Preprocess a tool before invocation (not implemented). 35 | 36 | Args: 37 | tool: The tool use object to preprocess. 38 | tool_config: Configuration for the tool. 39 | **kwargs: Additional keyword arguments. 40 | 41 | Returns: 42 | Result of preprocessing, if any. 43 | """ 44 | pass 45 | 46 | def process( 47 | self, 48 | tool: Any, 49 | *, 50 | model: Model, 51 | system_prompt: Optional[str], 52 | messages: List[Any], 53 | tool_config: Any, 54 | callback_handler: Any, 55 | **kwargs: Any, 56 | ) -> Any: 57 | """Process a tool invocation. 58 | 59 | Looks up the tool in the registry and invokes it with the provided parameters. 60 | 61 | Args: 62 | tool: The tool object to process, containing name and parameters. 63 | model: The model being used for the agent. 64 | system_prompt: The system prompt for the agent. 65 | messages: The conversation history. 66 | tool_config: Configuration for the tool. 67 | callback_handler: Callback for processing events as they happen. 68 | **kwargs: Additional keyword arguments passed to the tool. 69 | 70 | Returns: 71 | The result of the tool invocation, or an error response if the tool fails or is not found. 72 | """ 73 | logger.debug("tool=<%s> | invoking", tool) 74 | tool_use_id = tool["toolUseId"] 75 | tool_name = tool["name"] 76 | 77 | # Get the tool info 78 | tool_info = self.tool_registry.dynamic_tools.get(tool_name) 79 | tool_func = tool_info if tool_info is not None else self.tool_registry.registry.get(tool_name) 80 | 81 | try: 82 | # Check if tool exists 83 | if not tool_func: 84 | logger.error( 85 | "tool_name=<%s>, available_tools=<%s> | tool not found in registry", 86 | tool_name, 87 | list(self.tool_registry.registry.keys()), 88 | ) 89 | return { 90 | "toolUseId": tool_use_id, 91 | "status": "error", 92 | "content": [{"text": f"Unknown tool: {tool_name}"}], 93 | } 94 | # Add standard arguments to kwargs for Python tools 95 | kwargs.update( 96 | { 97 | "model": model, 98 | "system_prompt": system_prompt, 99 | "messages": messages, 100 | "tool_config": tool_config, 101 | "callback_handler": callback_handler, 102 | } 103 | ) 104 | 105 | return tool_func.invoke(tool, **kwargs) 106 | 107 | except Exception as e: 108 | logger.exception("tool_name=<%s> | failed to process tool", tool_name) 109 | return { 110 | "toolUseId": tool_use_id, 111 | "status": "error", 112 | "content": [{"text": f"Error: {str(e)}"}], 113 | } 114 | -------------------------------------------------------------------------------- /src/strands/models/__init__.py: -------------------------------------------------------------------------------- 1 | """SDK model providers. 2 | 3 | This package includes an abstract base Model class along with concrete implementations for specific providers. 4 | """ 5 | 6 | from . import bedrock 7 | from .bedrock import BedrockModel 8 | 9 | __all__ = ["bedrock", "BedrockModel"] 10 | -------------------------------------------------------------------------------- /src/strands/models/litellm.py: -------------------------------------------------------------------------------- 1 | """LiteLLM model provider. 2 | 3 | - Docs: https://docs.litellm.ai/ 4 | """ 5 | 6 | import logging 7 | from typing import Any, Optional, TypedDict, cast 8 | 9 | import litellm 10 | from typing_extensions import Unpack, override 11 | 12 | from ..types.content import ContentBlock 13 | from .openai import OpenAIModel 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class LiteLLMModel(OpenAIModel): 19 | """LiteLLM model provider implementation.""" 20 | 21 | class LiteLLMConfig(TypedDict, total=False): 22 | """Configuration options for LiteLLM models. 23 | 24 | Attributes: 25 | model_id: Model ID (e.g., "openai/gpt-4o", "anthropic/claude-3-sonnet"). 26 | For a complete list of supported models, see https://docs.litellm.ai/docs/providers. 27 | params: Model parameters (e.g., max_tokens). 28 | For a complete list of supported parameters, see 29 | https://docs.litellm.ai/docs/completion/input#input-params-1. 30 | """ 31 | 32 | model_id: str 33 | params: Optional[dict[str, Any]] 34 | 35 | def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[LiteLLMConfig]) -> None: 36 | """Initialize provider instance. 37 | 38 | Args: 39 | client_args: Arguments for the LiteLLM client. 40 | For a complete list of supported arguments, see 41 | https://github.com/BerriAI/litellm/blob/main/litellm/main.py. 42 | **model_config: Configuration options for the LiteLLM model. 43 | """ 44 | self.config = dict(model_config) 45 | 46 | logger.debug("config=<%s> | initializing", self.config) 47 | 48 | client_args = client_args or {} 49 | self.client = litellm.LiteLLM(**client_args) 50 | 51 | @override 52 | def update_config(self, **model_config: Unpack[LiteLLMConfig]) -> None: # type: ignore[override] 53 | """Update the LiteLLM model configuration with the provided arguments. 54 | 55 | Args: 56 | **model_config: Configuration overrides. 57 | """ 58 | self.config.update(model_config) 59 | 60 | @override 61 | def get_config(self) -> LiteLLMConfig: 62 | """Get the LiteLLM model configuration. 63 | 64 | Returns: 65 | The LiteLLM model configuration. 66 | """ 67 | return cast(LiteLLMModel.LiteLLMConfig, self.config) 68 | 69 | @override 70 | @classmethod 71 | def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: 72 | """Format a LiteLLM content block. 73 | 74 | Args: 75 | content: Message content. 76 | 77 | Returns: 78 | LiteLLM formatted content block. 79 | 80 | Raises: 81 | TypeError: If the content block type cannot be converted to a LiteLLM-compatible format. 82 | """ 83 | if "reasoningContent" in content: 84 | return { 85 | "signature": content["reasoningContent"]["reasoningText"]["signature"], 86 | "thinking": content["reasoningContent"]["reasoningText"]["text"], 87 | "type": "thinking", 88 | } 89 | 90 | if "video" in content: 91 | return { 92 | "type": "video_url", 93 | "video_url": { 94 | "detail": "auto", 95 | "url": content["video"]["source"]["bytes"], 96 | }, 97 | } 98 | 99 | return super().format_request_message_content(content) 100 | -------------------------------------------------------------------------------- /src/strands/models/openai.py: -------------------------------------------------------------------------------- 1 | """OpenAI model provider. 2 | 3 | - Docs: https://platform.openai.com/docs/overview 4 | """ 5 | 6 | import logging 7 | from typing import Any, Iterable, Optional, Protocol, TypedDict, cast 8 | 9 | import openai 10 | from typing_extensions import Unpack, override 11 | 12 | from ..types.models import OpenAIModel as SAOpenAIModel 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class Client(Protocol): 18 | """Protocol defining the OpenAI-compatible interface for the underlying provider client.""" 19 | 20 | @property 21 | # pragma: no cover 22 | def chat(self) -> Any: 23 | """Chat completions interface.""" 24 | ... 25 | 26 | 27 | class OpenAIModel(SAOpenAIModel): 28 | """OpenAI model provider implementation.""" 29 | 30 | client: Client 31 | 32 | class OpenAIConfig(TypedDict, total=False): 33 | """Configuration options for OpenAI models. 34 | 35 | Attributes: 36 | model_id: Model ID (e.g., "gpt-4o"). 37 | For a complete list of supported models, see https://platform.openai.com/docs/models. 38 | params: Model parameters (e.g., max_tokens). 39 | For a complete list of supported parameters, see 40 | https://platform.openai.com/docs/api-reference/chat/create. 41 | """ 42 | 43 | model_id: str 44 | params: Optional[dict[str, Any]] 45 | 46 | def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[OpenAIConfig]) -> None: 47 | """Initialize provider instance. 48 | 49 | Args: 50 | client_args: Arguments for the OpenAI client. 51 | For a complete list of supported arguments, see https://pypi.org/project/openai/. 52 | **model_config: Configuration options for the OpenAI model. 53 | """ 54 | self.config = dict(model_config) 55 | 56 | logger.debug("config=<%s> | initializing", self.config) 57 | 58 | client_args = client_args or {} 59 | self.client = openai.OpenAI(**client_args) 60 | 61 | @override 62 | def update_config(self, **model_config: Unpack[OpenAIConfig]) -> None: # type: ignore[override] 63 | """Update the OpenAI model configuration with the provided arguments. 64 | 65 | Args: 66 | **model_config: Configuration overrides. 67 | """ 68 | self.config.update(model_config) 69 | 70 | @override 71 | def get_config(self) -> OpenAIConfig: 72 | """Get the OpenAI model configuration. 73 | 74 | Returns: 75 | The OpenAI model configuration. 76 | """ 77 | return cast(OpenAIModel.OpenAIConfig, self.config) 78 | 79 | @override 80 | def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: 81 | """Send the request to the OpenAI model and get the streaming response. 82 | 83 | Args: 84 | request: The formatted request to send to the OpenAI model. 85 | 86 | Returns: 87 | An iterable of response events from the OpenAI model. 88 | """ 89 | response = self.client.chat.completions.create(**request) 90 | 91 | yield {"chunk_type": "message_start"} 92 | yield {"chunk_type": "content_start", "data_type": "text"} 93 | 94 | tool_calls: dict[int, list[Any]] = {} 95 | 96 | for event in response: 97 | # Defensive: skip events with empty or missing choices 98 | if not getattr(event, "choices", None): 99 | continue 100 | choice = event.choices[0] 101 | 102 | if choice.delta.content: 103 | yield {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content} 104 | 105 | for tool_call in choice.delta.tool_calls or []: 106 | tool_calls.setdefault(tool_call.index, []).append(tool_call) 107 | 108 | if choice.finish_reason: 109 | break 110 | 111 | yield {"chunk_type": "content_stop", "data_type": "text"} 112 | 113 | for tool_deltas in tool_calls.values(): 114 | yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]} 115 | 116 | for tool_delta in tool_deltas: 117 | yield {"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta} 118 | 119 | yield {"chunk_type": "content_stop", "data_type": "tool"} 120 | 121 | yield {"chunk_type": "message_stop", "data": choice.finish_reason} 122 | 123 | # Skip remaining events as we don't have use for anything except the final usage payload 124 | for event in response: 125 | _ = event 126 | 127 | yield {"chunk_type": "metadata", "data": event.usage} 128 | -------------------------------------------------------------------------------- /src/strands/py.typed: -------------------------------------------------------------------------------- 1 | # Marker file that indicates this package supports typing 2 | -------------------------------------------------------------------------------- /src/strands/telemetry/__init__.py: -------------------------------------------------------------------------------- 1 | """Telemetry module. 2 | 3 | This module provides metrics and tracing functionality. 4 | """ 5 | 6 | from .metrics import EventLoopMetrics, Trace, metrics_to_string 7 | from .tracer import Tracer, get_tracer 8 | 9 | __all__ = [ 10 | "EventLoopMetrics", 11 | "Trace", 12 | "metrics_to_string", 13 | "Tracer", 14 | "get_tracer", 15 | ] 16 | -------------------------------------------------------------------------------- /src/strands/tools/__init__.py: -------------------------------------------------------------------------------- 1 | """Agent tool interfaces and utilities. 2 | 3 | This module provides the core functionality for creating, managing, and executing tools through agents. 4 | """ 5 | 6 | from .decorator import tool 7 | from .thread_pool_executor import ThreadPoolExecutorWrapper 8 | from .tools import FunctionTool, InvalidToolUseNameException, PythonAgentTool, normalize_schema, normalize_tool_spec 9 | 10 | __all__ = [ 11 | "tool", 12 | "FunctionTool", 13 | "PythonAgentTool", 14 | "InvalidToolUseNameException", 15 | "normalize_schema", 16 | "normalize_tool_spec", 17 | "ThreadPoolExecutorWrapper", 18 | ] 19 | -------------------------------------------------------------------------------- /src/strands/tools/executor.py: -------------------------------------------------------------------------------- 1 | """Tool execution functionality for the event loop.""" 2 | 3 | import logging 4 | import time 5 | from concurrent.futures import TimeoutError 6 | from typing import Any, Callable, List, Optional, Tuple 7 | 8 | from opentelemetry import trace 9 | 10 | from ..telemetry.metrics import EventLoopMetrics, Trace 11 | from ..telemetry.tracer import get_tracer 12 | from ..tools.tools import InvalidToolUseNameException, validate_tool_use 13 | from ..types.content import Message 14 | from ..types.event_loop import ParallelToolExecutorInterface 15 | from ..types.tools import ToolResult, ToolUse 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | def run_tools( 21 | handler: Callable[[ToolUse], ToolResult], 22 | tool_uses: List[ToolUse], 23 | event_loop_metrics: EventLoopMetrics, 24 | request_state: Any, 25 | invalid_tool_use_ids: List[str], 26 | tool_results: List[ToolResult], 27 | cycle_trace: Trace, 28 | parent_span: Optional[trace.Span] = None, 29 | parallel_tool_executor: Optional[ParallelToolExecutorInterface] = None, 30 | ) -> bool: 31 | """Execute tools either in parallel or sequentially. 32 | 33 | Args: 34 | handler: Tool handler processing function. 35 | tool_uses: List of tool uses to execute. 36 | event_loop_metrics: Metrics collection object. 37 | request_state: Current request state. 38 | invalid_tool_use_ids: List of invalid tool use IDs. 39 | tool_results: List to populate with tool results. 40 | cycle_trace: Parent trace for the current cycle. 41 | parent_span: Parent span for the current cycle. 42 | parallel_tool_executor: Optional executor for parallel processing. 43 | 44 | Returns: 45 | bool: True if any tool failed, False otherwise. 46 | """ 47 | 48 | def _handle_tool_execution(tool: ToolUse) -> Tuple[bool, Optional[ToolResult]]: 49 | result = None 50 | tool_succeeded = False 51 | 52 | tracer = get_tracer() 53 | tool_call_span = tracer.start_tool_call_span(tool, parent_span) 54 | 55 | try: 56 | if "toolUseId" not in tool or tool["toolUseId"] not in invalid_tool_use_ids: 57 | tool_name = tool["name"] 58 | tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name) 59 | tool_start_time = time.time() 60 | result = handler(tool) 61 | tool_success = result.get("status") == "success" 62 | if tool_success: 63 | tool_succeeded = True 64 | 65 | tool_duration = time.time() - tool_start_time 66 | message = Message(role="user", content=[{"toolResult": result}]) 67 | event_loop_metrics.add_tool_usage(tool, tool_duration, tool_trace, tool_success, message) 68 | cycle_trace.add_child(tool_trace) 69 | 70 | if tool_call_span: 71 | tracer.end_tool_call_span(tool_call_span, result) 72 | except Exception as e: 73 | if tool_call_span: 74 | tracer.end_span_with_error(tool_call_span, str(e), e) 75 | 76 | return tool_succeeded, result 77 | 78 | any_tool_failed = False 79 | if parallel_tool_executor: 80 | logger.debug( 81 | "tool_count=<%s>, tool_executor=<%s> | executing tools in parallel", 82 | len(tool_uses), 83 | type(parallel_tool_executor).__name__, 84 | ) 85 | # Submit all tasks with their associated tools 86 | future_to_tool = { 87 | parallel_tool_executor.submit(_handle_tool_execution, tool_use): tool_use for tool_use in tool_uses 88 | } 89 | logger.debug("tool_count=<%s> | submitted tasks to parallel executor", len(tool_uses)) 90 | 91 | # Collect results truly in parallel using the provided executor's as_completed method 92 | completed_results = [] 93 | try: 94 | for future in parallel_tool_executor.as_completed(future_to_tool): 95 | try: 96 | succeeded, result = future.result() 97 | if result is not None: 98 | completed_results.append(result) 99 | if not succeeded: 100 | any_tool_failed = True 101 | except Exception as e: 102 | tool = future_to_tool[future] 103 | logger.debug("tool_name=<%s> | tool execution failed | %s", tool["name"], e) 104 | any_tool_failed = True 105 | except TimeoutError: 106 | logger.error("timeout_seconds=<%s> | parallel tool execution timed out", parallel_tool_executor.timeout) 107 | # Process any completed tasks 108 | for future in future_to_tool: 109 | if future.done(): # type: ignore 110 | try: 111 | succeeded, result = future.result(timeout=0) 112 | if result is not None: 113 | completed_results.append(result) 114 | except Exception as tool_e: 115 | tool = future_to_tool[future] 116 | logger.debug("tool_name=<%s> | tool execution failed | %s", tool["name"], tool_e) 117 | else: 118 | # This future didn't complete within the timeout 119 | tool = future_to_tool[future] 120 | logger.debug("tool_name=<%s> | tool execution timed out", tool["name"]) 121 | 122 | any_tool_failed = True 123 | 124 | # Add completed results to tool_results 125 | tool_results.extend(completed_results) 126 | else: 127 | # Sequential execution fallback 128 | for tool_use in tool_uses: 129 | succeeded, result = _handle_tool_execution(tool_use) 130 | if result is not None: 131 | tool_results.append(result) 132 | if not succeeded: 133 | any_tool_failed = True 134 | 135 | return any_tool_failed 136 | 137 | 138 | def validate_and_prepare_tools( 139 | message: Message, 140 | tool_uses: List[ToolUse], 141 | tool_results: List[ToolResult], 142 | invalid_tool_use_ids: List[str], 143 | ) -> None: 144 | """Validate tool uses and prepare them for execution. 145 | 146 | Args: 147 | message: Current message. 148 | tool_uses: List to populate with tool uses. 149 | tool_results: List to populate with tool results for invalid tools. 150 | invalid_tool_use_ids: List to populate with invalid tool use IDs. 151 | """ 152 | # Extract tool uses from message 153 | for content in message["content"]: 154 | if isinstance(content, dict) and "toolUse" in content: 155 | tool_uses.append(content["toolUse"]) 156 | 157 | # Validate tool uses 158 | # Avoid modifying original `tool_uses` variable during iteration 159 | tool_uses_copy = tool_uses.copy() 160 | for tool in tool_uses_copy: 161 | try: 162 | validate_tool_use(tool) 163 | except InvalidToolUseNameException as e: 164 | # Replace the invalid toolUse name and return invalid name error as ToolResult to the LLM as context 165 | tool_uses.remove(tool) 166 | tool["name"] = "INVALID_TOOL_NAME" 167 | invalid_tool_use_ids.append(tool["toolUseId"]) 168 | tool_uses.append(tool) 169 | tool_results.append( 170 | { 171 | "toolUseId": tool["toolUseId"], 172 | "status": "error", 173 | "content": [{"text": f"Error: {str(e)}"}], 174 | } 175 | ) 176 | -------------------------------------------------------------------------------- /src/strands/tools/mcp/__init__.py: -------------------------------------------------------------------------------- 1 | """Model Context Protocol (MCP) integration. 2 | 3 | This package provides integration with the Model Context Protocol (MCP), allowing agents to use tools provided by MCP 4 | servers. 5 | 6 | - Docs: https://www.anthropic.com/news/model-context-protocol 7 | """ 8 | 9 | from .mcp_agent_tool import MCPAgentTool 10 | from .mcp_client import MCPClient 11 | from .mcp_types import MCPTransport 12 | 13 | __all__ = ["MCPAgentTool", "MCPClient", "MCPTransport"] 14 | -------------------------------------------------------------------------------- /src/strands/tools/mcp/mcp_agent_tool.py: -------------------------------------------------------------------------------- 1 | """MCP Agent Tool module for adapting Model Context Protocol tools to the agent framework. 2 | 3 | This module provides the MCPAgentTool class which serves as an adapter between 4 | MCP (Model Context Protocol) tools and the agent framework's tool interface. 5 | It allows MCP tools to be seamlessly integrated and used within the agent ecosystem. 6 | """ 7 | 8 | import logging 9 | from typing import TYPE_CHECKING, Any 10 | 11 | from mcp.types import Tool as MCPTool 12 | 13 | from ...types.tools import AgentTool, ToolResult, ToolSpec, ToolUse 14 | 15 | if TYPE_CHECKING: 16 | from .mcp_client import MCPClient 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class MCPAgentTool(AgentTool): 22 | """Adapter class that wraps an MCP tool and exposes it as an AgentTool. 23 | 24 | This class bridges the gap between the MCP protocol's tool representation 25 | and the agent framework's tool interface, allowing MCP tools to be used 26 | seamlessly within the agent framework. 27 | """ 28 | 29 | def __init__(self, mcp_tool: MCPTool, mcp_client: "MCPClient") -> None: 30 | """Initialize a new MCPAgentTool instance. 31 | 32 | Args: 33 | mcp_tool: The MCP tool to adapt 34 | mcp_client: The MCP server connection to use for tool invocation 35 | """ 36 | super().__init__() 37 | logger.debug("tool_name=<%s> | creating mcp agent tool", mcp_tool.name) 38 | self.mcp_tool = mcp_tool 39 | self.mcp_client = mcp_client 40 | 41 | @property 42 | def tool_name(self) -> str: 43 | """Get the name of the tool. 44 | 45 | Returns: 46 | str: The name of the MCP tool 47 | """ 48 | return self.mcp_tool.name 49 | 50 | @property 51 | def tool_spec(self) -> ToolSpec: 52 | """Get the specification of the tool. 53 | 54 | This method converts the MCP tool specification to the agent framework's 55 | ToolSpec format, including the input schema and description. 56 | 57 | Returns: 58 | ToolSpec: The tool specification in the agent framework format 59 | """ 60 | description: str = self.mcp_tool.description or f"Tool which performs {self.mcp_tool.name}" 61 | return { 62 | "inputSchema": {"json": self.mcp_tool.inputSchema}, 63 | "name": self.mcp_tool.name, 64 | "description": description, 65 | } 66 | 67 | @property 68 | def tool_type(self) -> str: 69 | """Get the type of the tool. 70 | 71 | Returns: 72 | str: The type of the tool, always "python" for MCP tools 73 | """ 74 | return "python" 75 | 76 | def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolResult: 77 | """Invoke the MCP tool. 78 | 79 | This method delegates the tool invocation to the MCP server connection, 80 | passing the tool use ID, tool name, and input arguments. 81 | """ 82 | logger.debug("invoking MCP tool '%s' with tool_use_id=%s", self.tool_name, tool["toolUseId"]) 83 | return self.mcp_client.call_tool_sync( 84 | tool_use_id=tool["toolUseId"], name=self.tool_name, arguments=tool["input"] 85 | ) 86 | -------------------------------------------------------------------------------- /src/strands/tools/mcp/mcp_types.py: -------------------------------------------------------------------------------- 1 | """Type definitions for MCP integration.""" 2 | 3 | from contextlib import AbstractAsyncContextManager 4 | 5 | from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream 6 | from mcp.client.streamable_http import GetSessionIdCallback 7 | from mcp.shared.memory import MessageStream 8 | from mcp.shared.message import SessionMessage 9 | 10 | """ 11 | MCPTransport defines the interface for MCP transport implementations. This abstracts 12 | communication with an MCP server, hiding details of the underlying transport mechanism (WebSocket, stdio, etc.). 13 | 14 | It represents an async context manager that yields a tuple of read and write streams for MCP communication. 15 | When used with `async with`, it should establish the connection and yield the streams, then clean up 16 | when the context is exited. 17 | 18 | The read stream receives messages from the client (or exceptions if parsing fails), while the write 19 | stream sends messages to the client. 20 | 21 | Example implementation (simplified): 22 | ```python 23 | @contextlib.asynccontextmanager 24 | async def my_transport_implementation(): 25 | # Set up connection 26 | read_stream_writer, read_stream = anyio.create_memory_object_stream(0) 27 | write_stream, write_stream_reader = anyio.create_memory_object_stream(0) 28 | 29 | # Start background tasks to handle actual I/O 30 | async with anyio.create_task_group() as tg: 31 | tg.start_soon(reader_task, read_stream_writer) 32 | tg.start_soon(writer_task, write_stream_reader) 33 | 34 | # Yield the streams to the caller 35 | yield (read_stream, write_stream) 36 | ``` 37 | """ 38 | # GetSessionIdCallback was added for HTTP Streaming but was not applied to the MessageStream type 39 | # https://github.com/modelcontextprotocol/python-sdk/blob/ed25167fa5d715733437996682e20c24470e8177/src/mcp/client/streamable_http.py#L418 40 | _MessageStreamWithGetSessionIdCallback = tuple[ 41 | MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage], GetSessionIdCallback 42 | ] 43 | MCPTransport = AbstractAsyncContextManager[MessageStream | _MessageStreamWithGetSessionIdCallback] 44 | -------------------------------------------------------------------------------- /src/strands/tools/thread_pool_executor.py: -------------------------------------------------------------------------------- 1 | """Thread pool execution management for parallel tool calls.""" 2 | 3 | import concurrent.futures 4 | from concurrent.futures import ThreadPoolExecutor 5 | from typing import Any, Callable, Iterable, Iterator, Optional 6 | 7 | from ..types.event_loop import Future, ParallelToolExecutorInterface 8 | 9 | 10 | class ThreadPoolExecutorWrapper(ParallelToolExecutorInterface): 11 | """Wrapper around ThreadPoolExecutor to implement the strands.types.event_loop.ParallelToolExecutorInterface. 12 | 13 | This class adapts Python's standard ThreadPoolExecutor to conform to the SDK's ParallelToolExecutorInterface, 14 | allowing it to be used for parallel tool execution within the agent event loop. It provides methods for submitting 15 | tasks, monitoring their completion, and shutting down the executor. 16 | 17 | Attributes: 18 | thread_pool: The underlying ThreadPoolExecutor instance. 19 | """ 20 | 21 | def __init__(self, thread_pool: ThreadPoolExecutor): 22 | """Initialize with a ThreadPoolExecutor instance. 23 | 24 | Args: 25 | thread_pool: The ThreadPoolExecutor to wrap. 26 | """ 27 | self.thread_pool = thread_pool 28 | 29 | def submit(self, fn: Callable[..., Any], /, *args: Any, **kwargs: Any) -> Future: 30 | """Submit a callable to be executed with the given arguments. 31 | 32 | This method schedules the callable to be executed as fn(*args, **kwargs) 33 | and returns a Future instance representing the execution of the callable. 34 | 35 | Args: 36 | fn: The callable to execute. 37 | *args: Positional arguments for the callable. 38 | **kwargs: Keyword arguments for the callable. 39 | 40 | Returns: 41 | A Future instance representing the execution of the callable. 42 | """ 43 | return self.thread_pool.submit(fn, *args, **kwargs) 44 | 45 | def as_completed(self, futures: Iterable[Future], timeout: Optional[int] = None) -> Iterator[Future]: 46 | """Return an iterator over the futures as they complete. 47 | 48 | The returned iterator yields futures as they complete (finished or cancelled). 49 | 50 | Args: 51 | futures: The futures to iterate over. 52 | timeout: The maximum number of seconds to wait. 53 | None means no limit. 54 | 55 | Returns: 56 | An iterator yielding futures as they complete. 57 | 58 | Raises: 59 | concurrent.futures.TimeoutError: If the timeout is reached. 60 | """ 61 | return concurrent.futures.as_completed(futures, timeout=timeout) # type: ignore 62 | 63 | def shutdown(self, wait: bool = True) -> None: 64 | """Shutdown the thread pool executor. 65 | 66 | Args: 67 | wait: If True, waits until all running futures have finished executing. 68 | """ 69 | self.thread_pool.shutdown(wait=wait) 70 | -------------------------------------------------------------------------------- /src/strands/tools/watcher.py: -------------------------------------------------------------------------------- 1 | """Tool watcher for hot reloading tools during development. 2 | 3 | This module provides functionality to watch tool directories for changes and automatically reload tools when they are 4 | modified. 5 | """ 6 | 7 | import logging 8 | from pathlib import Path 9 | from typing import Any, Dict, Set 10 | 11 | from watchdog.events import FileSystemEventHandler 12 | from watchdog.observers import Observer 13 | 14 | from .registry import ToolRegistry 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | class ToolWatcher: 20 | """Watches tool directories for changes and reloads tools when they are modified.""" 21 | 22 | # This class uses class variables for the observer and handlers because watchdog allows only one Observer instance 23 | # per directory. Using class variables ensures that all ToolWatcher instances share a single Observer, with the 24 | # MasterChangeHandler routing file system events to the appropriate individual handlers for each registry. This 25 | # design pattern avoids conflicts when multiple tool registries are watching the same directories. 26 | 27 | _shared_observer = None 28 | _watched_dirs: Set[str] = set() 29 | _observer_started = False 30 | _registry_handlers: Dict[str, Dict[int, "ToolWatcher.ToolChangeHandler"]] = {} 31 | 32 | def __init__(self, tool_registry: ToolRegistry) -> None: 33 | """Initialize a tool watcher for the given tool registry. 34 | 35 | Args: 36 | tool_registry: The tool registry to report changes. 37 | """ 38 | self.tool_registry = tool_registry 39 | self.start() 40 | 41 | class ToolChangeHandler(FileSystemEventHandler): 42 | """Handler for tool file changes.""" 43 | 44 | def __init__(self, tool_registry: ToolRegistry) -> None: 45 | """Initialize a tool change handler. 46 | 47 | Args: 48 | tool_registry: The tool registry to update when tools change. 49 | """ 50 | self.tool_registry = tool_registry 51 | 52 | def on_modified(self, event: Any) -> None: 53 | """Reload tool if file modification detected. 54 | 55 | Args: 56 | event: The file system event that triggered this handler. 57 | """ 58 | if event.src_path.endswith(".py"): 59 | tool_path = Path(event.src_path) 60 | tool_name = tool_path.stem 61 | 62 | if tool_name not in ["__init__"]: 63 | logger.debug("tool_name=<%s> | tool change detected", tool_name) 64 | try: 65 | self.tool_registry.reload_tool(tool_name) 66 | except Exception as e: 67 | logger.error("tool_name=<%s>, exception=<%s> | failed to reload tool", tool_name, str(e)) 68 | 69 | class MasterChangeHandler(FileSystemEventHandler): 70 | """Master handler that delegates to all registered handlers.""" 71 | 72 | def __init__(self, dir_path: str) -> None: 73 | """Initialize a master change handler for a specific directory. 74 | 75 | Args: 76 | dir_path: The directory path to watch. 77 | """ 78 | self.dir_path = dir_path 79 | 80 | def on_modified(self, event: Any) -> None: 81 | """Delegate file modification events to all registered handlers. 82 | 83 | Args: 84 | event: The file system event that triggered this handler. 85 | """ 86 | if event.src_path.endswith(".py"): 87 | tool_path = Path(event.src_path) 88 | tool_name = tool_path.stem 89 | 90 | if tool_name not in ["__init__"]: 91 | # Delegate to all registered handlers for this directory 92 | for handler in ToolWatcher._registry_handlers.get(self.dir_path, {}).values(): 93 | try: 94 | handler.on_modified(event) 95 | except Exception as e: 96 | logger.error("exception=<%s> | handler error", str(e)) 97 | 98 | def start(self) -> None: 99 | """Start watching all tools directories for changes.""" 100 | # Initialize shared observer if not already done 101 | if ToolWatcher._shared_observer is None: 102 | ToolWatcher._shared_observer = Observer() 103 | 104 | # Create handler for this instance 105 | self.tool_change_handler = self.ToolChangeHandler(self.tool_registry) 106 | registry_id = id(self.tool_registry) 107 | 108 | # Get tools directories to watch 109 | tools_dirs = self.tool_registry.get_tools_dirs() 110 | 111 | for tools_dir in tools_dirs: 112 | dir_str = str(tools_dir) 113 | 114 | # Initialize the registry handlers dict for this directory if needed 115 | if dir_str not in ToolWatcher._registry_handlers: 116 | ToolWatcher._registry_handlers[dir_str] = {} 117 | 118 | # Store this handler with its registry id 119 | ToolWatcher._registry_handlers[dir_str][registry_id] = self.tool_change_handler 120 | 121 | # Schedule or update the master handler for this directory 122 | if dir_str not in ToolWatcher._watched_dirs: 123 | # First time seeing this directory, create a master handler 124 | master_handler = self.MasterChangeHandler(dir_str) 125 | ToolWatcher._shared_observer.schedule(master_handler, dir_str, recursive=False) 126 | ToolWatcher._watched_dirs.add(dir_str) 127 | logger.debug("tools_dir=<%s> | started watching tools directory", tools_dir) 128 | else: 129 | # Directory already being watched, just log it 130 | logger.debug("tools_dir=<%s> | directory already being watched", tools_dir) 131 | 132 | # Start the observer if not already started 133 | if not ToolWatcher._observer_started: 134 | ToolWatcher._shared_observer.start() 135 | ToolWatcher._observer_started = True 136 | logger.debug("tool directory watching initialized") 137 | -------------------------------------------------------------------------------- /src/strands/types/__init__.py: -------------------------------------------------------------------------------- 1 | """SDK type definitions.""" 2 | -------------------------------------------------------------------------------- /src/strands/types/content.py: -------------------------------------------------------------------------------- 1 | """Content-related type definitions for the SDK. 2 | 3 | This module defines the types used to represent messages, content blocks, and other content-related structures in the 4 | SDK. These types are modeled after the Bedrock API. 5 | 6 | - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html 7 | """ 8 | 9 | from typing import Dict, List, Literal, Optional 10 | 11 | from typing_extensions import TypedDict 12 | 13 | from .media import DocumentContent, ImageContent, VideoContent 14 | from .tools import ToolResult, ToolUse 15 | 16 | 17 | class GuardContentText(TypedDict): 18 | """Text content to be evaluated by guardrails. 19 | 20 | Attributes: 21 | qualifiers: The qualifiers describing the text block. 22 | text: The input text details to be evaluated by the guardrail. 23 | """ 24 | 25 | qualifiers: List[Literal["grounding_source", "query", "guard_content"]] 26 | text: str 27 | 28 | 29 | class GuardContent(TypedDict): 30 | """Content block to be evaluated by guardrails. 31 | 32 | Attributes: 33 | text: Text within content block to be evaluated by the guardrail. 34 | """ 35 | 36 | text: GuardContentText 37 | 38 | 39 | class ReasoningTextBlock(TypedDict, total=False): 40 | """Contains the reasoning that the model used to return the output. 41 | 42 | Attributes: 43 | signature: A token that verifies that the reasoning text was generated by the model. 44 | text: The reasoning that the model used to return the output. 45 | """ 46 | 47 | signature: Optional[str] 48 | text: str 49 | 50 | 51 | class ReasoningContentBlock(TypedDict, total=False): 52 | """Contains content regarding the reasoning that is carried out by the model. 53 | 54 | Attributes: 55 | reasoningText: The reasoning that the model used to return the output. 56 | redactedContent: The content in the reasoning that was encrypted by the model provider for safety reasons. 57 | """ 58 | 59 | reasoningText: ReasoningTextBlock 60 | redactedContent: bytes 61 | 62 | 63 | class CachePoint(TypedDict): 64 | """A cache point configuration for optimizing conversation history. 65 | 66 | Attributes: 67 | type: The type of cache point, typically "default". 68 | """ 69 | 70 | type: str 71 | 72 | 73 | class ContentBlock(TypedDict, total=False): 74 | """A block of content for a message that you pass to, or receive from, a model. 75 | 76 | Attributes: 77 | cachePoint: A cache point configuration to optimize conversation history. 78 | document: A document to include in the message. 79 | guardContent: Contains the content to assess with the guardrail. 80 | image: Image to include in the message. 81 | reasoningContent: Contains content regarding the reasoning that is carried out by the model. 82 | text: Text to include in the message. 83 | toolResult: The result for a tool request that a model makes. 84 | toolUse: Information about a tool use request from a model. 85 | video: Video to include in the message. 86 | """ 87 | 88 | cachePoint: CachePoint 89 | document: DocumentContent 90 | guardContent: GuardContent 91 | image: ImageContent 92 | reasoningContent: ReasoningContentBlock 93 | text: str 94 | toolResult: ToolResult 95 | toolUse: ToolUse 96 | video: VideoContent 97 | 98 | 99 | class SystemContentBlock(TypedDict, total=False): 100 | """Contains configurations for instructions to provide the model for how to handle input. 101 | 102 | Attributes: 103 | guardContent: A content block to assess with the guardrail. 104 | text: A system prompt for the model. 105 | """ 106 | 107 | guardContent: GuardContent 108 | text: str 109 | 110 | 111 | class DeltaContent(TypedDict, total=False): 112 | """A block of content in a streaming response. 113 | 114 | Attributes: 115 | text: The content text. 116 | toolUse: Information about a tool that the model is requesting to use. 117 | """ 118 | 119 | text: str 120 | toolUse: Dict[Literal["input"], str] 121 | 122 | 123 | class ContentBlockStartToolUse(TypedDict): 124 | """The start of a tool use block. 125 | 126 | Attributes: 127 | name: The name of the tool that the model is requesting to use. 128 | toolUseId: The ID for the tool request. 129 | """ 130 | 131 | name: str 132 | toolUseId: str 133 | 134 | 135 | class ContentBlockStart(TypedDict, total=False): 136 | """Content block start information. 137 | 138 | Attributes: 139 | toolUse: Information about a tool that the model is requesting to use. 140 | """ 141 | 142 | toolUse: Optional[ContentBlockStartToolUse] 143 | 144 | 145 | class ContentBlockDelta(TypedDict): 146 | """The content block delta event. 147 | 148 | Attributes: 149 | contentBlockIndex: The block index for a content block delta event. 150 | delta: The delta for a content block delta event. 151 | """ 152 | 153 | contentBlockIndex: int 154 | delta: DeltaContent 155 | 156 | 157 | class ContentBlockStop(TypedDict): 158 | """A content block stop event. 159 | 160 | Attributes: 161 | contentBlockIndex: The index for a content block. 162 | """ 163 | 164 | contentBlockIndex: int 165 | 166 | 167 | Role = Literal["user", "assistant"] 168 | """Role of a message sender. 169 | 170 | - "user": Messages from the user to the assistant 171 | - "assistant": Messages from the assistant to the user 172 | """ 173 | 174 | 175 | class Message(TypedDict): 176 | """A message in a conversation with the agent. 177 | 178 | Attributes: 179 | content: The message content. 180 | role: The role of the message sender. 181 | """ 182 | 183 | content: List[ContentBlock] 184 | role: Role 185 | 186 | 187 | Messages = List[Message] 188 | """A list of messages representing a conversation.""" 189 | -------------------------------------------------------------------------------- /src/strands/types/event_loop.py: -------------------------------------------------------------------------------- 1 | """Event loop-related type definitions for the SDK.""" 2 | 3 | from typing import Any, Callable, Iterable, Iterator, Literal, Optional, Protocol 4 | 5 | from typing_extensions import TypedDict, runtime_checkable 6 | 7 | 8 | class Usage(TypedDict): 9 | """Token usage information for model interactions. 10 | 11 | Attributes: 12 | inputTokens: Number of tokens sent in the request to the model.. 13 | outputTokens: Number of tokens that the model generated for the request. 14 | totalTokens: Total number of tokens (input + output). 15 | """ 16 | 17 | inputTokens: int 18 | outputTokens: int 19 | totalTokens: int 20 | 21 | 22 | class Metrics(TypedDict): 23 | """Performance metrics for model interactions. 24 | 25 | Attributes: 26 | latencyMs (int): Latency of the model request in milliseconds. 27 | """ 28 | 29 | latencyMs: int 30 | 31 | 32 | StopReason = Literal[ 33 | "content_filtered", 34 | "end_turn", 35 | "guardrail_intervened", 36 | "max_tokens", 37 | "stop_sequence", 38 | "tool_use", 39 | ] 40 | """Reason for the model ending its response generation. 41 | 42 | - "content_filtered": Content was filtered due to policy violation 43 | - "end_turn": Normal completion of the response 44 | - "guardrail_intervened": Guardrail system intervened 45 | - "max_tokens": Maximum token limit reached 46 | - "stop_sequence": Stop sequence encountered 47 | - "tool_use": Model requested to use a tool 48 | """ 49 | 50 | 51 | @runtime_checkable 52 | class Future(Protocol): 53 | """Interface representing the result of an asynchronous computation.""" 54 | 55 | def result(self, timeout: Optional[int] = None) -> Any: 56 | """Return the result of the call that the future represents. 57 | 58 | This method will block until the asynchronous operation completes or until the specified timeout is reached. 59 | 60 | Args: 61 | timeout: The number of seconds to wait for the result. 62 | If None, then there is no limit on the wait time. 63 | 64 | Returns: 65 | Any: The result of the asynchronous operation. 66 | """ 67 | 68 | 69 | @runtime_checkable 70 | class ParallelToolExecutorInterface(Protocol): 71 | """Interface for parallel tool execution. 72 | 73 | Attributes: 74 | timeout: Default timeout in seconds for futures. 75 | """ 76 | 77 | timeout: int = 900 # default 15 minute timeout for futures 78 | 79 | def submit(self, fn: Callable[..., Any], /, *args: Any, **kwargs: Any) -> Future: 80 | """Submit a callable to be executed with the given arguments. 81 | 82 | Schedules the callable to be executed as fn(*args, **kwargs) and returns a Future instance representing the 83 | execution of the callable. 84 | 85 | Args: 86 | fn: The callable to execute. 87 | *args: Positional arguments to pass to the callable. 88 | **kwargs: Keyword arguments to pass to the callable. 89 | 90 | Returns: 91 | Future: A Future representing the given call. 92 | """ 93 | 94 | def as_completed(self, futures: Iterable[Future], timeout: Optional[int] = timeout) -> Iterator[Future]: 95 | """Iterate over the given futures, yielding each as it completes. 96 | 97 | Args: 98 | futures: The sequence of Futures to iterate over. 99 | timeout: The maximum number of seconds to wait. 100 | If None, then there is no limit on the wait time. 101 | 102 | Returns: 103 | An iterator that yields the given Futures as they complete (finished or cancelled). 104 | """ 105 | 106 | def shutdown(self, wait: bool = True) -> None: 107 | """Shutdown the executor and free associated resources. 108 | 109 | Args: 110 | wait: If True, shutdown will not return until all running futures have finished executing. 111 | """ 112 | -------------------------------------------------------------------------------- /src/strands/types/exceptions.py: -------------------------------------------------------------------------------- 1 | """Exception-related type definitions for the SDK.""" 2 | 3 | from typing import Any 4 | 5 | 6 | class EventLoopException(Exception): 7 | """Exception raised by the event loop.""" 8 | 9 | def __init__(self, original_exception: Exception, request_state: Any = None) -> None: 10 | """Initialize exception. 11 | 12 | Args: 13 | original_exception: The original exception that was raised. 14 | request_state: The state of the request at the time of the exception. 15 | """ 16 | self.original_exception = original_exception 17 | self.request_state = request_state if request_state is not None else {} 18 | super().__init__(str(original_exception)) 19 | 20 | 21 | class ContextWindowOverflowException(Exception): 22 | """Exception raised when the context window is exceeded. 23 | 24 | This exception is raised when the input to a model exceeds the maximum context window size that the model can 25 | handle. This typically occurs when the combined length of the conversation history, system prompt, and current 26 | message is too large for the model to process. 27 | """ 28 | 29 | pass 30 | 31 | 32 | class MCPClientInitializationError(Exception): 33 | """Raised when the MCP server fails to initialize properly.""" 34 | 35 | pass 36 | 37 | 38 | class ModelThrottledException(Exception): 39 | """Exception raised when the model is throttled. 40 | 41 | This exception is raised when the model is throttled by the service. This typically occurs when the service is 42 | throttling the requests from the client. 43 | """ 44 | 45 | def __init__(self, message: str) -> None: 46 | """Initialize exception. 47 | 48 | Args: 49 | message: The message from the service that describes the throttling. 50 | """ 51 | self.message = message 52 | super().__init__(message) 53 | 54 | pass 55 | -------------------------------------------------------------------------------- /src/strands/types/media.py: -------------------------------------------------------------------------------- 1 | """Media-related type definitions for the SDK. 2 | 3 | These types are modeled after the Bedrock API. 4 | 5 | - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html 6 | """ 7 | 8 | from typing import Literal 9 | 10 | from typing_extensions import TypedDict 11 | 12 | DocumentFormat = Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"] 13 | """Supported document formats.""" 14 | 15 | 16 | class DocumentSource(TypedDict): 17 | """Contains the content of a document. 18 | 19 | Attributes: 20 | bytes: The binary content of the document. 21 | """ 22 | 23 | bytes: bytes 24 | 25 | 26 | class DocumentContent(TypedDict): 27 | """A document to include in a message. 28 | 29 | Attributes: 30 | format: The format of the document (e.g., "pdf", "txt"). 31 | name: The name of the document. 32 | source: The source containing the document's binary content. 33 | """ 34 | 35 | format: Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"] 36 | name: str 37 | source: DocumentSource 38 | 39 | 40 | ImageFormat = Literal["png", "jpeg", "gif", "webp"] 41 | """Supported image formats.""" 42 | 43 | 44 | class ImageSource(TypedDict): 45 | """Contains the content of an image. 46 | 47 | Attributes: 48 | bytes: The binary content of the image. 49 | """ 50 | 51 | bytes: bytes 52 | 53 | 54 | class ImageContent(TypedDict): 55 | """An image to include in a message. 56 | 57 | Attributes: 58 | format: The format of the image (e.g., "png", "jpeg"). 59 | source: The source containing the image's binary content. 60 | """ 61 | 62 | format: ImageFormat 63 | source: ImageSource 64 | 65 | 66 | VideoFormat = Literal["flv", "mkv", "mov", "mpeg", "mpg", "mp4", "three_gp", "webm", "wmv"] 67 | """Supported video formats.""" 68 | 69 | 70 | class VideoSource(TypedDict): 71 | """Contains the content of a video. 72 | 73 | Attributes: 74 | bytes: The binary content of the video. 75 | """ 76 | 77 | bytes: bytes 78 | 79 | 80 | class VideoContent(TypedDict): 81 | """A video to include in a message. 82 | 83 | Attributes: 84 | format: The format of the video (e.g., "mp4", "avi"). 85 | source: The source containing the video's binary content. 86 | """ 87 | 88 | format: VideoFormat 89 | source: VideoSource 90 | -------------------------------------------------------------------------------- /src/strands/types/models/__init__.py: -------------------------------------------------------------------------------- 1 | """Model-related type definitions for the SDK.""" 2 | 3 | from .model import Model 4 | from .openai import OpenAIModel 5 | 6 | __all__ = ["Model", "OpenAIModel"] 7 | -------------------------------------------------------------------------------- /src/strands/types/models/model.py: -------------------------------------------------------------------------------- 1 | """Model-related type definitions for the SDK.""" 2 | 3 | import abc 4 | import logging 5 | from typing import Any, Iterable, Optional 6 | 7 | from ..content import Messages 8 | from ..streaming import StreamEvent 9 | from ..tools import ToolSpec 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class Model(abc.ABC): 15 | """Abstract base class for AI model implementations. 16 | 17 | This class defines the interface for all model implementations in the Strands Agents SDK. It provides a 18 | standardized way to configure, format, and process requests for different AI model providers. 19 | """ 20 | 21 | @abc.abstractmethod 22 | # pragma: no cover 23 | def update_config(self, **model_config: Any) -> None: 24 | """Update the model configuration with the provided arguments. 25 | 26 | Args: 27 | **model_config: Configuration overrides. 28 | """ 29 | pass 30 | 31 | @abc.abstractmethod 32 | # pragma: no cover 33 | def get_config(self) -> Any: 34 | """Return the model configuration. 35 | 36 | Returns: 37 | The model's configuration. 38 | """ 39 | pass 40 | 41 | @abc.abstractmethod 42 | # pragma: no cover 43 | def format_request( 44 | self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None 45 | ) -> Any: 46 | """Format a streaming request to the underlying model. 47 | 48 | Args: 49 | messages: List of message objects to be processed by the model. 50 | tool_specs: List of tool specifications to make available to the model. 51 | system_prompt: System prompt to provide context to the model. 52 | 53 | Returns: 54 | The formatted request. 55 | """ 56 | pass 57 | 58 | @abc.abstractmethod 59 | # pragma: no cover 60 | def format_chunk(self, event: Any) -> StreamEvent: 61 | """Format the model response events into standardized message chunks. 62 | 63 | Args: 64 | event: A response event from the model. 65 | 66 | Returns: 67 | The formatted chunk. 68 | """ 69 | pass 70 | 71 | @abc.abstractmethod 72 | # pragma: no cover 73 | def stream(self, request: Any) -> Iterable[Any]: 74 | """Send the request to the model and get a streaming response. 75 | 76 | Args: 77 | request: The formatted request to send to the model. 78 | 79 | Returns: 80 | The model's response. 81 | 82 | Raises: 83 | ModelThrottledException: When the model service is throttling requests from the client. 84 | """ 85 | pass 86 | 87 | def converse( 88 | self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None 89 | ) -> Iterable[StreamEvent]: 90 | """Converse with the model. 91 | 92 | This method handles the full lifecycle of conversing with the model: 93 | 1. Format the messages, tool specs, and configuration into a streaming request 94 | 2. Send the request to the model 95 | 3. Yield the formatted message chunks 96 | 97 | Args: 98 | messages: List of message objects to be processed by the model. 99 | tool_specs: List of tool specifications to make available to the model. 100 | system_prompt: System prompt to provide context to the model. 101 | 102 | Yields: 103 | Formatted message chunks from the model. 104 | 105 | Raises: 106 | ModelThrottledException: When the model service is throttling requests from the client. 107 | """ 108 | logger.debug("formatting request") 109 | request = self.format_request(messages, tool_specs, system_prompt) 110 | 111 | logger.debug("invoking model") 112 | response = self.stream(request) 113 | 114 | logger.debug("got response from model") 115 | for event in response: 116 | yield self.format_chunk(event) 117 | 118 | logger.debug("finished streaming response from model") 119 | -------------------------------------------------------------------------------- /src/strands/types/streaming.py: -------------------------------------------------------------------------------- 1 | """Streaming-related type definitions for the SDK. 2 | 3 | These types are modeled after the Bedrock API. 4 | 5 | - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html 6 | """ 7 | 8 | from typing import Optional, Union 9 | 10 | from typing_extensions import TypedDict 11 | 12 | from .content import ContentBlockStart, Role 13 | from .event_loop import Metrics, StopReason, Usage 14 | from .guardrails import Trace 15 | 16 | 17 | class MessageStartEvent(TypedDict): 18 | """Event signaling the start of a message in a streaming response. 19 | 20 | Attributes: 21 | role: The role of the message sender (e.g., "assistant", "user"). 22 | """ 23 | 24 | role: Role 25 | 26 | 27 | class ContentBlockStartEvent(TypedDict, total=False): 28 | """Event signaling the start of a content block in a streaming response. 29 | 30 | Attributes: 31 | contentBlockIndex: Index of the content block within the message. 32 | This is optional to accommodate different model providers. 33 | start: Information about the content block being started. 34 | """ 35 | 36 | contentBlockIndex: Optional[int] 37 | start: ContentBlockStart 38 | 39 | 40 | class ContentBlockDeltaText(TypedDict): 41 | """Text content delta in a streaming response. 42 | 43 | Attributes: 44 | text: The text fragment being streamed. 45 | """ 46 | 47 | text: str 48 | 49 | 50 | class ContentBlockDeltaToolUse(TypedDict): 51 | """Tool use input delta in a streaming response. 52 | 53 | Attributes: 54 | input: The tool input fragment being streamed. 55 | """ 56 | 57 | input: str 58 | 59 | 60 | class ReasoningContentBlockDelta(TypedDict, total=False): 61 | """Delta for reasoning content block in a streaming response. 62 | 63 | Attributes: 64 | redactedContent: The content in the reasoning that was encrypted by the model provider for safety reasons. 65 | signature: A token that verifies that the reasoning text was generated by the model. 66 | text: The reasoning that the model used to return the output. 67 | """ 68 | 69 | redactedContent: Optional[bytes] 70 | signature: Optional[str] 71 | text: Optional[str] 72 | 73 | 74 | class ContentBlockDelta(TypedDict, total=False): 75 | """A block of content in a streaming response. 76 | 77 | Attributes: 78 | reasoningContent: Contains content regarding the reasoning that is carried out by the model. 79 | text: Text fragment being streamed. 80 | toolUse: Tool use input fragment being streamed. 81 | """ 82 | 83 | reasoningContent: ReasoningContentBlockDelta 84 | text: str 85 | toolUse: ContentBlockDeltaToolUse 86 | 87 | 88 | class ContentBlockDeltaEvent(TypedDict, total=False): 89 | """Event containing a delta update for a content block in a streaming response. 90 | 91 | Attributes: 92 | contentBlockIndex: Index of the content block within the message. 93 | This is optional to accommodate different model providers. 94 | delta: The incremental content update for the content block. 95 | """ 96 | 97 | contentBlockIndex: Optional[int] 98 | delta: ContentBlockDelta 99 | 100 | 101 | class ContentBlockStopEvent(TypedDict, total=False): 102 | """Event signaling the end of a content block in a streaming response. 103 | 104 | Attributes: 105 | contentBlockIndex: Index of the content block within the message. 106 | This is optional to accommodate different model providers. 107 | """ 108 | 109 | contentBlockIndex: Optional[int] 110 | 111 | 112 | class MessageStopEvent(TypedDict, total=False): 113 | """Event signaling the end of a message in a streaming response. 114 | 115 | Attributes: 116 | additionalModelResponseFields: Additional fields to include in model response. 117 | stopReason: The reason why the model stopped generating content. 118 | """ 119 | 120 | additionalModelResponseFields: Optional[Union[dict, list, int, float, str, bool, None]] 121 | stopReason: StopReason 122 | 123 | 124 | class MetadataEvent(TypedDict, total=False): 125 | """Event containing metadata about the streaming response. 126 | 127 | Attributes: 128 | metrics: Performance metrics related to the model invocation. 129 | trace: Trace information for debugging and monitoring. 130 | usage: Resource usage information for the model invocation. 131 | """ 132 | 133 | metrics: Metrics 134 | trace: Optional[Trace] 135 | usage: Usage 136 | 137 | 138 | class ExceptionEvent(TypedDict): 139 | """Base event for exceptions in a streaming response. 140 | 141 | Attributes: 142 | message: The error message describing what went wrong. 143 | """ 144 | 145 | message: str 146 | 147 | 148 | class ModelStreamErrorEvent(ExceptionEvent): 149 | """Event for model streaming errors. 150 | 151 | Attributes: 152 | originalMessage: The original error message from the model provider. 153 | originalStatusCode: The HTTP status code returned by the model provider. 154 | """ 155 | 156 | originalMessage: str 157 | originalStatusCode: int 158 | 159 | 160 | class RedactContentEvent(TypedDict, total=False): 161 | """Event for redacting content. 162 | 163 | Attributes: 164 | redactUserContentMessage: The string to overwrite the users input with. 165 | redactAssistantContentMessage: The string to overwrite the assistants output with. 166 | 167 | """ 168 | 169 | redactUserContentMessage: Optional[str] 170 | redactAssistantContentMessage: Optional[str] 171 | 172 | 173 | class StreamEvent(TypedDict, total=False): 174 | """The messages output stream. 175 | 176 | Attributes: 177 | contentBlockDelta: Delta content for a content block. 178 | contentBlockStart: Start of a content block. 179 | contentBlockStop: End of a content block. 180 | internalServerException: Internal server error information. 181 | messageStart: Start of a message. 182 | messageStop: End of a message. 183 | metadata: Metadata about the streaming response. 184 | modelStreamErrorException: Model streaming error information. 185 | serviceUnavailableException: Service unavailable error information. 186 | throttlingException: Throttling error information. 187 | validationException: Validation error information. 188 | """ 189 | 190 | contentBlockDelta: ContentBlockDeltaEvent 191 | contentBlockStart: ContentBlockStartEvent 192 | contentBlockStop: ContentBlockStopEvent 193 | internalServerException: ExceptionEvent 194 | messageStart: MessageStartEvent 195 | messageStop: MessageStopEvent 196 | metadata: MetadataEvent 197 | redactContent: RedactContentEvent 198 | modelStreamErrorException: ModelStreamErrorEvent 199 | serviceUnavailableException: ExceptionEvent 200 | throttlingException: ExceptionEvent 201 | validationException: ExceptionEvent 202 | -------------------------------------------------------------------------------- /src/strands/types/traces.py: -------------------------------------------------------------------------------- 1 | """Tracing type definitions for the SDK.""" 2 | 3 | from typing import List, Union 4 | 5 | AttributeValue = Union[str, bool, float, int, List[str], List[bool], List[float], List[int]] 6 | -------------------------------------------------------------------------------- /tests-integ/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/strands-agents/sdk-python/a64e80dfa5f5f6053a3fe53767f59fc3c5c1af95/tests-integ/__init__.py -------------------------------------------------------------------------------- /tests-integ/echo_server.py: -------------------------------------------------------------------------------- 1 | """ 2 | Echo Server for MCP Integration Testing 3 | 4 | This module implements a simple echo server using the Model Context Protocol (MCP). 5 | It provides a basic tool that echoes back any input string, which is useful for 6 | testing the MCP communication flow and validating that messages are properly 7 | transmitted between the client and server. 8 | 9 | The server runs with stdio transport, making it suitable for integration tests 10 | where the client can spawn this process and communicate with it through standard 11 | input/output streams. 12 | 13 | Usage: 14 | Run this file directly to start the echo server: 15 | $ python echo_server.py 16 | """ 17 | 18 | from mcp.server import FastMCP 19 | 20 | 21 | def start_echo_server(): 22 | """ 23 | Initialize and start the MCP echo server. 24 | 25 | Creates a FastMCP server instance with a single 'echo' tool that returns 26 | any input string back to the caller. The server uses stdio transport 27 | for communication. 28 | """ 29 | mcp = FastMCP("Echo Server") 30 | 31 | @mcp.tool(description="Echos response back to the user") 32 | def echo(to_echo: str) -> str: 33 | return to_echo 34 | 35 | mcp.run(transport="stdio") 36 | 37 | 38 | if __name__ == "__main__": 39 | start_echo_server() 40 | -------------------------------------------------------------------------------- /tests-integ/test_bedrock_cache_point.py: -------------------------------------------------------------------------------- 1 | from strands import Agent 2 | from strands.types.content import Messages 3 | 4 | 5 | def test_bedrock_cache_point(): 6 | messages: Messages = [ 7 | { 8 | "role": "user", 9 | "content": [ 10 | { 11 | "text": "Some really long text!" * 1000 # Minimum token count for cachePoint is 1024 tokens 12 | }, 13 | {"cachePoint": {"type": "default"}}, 14 | ], 15 | }, 16 | {"role": "assistant", "content": [{"text": "Blue!"}]}, 17 | ] 18 | 19 | cache_point_usage = 0 20 | 21 | def cache_point_callback_handler(**kwargs): 22 | nonlocal cache_point_usage 23 | if "event" in kwargs and kwargs["event"] and "metadata" in kwargs["event"] and kwargs["event"]["metadata"]: 24 | metadata = kwargs["event"]["metadata"] 25 | if "usage" in metadata and metadata["usage"]: 26 | if "cacheReadInputTokens" in metadata["usage"] or "cacheWriteInputTokens" in metadata["usage"]: 27 | cache_point_usage += 1 28 | 29 | agent = Agent(messages=messages, callback_handler=cache_point_callback_handler, load_tools_from_directory=False) 30 | agent("What is favorite color?") 31 | assert cache_point_usage > 0 32 | -------------------------------------------------------------------------------- /tests-integ/test_bedrock_guardrails.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import boto3 4 | import pytest 5 | 6 | from strands import Agent 7 | from strands.models.bedrock import BedrockModel 8 | 9 | BLOCKED_INPUT = "BLOCKED_INPUT" 10 | BLOCKED_OUTPUT = "BLOCKED_OUTPUT" 11 | 12 | 13 | @pytest.fixture(scope="module") 14 | def boto_session(): 15 | return boto3.Session(region_name="us-west-2") 16 | 17 | 18 | @pytest.fixture(scope="module") 19 | def bedrock_guardrail(boto_session): 20 | """ 21 | Fixture that creates a guardrail before tests if it doesn't already exist." 22 | """ 23 | 24 | client = boto_session.client("bedrock") 25 | 26 | guardrail_name = "test-guardrail-block-cactus" 27 | guardrail_id = get_guardrail_id(client, guardrail_name) 28 | 29 | if guardrail_id: 30 | print(f"Guardrail {guardrail_name} already exists with ID: {guardrail_id}") 31 | else: 32 | print(f"Creating guardrail {guardrail_name}") 33 | response = client.create_guardrail( 34 | name=guardrail_name, 35 | description="Testing Guardrail", 36 | wordPolicyConfig={ 37 | "wordsConfig": [ 38 | { 39 | "text": "CACTUS", 40 | "inputAction": "BLOCK", 41 | "outputAction": "BLOCK", 42 | "inputEnabled": True, 43 | "outputEnabled": True, 44 | }, 45 | ], 46 | }, 47 | blockedInputMessaging=BLOCKED_INPUT, 48 | blockedOutputsMessaging=BLOCKED_OUTPUT, 49 | ) 50 | guardrail_id = response.get("guardrailId") 51 | print(f"Created test guardrail with ID: {guardrail_id}") 52 | wait_for_guardrail_active(client, guardrail_id) 53 | return guardrail_id 54 | 55 | 56 | def get_guardrail_id(client, guardrail_name): 57 | """ 58 | Retrieves the ID of a guardrail by its name. 59 | 60 | Args: 61 | client: The Bedrock client instance 62 | guardrail_name: Name of the guardrail to look up 63 | 64 | Returns: 65 | str: The ID of the guardrail if found, None otherwise 66 | """ 67 | response = client.list_guardrails() 68 | for guardrail in response.get("guardrails", []): 69 | if guardrail["name"] == guardrail_name: 70 | return guardrail["id"] 71 | return None 72 | 73 | 74 | def wait_for_guardrail_active(bedrock_client, guardrail_id, max_attempts=10, delay=5): 75 | """ 76 | Wait for the guardrail to become active 77 | """ 78 | for _ in range(max_attempts): 79 | response = bedrock_client.get_guardrail(guardrailIdentifier=guardrail_id) 80 | status = response.get("status") 81 | 82 | if status == "READY": 83 | print(f"Guardrail {guardrail_id} is now active") 84 | return True 85 | 86 | print(f"Waiting for guardrail to become active. Current status: {status}") 87 | time.sleep(delay) 88 | 89 | print(f"Guardrail did not become active within {max_attempts * delay} seconds.") 90 | raise RuntimeError("Guardrail did not become active.") 91 | 92 | 93 | def test_guardrail_input_intervention(boto_session, bedrock_guardrail): 94 | bedrock_model = BedrockModel( 95 | guardrail_id=bedrock_guardrail, 96 | guardrail_version="DRAFT", 97 | boto_session=boto_session, 98 | ) 99 | 100 | agent = Agent(model=bedrock_model, system_prompt="You are a helpful assistant.", callback_handler=None) 101 | 102 | response1 = agent("CACTUS") 103 | response2 = agent("Hello!") 104 | 105 | assert response1.stop_reason == "guardrail_intervened" 106 | assert str(response1).strip() == BLOCKED_INPUT 107 | assert response2.stop_reason != "guardrail_intervened" 108 | assert str(response2).strip() != BLOCKED_INPUT 109 | 110 | 111 | @pytest.mark.parametrize("processing_mode", ["sync", "async"]) 112 | def test_guardrail_output_intervention(boto_session, bedrock_guardrail, processing_mode): 113 | bedrock_model = BedrockModel( 114 | guardrail_id=bedrock_guardrail, 115 | guardrail_version="DRAFT", 116 | guardrail_redact_output=False, 117 | guardrail_stream_processing_mode=processing_mode, 118 | boto_session=boto_session, 119 | ) 120 | 121 | agent = Agent( 122 | model=bedrock_model, 123 | system_prompt="When asked to say the word, say CACTUS.", 124 | callback_handler=None, 125 | load_tools_from_directory=False, 126 | ) 127 | 128 | response1 = agent("Say the word.") 129 | response2 = agent("Hello!") 130 | assert response1.stop_reason == "guardrail_intervened" 131 | assert BLOCKED_OUTPUT in str(response1) 132 | assert response2.stop_reason != "guardrail_intervened" 133 | assert BLOCKED_OUTPUT not in str(response2) 134 | 135 | 136 | @pytest.mark.parametrize("processing_mode", ["sync", "async"]) 137 | def test_guardrail_output_intervention_redact_output(bedrock_guardrail, processing_mode): 138 | REDACT_MESSAGE = "Redacted." 139 | bedrock_model = BedrockModel( 140 | guardrail_id=bedrock_guardrail, 141 | guardrail_version="DRAFT", 142 | guardrail_stream_processing_mode=processing_mode, 143 | guardrail_redact_output=True, 144 | guardrail_redact_output_message=REDACT_MESSAGE, 145 | region_name="us-west-2", 146 | ) 147 | 148 | agent = Agent( 149 | model=bedrock_model, 150 | system_prompt="When asked to say the word, say CACTUS.", 151 | callback_handler=None, 152 | load_tools_from_directory=False, 153 | ) 154 | 155 | response1 = agent("Say the word.") 156 | response2 = agent("Hello!") 157 | assert response1.stop_reason == "guardrail_intervened" 158 | assert REDACT_MESSAGE in str(response1) 159 | assert response2.stop_reason != "guardrail_intervened" 160 | assert REDACT_MESSAGE not in str(response2) 161 | -------------------------------------------------------------------------------- /tests-integ/test_context_overflow.py: -------------------------------------------------------------------------------- 1 | from strands import Agent 2 | from strands.types.content import Messages 3 | 4 | 5 | def test_context_window_overflow(): 6 | messages: Messages = [ 7 | {"role": "user", "content": [{"text": "Too much text!" * 100000}]}, 8 | {"role": "assistant", "content": [{"text": "That was a lot of text!"}]}, 9 | ] 10 | 11 | agent = Agent(messages=messages, load_tools_from_directory=False) 12 | agent("Hi!") 13 | assert len(agent.messages) == 2 14 | -------------------------------------------------------------------------------- /tests-integ/test_function_tools.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Test script for function-based tools 4 | """ 5 | 6 | import logging 7 | from typing import Optional 8 | 9 | from strands import Agent, tool 10 | 11 | logging.getLogger("strands").setLevel(logging.DEBUG) 12 | logging.basicConfig(format="%(levelname)s | %(name)s | %(message)s", handlers=[logging.StreamHandler()]) 13 | 14 | 15 | @tool 16 | def word_counter(text: str) -> str: 17 | """ 18 | Count words in text. 19 | 20 | Args: 21 | text: Text to analyze 22 | """ 23 | count = len(text.split()) 24 | return f"Word count: {count}" 25 | 26 | 27 | @tool(name="count_chars", description="Count characters in text") 28 | def count_chars(text: str, include_spaces: Optional[bool] = True) -> str: 29 | """ 30 | Count characters in text. 31 | 32 | Args: 33 | text: Text to analyze 34 | include_spaces: Whether to include spaces in the count 35 | """ 36 | if not include_spaces: 37 | text = text.replace(" ", "") 38 | return f"Character count: {len(text)}" 39 | 40 | 41 | # Initialize agent with function tools 42 | agent = Agent(tools=[word_counter, count_chars]) 43 | 44 | print("\n===== Testing Direct Tool Access =====") 45 | # Use the tools directly 46 | word_result = agent.tool.word_counter(text="Hello world, this is a test") 47 | print(f"\nWord counter result: {word_result}") 48 | 49 | char_result = agent.tool.count_chars(text="Hello world!", include_spaces=False) 50 | print(f"\nCharacter counter result: {char_result}") 51 | 52 | print("\n===== Testing Natural Language Access =====") 53 | # Use through natural language 54 | nl_result = agent("Count the words in this sentence: 'The quick brown fox jumps over the lazy dog'") 55 | print(f"\nNL Result: {nl_result}") 56 | -------------------------------------------------------------------------------- /tests-integ/test_hot_tool_reload_decorator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Integration test for hot tool reloading functionality with the @tool decorator. 3 | 4 | This test verifies that the Strands Agent can automatically detect and load 5 | new tools created with the @tool decorator when they are added to a tools directory. 6 | """ 7 | 8 | import logging 9 | import os 10 | import time 11 | from pathlib import Path 12 | 13 | from strands import Agent 14 | 15 | logging.getLogger("strands").setLevel(logging.DEBUG) 16 | logging.basicConfig(format="%(levelname)s | %(name)s | %(message)s", handlers=[logging.StreamHandler()]) 17 | 18 | 19 | def test_hot_reload_decorator(): 20 | """ 21 | Test that the Agent automatically loads tools created with @tool decorator 22 | when added to the current working directory's tools folder. 23 | """ 24 | # Set up the tools directory in current working directory 25 | tools_dir = Path.cwd() / "tools" 26 | os.makedirs(tools_dir, exist_ok=True) 27 | 28 | # Tool path that will need cleanup 29 | test_tool_path = tools_dir / "uppercase.py" 30 | 31 | try: 32 | # Create an Agent instance without any tools 33 | agent = Agent() 34 | 35 | # Create a test tool using @tool decorator 36 | with open(test_tool_path, "w") as f: 37 | f.write(""" 38 | from strands import tool 39 | 40 | @tool 41 | def uppercase(text: str) -> str: 42 | \"\"\"Convert text to uppercase.\"\"\" 43 | return f"Input: {text}, Output: {text.upper()}" 44 | """) 45 | 46 | # Wait for tool detection 47 | time.sleep(3) 48 | 49 | # Verify the tool was automatically loaded 50 | assert "uppercase" in agent.tool_names, "Agent should have detected and loaded the uppercase tool" 51 | 52 | # Test calling the dynamically loaded tool 53 | result = agent.tool.uppercase(text="hello world") 54 | 55 | # Check that the result is successful 56 | assert result.get("status") == "success", "Tool call should be successful" 57 | 58 | # Check the content of the response 59 | content_list = result.get("content", []) 60 | assert len(content_list) > 0, "Tool response should have content" 61 | 62 | # Check that the expected message is in the content 63 | text_content = next((item.get("text") for item in content_list if "text" in item), "") 64 | assert "Input: hello world, Output: HELLO WORLD" in text_content 65 | 66 | finally: 67 | # Clean up - remove the test file 68 | if test_tool_path.exists(): 69 | os.remove(test_tool_path) 70 | 71 | 72 | def test_hot_reload_decorator_update(): 73 | """ 74 | Test that the Agent detects updates to tools created with @tool decorator. 75 | """ 76 | # Set up the tools directory in current working directory 77 | tools_dir = Path.cwd() / "tools" 78 | os.makedirs(tools_dir, exist_ok=True) 79 | 80 | # Tool path that will need cleanup - make sure filename matches function name 81 | test_tool_path = tools_dir / "greeting.py" 82 | 83 | try: 84 | # Create an Agent instance 85 | agent = Agent() 86 | 87 | # Create the initial version of the tool 88 | with open(test_tool_path, "w") as f: 89 | f.write(""" 90 | from strands import tool 91 | 92 | @tool 93 | def greeting(name: str) -> str: 94 | \"\"\"Generate a simple greeting.\"\"\" 95 | return f"Hello, {name}!" 96 | """) 97 | 98 | # Wait for tool detection 99 | time.sleep(3) 100 | 101 | # Verify the tool was loaded 102 | assert "greeting" in agent.tool_names, "Agent should have detected and loaded the greeting tool" 103 | 104 | # Test calling the tool 105 | result1 = agent.tool.greeting(name="Strands") 106 | text_content1 = next((item.get("text") for item in result1.get("content", []) if "text" in item), "") 107 | assert "Hello, Strands!" in text_content1, "Tool should return simple greeting" 108 | 109 | # Update the tool with new functionality 110 | with open(test_tool_path, "w") as f: 111 | f.write(""" 112 | from strands import tool 113 | import datetime 114 | 115 | @tool 116 | def greeting(name: str, formal: bool = False) -> str: 117 | \"\"\"Generate a greeting with optional formality.\"\"\" 118 | current_hour = datetime.datetime.now().hour 119 | time_of_day = "morning" if current_hour < 12 else "afternoon" if current_hour < 18 else "evening" 120 | 121 | if formal: 122 | return f"Good {time_of_day}, {name}. It's a pleasure to meet you." 123 | else: 124 | return f"Hey {name}! How's your {time_of_day} going?" 125 | """) 126 | 127 | # Wait for hot reload to detect the change 128 | time.sleep(3) 129 | 130 | # Test calling the updated tool 131 | result2 = agent.tool.greeting(name="Strands", formal=True) 132 | text_content2 = next((item.get("text") for item in result2.get("content", []) if "text" in item), "") 133 | assert "Good" in text_content2 and "Strands" in text_content2 and "pleasure to meet you" in text_content2 134 | 135 | # Test with informal parameter 136 | result3 = agent.tool.greeting(name="Strands", formal=False) 137 | text_content3 = next((item.get("text") for item in result3.get("content", []) if "text" in item), "") 138 | assert "Hey Strands!" in text_content3 and "going" in text_content3 139 | 140 | finally: 141 | # Clean up - remove the test file 142 | if test_tool_path.exists(): 143 | os.remove(test_tool_path) 144 | -------------------------------------------------------------------------------- /tests-integ/test_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/strands-agents/sdk-python/a64e80dfa5f5f6053a3fe53767f59fc3c5c1af95/tests-integ/test_image.png -------------------------------------------------------------------------------- /tests-integ/test_mcp_client.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import threading 3 | import time 4 | from typing import List, Literal 5 | 6 | from mcp import StdioServerParameters, stdio_client 7 | from mcp.client.sse import sse_client 8 | from mcp.client.streamable_http import streamablehttp_client 9 | from mcp.types import ImageContent as MCPImageContent 10 | 11 | from strands import Agent 12 | from strands.tools.mcp.mcp_client import MCPClient 13 | from strands.tools.mcp.mcp_types import MCPTransport 14 | from strands.types.content import Message 15 | from strands.types.tools import ToolUse 16 | 17 | 18 | def start_calculator_server(transport: Literal["sse", "streamable-http"], port=int): 19 | """ 20 | Initialize and start an MCP calculator server for integration testing. 21 | 22 | This function creates a FastMCP server instance that provides a simple 23 | calculator tool for performing addition operations. The server uses 24 | Server-Sent Events (SSE) transport for communication, making it accessible 25 | over HTTP. 26 | """ 27 | from mcp.server import FastMCP 28 | 29 | mcp = FastMCP("Calculator Server", port=port) 30 | 31 | @mcp.tool(description="Calculator tool which performs calculations") 32 | def calculator(x: int, y: int) -> int: 33 | return x + y 34 | 35 | @mcp.tool(description="Generates a custom image") 36 | def generate_custom_image() -> MCPImageContent: 37 | try: 38 | with open("tests-integ/test_image.png", "rb") as image_file: 39 | encoded_image = base64.b64encode(image_file.read()) 40 | return MCPImageContent(type="image", data=encoded_image, mimeType="image/png") 41 | except Exception as e: 42 | print("Error while generating custom image: {}".format(e)) 43 | 44 | mcp.run(transport=transport) 45 | 46 | 47 | def test_mcp_client(): 48 | """ 49 | Test should yield output similar to the following 50 | {'role': 'user', 'content': [{'text': 'add 1 and 2, then echo the result back to me'}]} 51 | {'role': 'assistant', 'content': [{'text': "I'll help you add 1 and 2 and then echo the result back to you.\n\nFirst, I'll calculate 1 + 2:"}, {'toolUse': {'toolUseId': 'tooluse_17ptaKUxQB20ySZxwgiI_w', 'name': 'calculator', 'input': {'x': 1, 'y': 2}}}]} 52 | {'role': 'user', 'content': [{'toolResult': {'status': 'success', 'toolUseId': 'tooluse_17ptaKUxQB20ySZxwgiI_w', 'content': [{'text': '3'}]}}]} 53 | {'role': 'assistant', 'content': [{'text': "\n\nNow I'll echo the result back to you:"}, {'toolUse': {'toolUseId': 'tooluse_GlOc5SN8TE6ti8jVZJMBOg', 'name': 'echo', 'input': {'to_echo': '3'}}}]} 54 | {'role': 'user', 'content': [{'toolResult': {'status': 'success', 'toolUseId': 'tooluse_GlOc5SN8TE6ti8jVZJMBOg', 'content': [{'text': '3'}]}}]} 55 | {'role': 'assistant', 'content': [{'text': '\n\nThe result of adding 1 and 2 is 3.'}]} 56 | """ # noqa: E501 57 | 58 | server_thread = threading.Thread( 59 | target=start_calculator_server, kwargs={"transport": "sse", "port": 8000}, daemon=True 60 | ) 61 | server_thread.start() 62 | time.sleep(2) # wait for server to startup completely 63 | 64 | sse_mcp_client = MCPClient(lambda: sse_client("http://127.0.0.1:8000/sse")) 65 | stdio_mcp_client = MCPClient( 66 | lambda: stdio_client(StdioServerParameters(command="python", args=["tests-integ/echo_server.py"])) 67 | ) 68 | with sse_mcp_client, stdio_mcp_client: 69 | agent = Agent(tools=sse_mcp_client.list_tools_sync() + stdio_mcp_client.list_tools_sync()) 70 | agent("add 1 and 2, then echo the result back to me") 71 | 72 | tool_use_content_blocks = _messages_to_content_blocks(agent.messages) 73 | assert any([block["name"] == "echo" for block in tool_use_content_blocks]) 74 | assert any([block["name"] == "calculator" for block in tool_use_content_blocks]) 75 | 76 | image_prompt = """ 77 | Generate a custom image, then tell me if the image is red, blue, yellow, pink, orange, or green. 78 | RESPOND ONLY WITH THE COLOR 79 | """ 80 | assert any( 81 | [ 82 | "yellow".casefold() in block["text"].casefold() 83 | for block in agent(image_prompt).message["content"] 84 | if "text" in block 85 | ] 86 | ) 87 | 88 | 89 | def test_can_reuse_mcp_client(): 90 | stdio_mcp_client = MCPClient( 91 | lambda: stdio_client(StdioServerParameters(command="python", args=["tests-integ/echo_server.py"])) 92 | ) 93 | with stdio_mcp_client: 94 | stdio_mcp_client.list_tools_sync() 95 | pass 96 | with stdio_mcp_client: 97 | agent = Agent(tools=stdio_mcp_client.list_tools_sync()) 98 | agent("echo the following to me DOG") 99 | 100 | tool_use_content_blocks = _messages_to_content_blocks(agent.messages) 101 | assert any([block["name"] == "echo" for block in tool_use_content_blocks]) 102 | 103 | 104 | def test_streamable_http_mcp_client(): 105 | server_thread = threading.Thread( 106 | target=start_calculator_server, kwargs={"transport": "streamable-http", "port": 8001}, daemon=True 107 | ) 108 | server_thread.start() 109 | time.sleep(2) # wait for server to startup completely 110 | 111 | def transport_callback() -> MCPTransport: 112 | return streamablehttp_client(url="http://127.0.0.1:8001/mcp") 113 | 114 | streamable_http_client = MCPClient(transport_callback) 115 | with streamable_http_client: 116 | agent = Agent(tools=streamable_http_client.list_tools_sync()) 117 | agent("add 1 and 2 using a calculator") 118 | 119 | tool_use_content_blocks = _messages_to_content_blocks(agent.messages) 120 | assert any([block["name"] == "calculator" for block in tool_use_content_blocks]) 121 | 122 | 123 | def _messages_to_content_blocks(messages: List[Message]) -> List[ToolUse]: 124 | return [block["toolUse"] for message in messages for block in message["content"] if "toolUse" in block] 125 | -------------------------------------------------------------------------------- /tests-integ/test_model_anthropic.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | import strands 6 | from strands import Agent 7 | from strands.models.anthropic import AnthropicModel 8 | 9 | 10 | @pytest.fixture 11 | def model(): 12 | return AnthropicModel( 13 | client_args={ 14 | "api_key": os.getenv("ANTHROPIC_API_KEY"), 15 | }, 16 | model_id="claude-3-7-sonnet-20250219", 17 | max_tokens=512, 18 | ) 19 | 20 | 21 | @pytest.fixture 22 | def tools(): 23 | @strands.tool 24 | def tool_time() -> str: 25 | return "12:00" 26 | 27 | @strands.tool 28 | def tool_weather() -> str: 29 | return "sunny" 30 | 31 | return [tool_time, tool_weather] 32 | 33 | 34 | @pytest.fixture 35 | def system_prompt(): 36 | return "You are an AI assistant that uses & instead of ." 37 | 38 | 39 | @pytest.fixture 40 | def agent(model, tools, system_prompt): 41 | return Agent(model=model, tools=tools, system_prompt=system_prompt) 42 | 43 | 44 | @pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing") 45 | def test_agent(agent): 46 | result = agent("What is the time and weather in New York?") 47 | text = result.message["content"][0]["text"].lower() 48 | 49 | assert all(string in text for string in ["12:00", "sunny", "&"]) 50 | -------------------------------------------------------------------------------- /tests-integ/test_model_bedrock.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import strands 4 | from strands import Agent 5 | from strands.models import BedrockModel 6 | 7 | 8 | @pytest.fixture 9 | def system_prompt(): 10 | return "You are an AI assistant that uses & instead of ." 11 | 12 | 13 | @pytest.fixture 14 | def streaming_model(): 15 | return BedrockModel( 16 | model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0", 17 | streaming=True, 18 | ) 19 | 20 | 21 | @pytest.fixture 22 | def non_streaming_model(): 23 | return BedrockModel( 24 | model_id="us.meta.llama3-2-90b-instruct-v1:0", 25 | streaming=False, 26 | ) 27 | 28 | 29 | @pytest.fixture 30 | def streaming_agent(streaming_model, system_prompt): 31 | return Agent(model=streaming_model, system_prompt=system_prompt, load_tools_from_directory=False) 32 | 33 | 34 | @pytest.fixture 35 | def non_streaming_agent(non_streaming_model, system_prompt): 36 | return Agent(model=non_streaming_model, system_prompt=system_prompt, load_tools_from_directory=False) 37 | 38 | 39 | def test_streaming_agent(streaming_agent): 40 | """Test agent with streaming model.""" 41 | result = streaming_agent("Hello!") 42 | 43 | assert len(str(result)) > 0 44 | 45 | 46 | def test_non_streaming_agent(non_streaming_agent): 47 | """Test agent with non-streaming model.""" 48 | result = non_streaming_agent("Hello!") 49 | 50 | assert len(str(result)) > 0 51 | 52 | 53 | def test_streaming_model_events(streaming_model): 54 | """Test streaming model events.""" 55 | messages = [{"role": "user", "content": [{"text": "Hello"}]}] 56 | 57 | # Call converse and collect events 58 | events = list(streaming_model.converse(messages)) 59 | 60 | # Verify basic structure of events 61 | assert any("messageStart" in event for event in events) 62 | assert any("contentBlockDelta" in event for event in events) 63 | assert any("messageStop" in event for event in events) 64 | 65 | 66 | def test_non_streaming_model_events(non_streaming_model): 67 | """Test non-streaming model events.""" 68 | messages = [{"role": "user", "content": [{"text": "Hello"}]}] 69 | 70 | # Call converse and collect events 71 | events = list(non_streaming_model.converse(messages)) 72 | 73 | # Verify basic structure of events 74 | assert any("messageStart" in event for event in events) 75 | assert any("contentBlockDelta" in event for event in events) 76 | assert any("messageStop" in event for event in events) 77 | 78 | 79 | def test_tool_use_streaming(streaming_model): 80 | """Test tool use with streaming model.""" 81 | 82 | tool_was_called = False 83 | 84 | @strands.tool 85 | def calculator(expression: str) -> float: 86 | """Calculate the result of a mathematical expression.""" 87 | 88 | nonlocal tool_was_called 89 | tool_was_called = True 90 | return eval(expression) 91 | 92 | agent = Agent(model=streaming_model, tools=[calculator], load_tools_from_directory=False) 93 | result = agent("What is 123 + 456?") 94 | 95 | # Print the full message content for debugging 96 | print("\nFull message content:") 97 | import json 98 | 99 | print(json.dumps(result.message["content"], indent=2)) 100 | 101 | assert tool_was_called 102 | 103 | 104 | def test_tool_use_non_streaming(non_streaming_model): 105 | """Test tool use with non-streaming model.""" 106 | 107 | tool_was_called = False 108 | 109 | @strands.tool 110 | def calculator(expression: str) -> float: 111 | """Calculate the result of a mathematical expression.""" 112 | 113 | nonlocal tool_was_called 114 | tool_was_called = True 115 | return eval(expression) 116 | 117 | agent = Agent(model=non_streaming_model, tools=[calculator], load_tools_from_directory=False) 118 | agent("What is 123 + 456?") 119 | 120 | assert tool_was_called 121 | -------------------------------------------------------------------------------- /tests-integ/test_model_litellm.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import strands 4 | from strands import Agent 5 | from strands.models.litellm import LiteLLMModel 6 | 7 | 8 | @pytest.fixture 9 | def model(): 10 | return LiteLLMModel(model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0") 11 | 12 | 13 | @pytest.fixture 14 | def tools(): 15 | @strands.tool 16 | def tool_time() -> str: 17 | return "12:00" 18 | 19 | @strands.tool 20 | def tool_weather() -> str: 21 | return "sunny" 22 | 23 | return [tool_time, tool_weather] 24 | 25 | 26 | @pytest.fixture 27 | def agent(model, tools): 28 | return Agent(model=model, tools=tools) 29 | 30 | 31 | def test_agent(agent): 32 | result = agent("What is the time and weather in New York?") 33 | text = result.message["content"][0]["text"].lower() 34 | 35 | assert all(string in text for string in ["12:00", "sunny"]) 36 | -------------------------------------------------------------------------------- /tests-integ/test_model_llamaapi.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | import os 3 | 4 | import pytest 5 | 6 | import strands 7 | from strands import Agent 8 | from strands.models.llamaapi import LlamaAPIModel 9 | 10 | 11 | @pytest.fixture 12 | def model(): 13 | return LlamaAPIModel( 14 | model_id="Llama-4-Maverick-17B-128E-Instruct-FP8", 15 | client_args={ 16 | "api_key": os.getenv("LLAMA_API_KEY"), 17 | }, 18 | ) 19 | 20 | 21 | @pytest.fixture 22 | def tools(): 23 | @strands.tool 24 | def tool_time() -> str: 25 | return "12:00" 26 | 27 | @strands.tool 28 | def tool_weather() -> str: 29 | return "sunny" 30 | 31 | return [tool_time, tool_weather] 32 | 33 | 34 | @pytest.fixture 35 | def agent(model, tools): 36 | return Agent(model=model, tools=tools) 37 | 38 | 39 | @pytest.mark.skipif( 40 | "LLAMA_API_KEY" not in os.environ, 41 | reason="LLAMA_API_KEY environment variable missing", 42 | ) 43 | def test_agent(agent): 44 | result = agent("What is the time and weather in New York?") 45 | text = result.message["content"][0]["text"].lower() 46 | 47 | assert all(string in text for string in ["12:00", "sunny"]) 48 | -------------------------------------------------------------------------------- /tests-integ/test_model_openai.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | import strands 6 | from strands import Agent 7 | from strands.models.openai import OpenAIModel 8 | 9 | 10 | @pytest.fixture 11 | def model(): 12 | return OpenAIModel( 13 | model_id="gpt-4o", 14 | client_args={ 15 | "api_key": os.getenv("OPENAI_API_KEY"), 16 | }, 17 | ) 18 | 19 | 20 | @pytest.fixture 21 | def tools(): 22 | @strands.tool 23 | def tool_time() -> str: 24 | return "12:00" 25 | 26 | @strands.tool 27 | def tool_weather() -> str: 28 | return "sunny" 29 | 30 | return [tool_time, tool_weather] 31 | 32 | 33 | @pytest.fixture 34 | def agent(model, tools): 35 | return Agent(model=model, tools=tools) 36 | 37 | 38 | @pytest.mark.skipif( 39 | "OPENAI_API_KEY" not in os.environ, 40 | reason="OPENAI_API_KEY environment variable missing", 41 | ) 42 | def test_agent(agent): 43 | result = agent("What is the time and weather in New York?") 44 | text = result.message["content"][0]["text"].lower() 45 | 46 | assert all(string in text for string in ["12:00", "sunny"]) 47 | -------------------------------------------------------------------------------- /tests-integ/test_stream_agent.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test script for Strands' custom callback handler functionality. 3 | Demonstrates different patterns of callback handling and processing. 4 | """ 5 | 6 | import logging 7 | 8 | from strands import Agent 9 | 10 | logging.getLogger("strands").setLevel(logging.DEBUG) 11 | logging.basicConfig(format="%(levelname)s | %(name)s | %(message)s", handlers=[logging.StreamHandler()]) 12 | 13 | 14 | class ToolCountingCallbackHandler: 15 | def __init__(self): 16 | self.tool_count = 0 17 | self.message_count = 0 18 | 19 | def callback_handler(self, **kwargs) -> None: 20 | """ 21 | Custom callback handler that processes and displays different types of events. 22 | 23 | Args: 24 | **kwargs: Callback event data including: 25 | - data: Regular output 26 | - complete: Completion status 27 | - message: Message processing 28 | - current_tool_use: Tool execution 29 | """ 30 | # Extract event data 31 | data = kwargs.get("data", "") 32 | complete = kwargs.get("complete", False) 33 | message = kwargs.get("message", {}) 34 | current_tool_use = kwargs.get("current_tool_use", {}) 35 | 36 | # Handle regular data output 37 | if data: 38 | print(f"🔄 Data: {data}") 39 | 40 | # Handle tool execution events 41 | if current_tool_use: 42 | self.tool_count += 1 43 | tool_name = current_tool_use.get("name", "") 44 | tool_input = current_tool_use.get("input", {}) 45 | print(f"🛠️ Tool Execution #{self.tool_count}\nTool: {tool_name}\nInput: {tool_input}") 46 | 47 | # Handle message processing 48 | if message: 49 | self.message_count += 1 50 | print(f"📝 Message #{self.message_count}") 51 | 52 | # Handle completion 53 | if complete: 54 | self.console.print("✨ Callback Complete", style="bold green") 55 | 56 | 57 | def test_basic_interaction(): 58 | """Test basic AGI interaction with custom callback handler.""" 59 | print("\nTesting Basic Interaction") 60 | 61 | # Initialize agent with custom handler 62 | agent = Agent( 63 | callback_handler=ToolCountingCallbackHandler().callback_handler, 64 | load_tools_from_directory=False, 65 | ) 66 | 67 | # Simple prompt to test callbacking 68 | agent("Tell me a short joke from your general knowledge") 69 | 70 | print("\nBasic Interaction Complete") 71 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/strands-agents/sdk-python/a64e80dfa5f5f6053a3fe53767f59fc3c5c1af95/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import logging 3 | import os 4 | import sys 5 | 6 | import boto3 7 | import moto 8 | import pytest 9 | 10 | ## Moto 11 | 12 | # Get the log level from the environment variable 13 | log_level = os.environ.get("LOG_LEVEL", "INFO").upper() 14 | 15 | logging.getLogger("strands").setLevel(log_level) 16 | logging.basicConfig( 17 | format="%(levelname)s | %(name)s | %(message)s", handlers=[logging.StreamHandler(stream=sys.stdout)] 18 | ) 19 | 20 | 21 | @pytest.fixture 22 | def moto_env(monkeypatch): 23 | monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test") 24 | monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test") 25 | monkeypatch.setenv("AWS_SECURITY_TOKEN", "test") 26 | monkeypatch.setenv("AWS_DEFAULT_REGION", "us-west-2") 27 | monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False) 28 | monkeypatch.delenv("OTEL_EXPORTER_OTLP_HEADERS", raising=False) 29 | 30 | 31 | @pytest.fixture 32 | def moto_mock_aws(): 33 | with moto.mock_aws(): 34 | yield 35 | 36 | 37 | @pytest.fixture 38 | def moto_cloudwatch_client(): 39 | return boto3.client("cloudwatch") 40 | 41 | 42 | ## Boto3 43 | 44 | 45 | @pytest.fixture 46 | def boto3_profile_name(): 47 | return "test-profile" 48 | 49 | 50 | @pytest.fixture 51 | def boto3_profile(boto3_profile_name): 52 | config = configparser.ConfigParser() 53 | config[boto3_profile_name] = { 54 | "aws_access_key_id": "test", 55 | "aws_secret_access_key": "test", 56 | } 57 | 58 | return config 59 | 60 | 61 | @pytest.fixture 62 | def boto3_profile_path(boto3_profile, tmp_path, monkeypatch): 63 | path = tmp_path / ".aws/credentials" 64 | path.parent.mkdir(exist_ok=True) 65 | with path.open("w") as fp: 66 | boto3_profile.write(fp) 67 | 68 | monkeypatch.setenv("AWS_SHARED_CREDENTIALS_FILE", str(path)) 69 | 70 | return path 71 | -------------------------------------------------------------------------------- /tests/strands/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/strands-agents/sdk-python/a64e80dfa5f5f6053a3fe53767f59fc3c5c1af95/tests/strands/__init__.py -------------------------------------------------------------------------------- /tests/strands/agent/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/strands-agents/sdk-python/a64e80dfa5f5f6053a3fe53767f59fc3c5c1af95/tests/strands/agent/__init__.py -------------------------------------------------------------------------------- /tests/strands/agent/test_agent_result.py: -------------------------------------------------------------------------------- 1 | import unittest.mock 2 | from typing import cast 3 | 4 | import pytest 5 | 6 | from strands.agent.agent_result import AgentResult 7 | from strands.telemetry.metrics import EventLoopMetrics 8 | from strands.types.content import Message 9 | from strands.types.streaming import StopReason 10 | 11 | 12 | @pytest.fixture 13 | def mock_metrics(): 14 | return unittest.mock.Mock(spec=EventLoopMetrics) 15 | 16 | 17 | @pytest.fixture 18 | def simple_message(): 19 | return {"role": "assistant", "content": [{"text": "Hello world!"}]} 20 | 21 | 22 | @pytest.fixture 23 | def complex_message(): 24 | return { 25 | "role": "assistant", 26 | "content": [ 27 | {"text": "First paragraph"}, 28 | {"text": "Second paragraph"}, 29 | {"non_text_content": "This should be ignored"}, 30 | {"text": "Third paragraph"}, 31 | ], 32 | } 33 | 34 | 35 | @pytest.fixture 36 | def empty_message(): 37 | return {"role": "assistant", "content": []} 38 | 39 | 40 | def test__init__(mock_metrics, simple_message: Message): 41 | """Test that AgentResult can be properly initialized with all required fields.""" 42 | stop_reason: StopReason = "end_turn" 43 | state = {"key": "value"} 44 | 45 | result = AgentResult(stop_reason=stop_reason, message=simple_message, metrics=mock_metrics, state=state) 46 | 47 | assert result.stop_reason == stop_reason 48 | assert result.message == simple_message 49 | assert result.metrics == mock_metrics 50 | assert result.state == state 51 | 52 | 53 | def test__str__simple(mock_metrics, simple_message: Message): 54 | """Test that str() works with a simple message.""" 55 | result = AgentResult(stop_reason="end_turn", message=simple_message, metrics=mock_metrics, state={}) 56 | 57 | message_string = str(result) 58 | assert message_string == "Hello world!\n" 59 | 60 | 61 | def test__str__complex(mock_metrics, complex_message: Message): 62 | """Test that str() works with a complex message with multiple text blocks.""" 63 | result = AgentResult(stop_reason="end_turn", message=complex_message, metrics=mock_metrics, state={}) 64 | 65 | message_string = str(result) 66 | assert message_string == "First paragraph\nSecond paragraph\nThird paragraph\n" 67 | 68 | 69 | def test__str__empty(mock_metrics, empty_message: Message): 70 | """Test that str() works with an empty message.""" 71 | result = AgentResult(stop_reason="end_turn", message=empty_message, metrics=mock_metrics, state={}) 72 | 73 | message_string = str(result) 74 | assert message_string == "" 75 | 76 | 77 | def test__str__no_content(mock_metrics): 78 | """Test that str() works with a message that has no content field.""" 79 | message_without_content = cast(Message, {"role": "assistant"}) 80 | 81 | result = AgentResult(stop_reason="end_turn", message=message_without_content, metrics=mock_metrics, state={}) 82 | 83 | message_string = str(result) 84 | assert message_string == "" 85 | 86 | 87 | def test__str__non_dict_content(mock_metrics): 88 | """Test that str() handles non-dictionary content items gracefully.""" 89 | message_with_non_dict = cast( 90 | Message, 91 | {"role": "assistant", "content": [{"text": "Valid text"}, "Not a dictionary", {"text": "More valid text"}]}, 92 | ) 93 | 94 | result = AgentResult(stop_reason="end_turn", message=message_with_non_dict, metrics=mock_metrics, state={}) 95 | 96 | message_string = str(result) 97 | assert message_string == "Valid text\nMore valid text\n" 98 | -------------------------------------------------------------------------------- /tests/strands/event_loop/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/strands-agents/sdk-python/a64e80dfa5f5f6053a3fe53767f59fc3c5c1af95/tests/strands/event_loop/__init__.py -------------------------------------------------------------------------------- /tests/strands/event_loop/test_error_handler.py: -------------------------------------------------------------------------------- 1 | import unittest.mock 2 | 3 | import botocore 4 | import pytest 5 | 6 | import strands 7 | from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException 8 | 9 | 10 | @pytest.fixture 11 | def callback_handler(): 12 | return unittest.mock.Mock() 13 | 14 | 15 | @pytest.fixture 16 | def kwargs(): 17 | return {"request_state": "value"} 18 | 19 | 20 | @pytest.fixture 21 | def tool_handler(): 22 | return unittest.mock.Mock() 23 | 24 | 25 | @pytest.fixture 26 | def model(): 27 | return unittest.mock.Mock() 28 | 29 | 30 | @pytest.fixture 31 | def tool_config(): 32 | return {} 33 | 34 | 35 | @pytest.fixture 36 | def system_prompt(): 37 | return "prompt." 38 | 39 | 40 | @pytest.fixture 41 | def sdk_event_loop(): 42 | with unittest.mock.patch.object(strands.event_loop.event_loop, "recurse_event_loop") as mock: 43 | yield mock 44 | 45 | 46 | @pytest.fixture 47 | def event_stream_error(request): 48 | message = request.param 49 | return botocore.exceptions.EventStreamError({"Error": {"Message": message}}, "mock_operation") 50 | 51 | 52 | def test_handle_throttling_error(callback_handler, kwargs): 53 | exception = ModelThrottledException("ThrottlingException | ConverseStream") 54 | max_attempts = 2 55 | delay = 0.1 56 | max_delay = 1 57 | 58 | tru_retries = [] 59 | tru_delays = [] 60 | for attempt in range(max_attempts): 61 | retry, delay = strands.event_loop.error_handler.handle_throttling_error( 62 | exception, attempt, max_attempts, delay, max_delay, callback_handler, kwargs 63 | ) 64 | 65 | tru_retries.append(retry) 66 | tru_delays.append(delay) 67 | 68 | exp_retries = [True, False] 69 | exp_delays = [0.2, 0.2] 70 | 71 | assert tru_retries == exp_retries and tru_delays == exp_delays 72 | 73 | callback_handler.assert_has_calls( 74 | [ 75 | unittest.mock.call(event_loop_throttled_delay=0.1, request_state="value"), 76 | unittest.mock.call(force_stop=True, force_stop_reason=str(exception)), 77 | ] 78 | ) 79 | 80 | 81 | def test_handle_throttling_error_does_not_exist(callback_handler, kwargs): 82 | exception = ModelThrottledException("Other Error") 83 | attempt = 0 84 | max_attempts = 1 85 | delay = 1 86 | max_delay = 1 87 | 88 | tru_retry, tru_delay = strands.event_loop.error_handler.handle_throttling_error( 89 | exception, attempt, max_attempts, delay, max_delay, callback_handler, kwargs 90 | ) 91 | 92 | exp_retry = False 93 | exp_delay = 1 94 | 95 | assert tru_retry == exp_retry and tru_delay == exp_delay 96 | 97 | callback_handler.assert_called_with(force_stop=True, force_stop_reason=str(exception)) 98 | 99 | 100 | @pytest.mark.parametrize("event_stream_error", ["Input is too long for requested model"], indirect=True) 101 | def test_handle_input_too_long_error( 102 | sdk_event_loop, 103 | event_stream_error, 104 | model, 105 | system_prompt, 106 | tool_config, 107 | callback_handler, 108 | tool_handler, 109 | kwargs, 110 | ): 111 | sdk_event_loop.return_value = "success" 112 | 113 | messages = [ 114 | { 115 | "role": "user", 116 | "content": [ 117 | {"toolResult": {"toolUseId": "t1", "status": "success", "content": [{"text": "needs truncation"}]}} 118 | ], 119 | } 120 | ] 121 | 122 | tru_result = strands.event_loop.error_handler.handle_input_too_long_error( 123 | event_stream_error, 124 | messages, 125 | model, 126 | system_prompt, 127 | tool_config, 128 | callback_handler, 129 | tool_handler, 130 | kwargs, 131 | ) 132 | exp_result = "success" 133 | 134 | tru_messages = messages 135 | exp_messages = [ 136 | { 137 | "role": "user", 138 | "content": [ 139 | { 140 | "toolResult": { 141 | "toolUseId": "t1", 142 | "status": "error", 143 | "content": [{"text": "The tool result was too large!"}], 144 | }, 145 | }, 146 | ], 147 | }, 148 | ] 149 | 150 | assert tru_result == exp_result and tru_messages == exp_messages 151 | 152 | sdk_event_loop.assert_called_once_with( 153 | model=model, 154 | system_prompt=system_prompt, 155 | messages=messages, 156 | tool_config=tool_config, 157 | callback_handler=callback_handler, 158 | tool_handler=tool_handler, 159 | request_state="value", 160 | ) 161 | 162 | callback_handler.assert_not_called() 163 | 164 | 165 | @pytest.mark.parametrize("event_stream_error", ["Other error"], indirect=True) 166 | def test_handle_input_too_long_error_does_not_exist( 167 | sdk_event_loop, 168 | event_stream_error, 169 | model, 170 | system_prompt, 171 | tool_config, 172 | callback_handler, 173 | tool_handler, 174 | kwargs, 175 | ): 176 | messages = [] 177 | 178 | with pytest.raises(ContextWindowOverflowException): 179 | strands.event_loop.error_handler.handle_input_too_long_error( 180 | event_stream_error, 181 | messages, 182 | model, 183 | system_prompt, 184 | tool_config, 185 | callback_handler, 186 | tool_handler, 187 | kwargs, 188 | ) 189 | 190 | sdk_event_loop.assert_not_called() 191 | callback_handler.assert_called_with(force_stop=True, force_stop_reason=str(event_stream_error)) 192 | 193 | 194 | @pytest.mark.parametrize("event_stream_error", ["Input is too long for requested model"], indirect=True) 195 | def test_handle_input_too_long_error_no_tool_result( 196 | sdk_event_loop, 197 | event_stream_error, 198 | model, 199 | system_prompt, 200 | tool_config, 201 | callback_handler, 202 | tool_handler, 203 | kwargs, 204 | ): 205 | messages = [] 206 | 207 | with pytest.raises(ContextWindowOverflowException): 208 | strands.event_loop.error_handler.handle_input_too_long_error( 209 | event_stream_error, 210 | messages, 211 | model, 212 | system_prompt, 213 | tool_config, 214 | callback_handler, 215 | tool_handler, 216 | kwargs, 217 | ) 218 | 219 | sdk_event_loop.assert_not_called() 220 | callback_handler.assert_called_with(force_stop=True, force_stop_reason=str(event_stream_error)) 221 | -------------------------------------------------------------------------------- /tests/strands/event_loop/test_message_processor.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import pytest 4 | 5 | from strands.event_loop import message_processor 6 | 7 | 8 | @pytest.mark.parametrize( 9 | "messages,expected,expected_messages", 10 | [ 11 | # Orphaned toolUse with empty input, no toolResult 12 | ( 13 | [ 14 | {"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "input": {}, "name": "foo"}}]}, 15 | {"role": "user", "content": [{"toolResult": {"toolUseId": "2"}}]}, 16 | ], 17 | True, 18 | [ 19 | {"role": "assistant", "content": [{"text": "[Attempted to use foo, but operation was canceled]"}]}, 20 | {"role": "user", "content": [{"toolResult": {"toolUseId": "2"}}]}, 21 | ], 22 | ), 23 | # toolUse with input, has matching toolResult 24 | ( 25 | [ 26 | {"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "input": {"a": 1}, "name": "foo"}}]}, 27 | {"role": "user", "content": [{"toolResult": {"toolUseId": "1"}}]}, 28 | ], 29 | False, 30 | [ 31 | {"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "input": {"a": 1}, "name": "foo"}}]}, 32 | {"role": "user", "content": [{"toolResult": {"toolUseId": "1"}}]}, 33 | ], 34 | ), 35 | # No messages 36 | ( 37 | [], 38 | False, 39 | [], 40 | ), 41 | ], 42 | ) 43 | def test_clean_orphaned_empty_tool_uses(messages, expected, expected_messages): 44 | test_messages = copy.deepcopy(messages) 45 | result = message_processor.clean_orphaned_empty_tool_uses(test_messages) 46 | assert result == expected 47 | assert test_messages == expected_messages 48 | 49 | 50 | @pytest.mark.parametrize( 51 | "messages,expected_idx", 52 | [ 53 | ( 54 | [ 55 | {"role": "user", "content": [{"text": "hi"}]}, 56 | {"role": "user", "content": [{"toolResult": {"toolUseId": "1"}}]}, 57 | {"role": "assistant", "content": [{"text": "ok"}]}, 58 | ], 59 | 1, 60 | ), 61 | ( 62 | [ 63 | {"role": "user", "content": [{"text": "hi"}]}, 64 | {"role": "assistant", "content": [{"text": "ok"}]}, 65 | ], 66 | None, 67 | ), 68 | ( 69 | [], 70 | None, 71 | ), 72 | ], 73 | ) 74 | def test_find_last_message_with_tool_results(messages, expected_idx): 75 | idx = message_processor.find_last_message_with_tool_results(messages) 76 | assert idx == expected_idx 77 | 78 | 79 | @pytest.mark.parametrize( 80 | "messages,msg_idx,expected_changed,expected_content", 81 | [ 82 | ( 83 | [ 84 | { 85 | "role": "user", 86 | "content": [{"toolResult": {"toolUseId": "1", "status": "ok", "content": [{"text": "big"}]}}], 87 | } 88 | ], 89 | 0, 90 | True, 91 | [ 92 | { 93 | "toolResult": { 94 | "toolUseId": "1", 95 | "status": "error", 96 | "content": [{"text": "The tool result was too large!"}], 97 | } 98 | } 99 | ], 100 | ), 101 | ( 102 | [{"role": "user", "content": [{"text": "no tool result"}]}], 103 | 0, 104 | False, 105 | [{"text": "no tool result"}], 106 | ), 107 | ( 108 | [], 109 | 0, 110 | False, 111 | [], 112 | ), 113 | ( 114 | [{"role": "user", "content": [{"toolResult": {"toolUseId": "1"}}]}], 115 | 2, 116 | False, 117 | [{"toolResult": {"toolUseId": "1"}}], 118 | ), 119 | ], 120 | ) 121 | def test_truncate_tool_results(messages, msg_idx, expected_changed, expected_content): 122 | test_messages = copy.deepcopy(messages) 123 | changed = message_processor.truncate_tool_results(test_messages, msg_idx) 124 | assert changed == expected_changed 125 | if 0 <= msg_idx < len(test_messages): 126 | assert test_messages[msg_idx]["content"] == expected_content 127 | else: 128 | assert test_messages == messages 129 | -------------------------------------------------------------------------------- /tests/strands/handlers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/strands-agents/sdk-python/a64e80dfa5f5f6053a3fe53767f59fc3c5c1af95/tests/strands/handlers/__init__.py -------------------------------------------------------------------------------- /tests/strands/handlers/test_callback_handler.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for the SDK callback handler module. 3 | 4 | These tests ensure the basic print-based callback handler in the SDK functions correctly. 5 | """ 6 | 7 | import unittest.mock 8 | 9 | import pytest 10 | 11 | from strands.handlers.callback_handler import CompositeCallbackHandler, PrintingCallbackHandler 12 | 13 | 14 | @pytest.fixture 15 | def handler(): 16 | """Create a fresh PrintingCallbackHandler instance for testing.""" 17 | return PrintingCallbackHandler() 18 | 19 | 20 | @pytest.fixture 21 | def mock_print(): 22 | with unittest.mock.patch("builtins.print") as mock: 23 | yield mock 24 | 25 | 26 | def test_call_with_empty_args(handler, mock_print): 27 | """Test calling the handler with no arguments.""" 28 | handler() 29 | # No output should be printed 30 | mock_print.assert_not_called() 31 | 32 | 33 | def test_call_handler_reasoningText(handler, mock_print): 34 | """Test calling the handler with reasoningText.""" 35 | handler(reasoningText="This is reasoning text") 36 | # Should print reasoning text without newline 37 | mock_print.assert_called_once_with("This is reasoning text", end="") 38 | 39 | 40 | def test_call_without_reasoningText(handler, mock_print): 41 | """Test calling the handler without reasoningText argument.""" 42 | handler(data="Some output") 43 | # Should only print data, not reasoningText 44 | mock_print.assert_called_once_with("Some output", end="") 45 | 46 | 47 | def test_call_with_reasoningText_and_data(handler, mock_print): 48 | """Test calling the handler with both reasoningText and data.""" 49 | handler(reasoningText="Reasoning", data="Output") 50 | # Should print reasoningText and data, both without newline 51 | calls = [ 52 | unittest.mock.call("Reasoning", end=""), 53 | unittest.mock.call("Output", end=""), 54 | ] 55 | mock_print.assert_has_calls(calls) 56 | 57 | 58 | def test_call_with_data_incomplete(handler, mock_print): 59 | """Test calling the handler with data but not complete.""" 60 | handler(data="Test output") 61 | # Should print without newline 62 | mock_print.assert_called_once_with("Test output", end="") 63 | 64 | 65 | def test_call_with_data_complete(handler, mock_print): 66 | """Test calling the handler with data and complete=True.""" 67 | handler(data="Test output", complete=True) 68 | # Should print with newline 69 | # The handler prints the data, and then also prints a newline when complete=True and data exists 70 | assert mock_print.call_count == 2 71 | mock_print.assert_any_call("Test output", end="\n") 72 | mock_print.assert_any_call("\n") 73 | 74 | 75 | def test_call_with_current_tool_use_new(handler, mock_print): 76 | """Test calling the handler with a new tool use.""" 77 | current_tool_use = {"name": "test_tool", "input": {"param": "value"}} 78 | 79 | handler(current_tool_use=current_tool_use) 80 | 81 | # Should print tool information 82 | mock_print.assert_called_once_with("\nTool #1: test_tool") 83 | 84 | # Should update the handler state 85 | assert handler.tool_count == 1 86 | assert handler.previous_tool_use == current_tool_use 87 | 88 | 89 | def test_call_with_current_tool_use_same(handler, mock_print): 90 | """Test calling the handler with the same tool use twice.""" 91 | current_tool_use = {"name": "test_tool", "input": {"param": "value"}} 92 | 93 | # First call 94 | handler(current_tool_use=current_tool_use) 95 | mock_print.reset_mock() 96 | 97 | # Second call with same tool use 98 | handler(current_tool_use=current_tool_use) 99 | 100 | # Should not print tool information again 101 | mock_print.assert_not_called() 102 | 103 | # Tool count should not increase 104 | assert handler.tool_count == 1 105 | 106 | 107 | def test_call_with_current_tool_use_different(handler, mock_print): 108 | """Test calling the handler with different tool uses.""" 109 | first_tool_use = {"name": "first_tool", "input": {"param": "value1"}} 110 | second_tool_use = {"name": "second_tool", "input": {"param": "value2"}} 111 | 112 | # First call 113 | handler(current_tool_use=first_tool_use) 114 | mock_print.reset_mock() 115 | 116 | # Second call with different tool use 117 | handler(current_tool_use=second_tool_use) 118 | 119 | # Should print info for the new tool 120 | mock_print.assert_called_once_with("\nTool #2: second_tool") 121 | 122 | # Tool count should increase 123 | assert handler.tool_count == 2 124 | assert handler.previous_tool_use == second_tool_use 125 | 126 | 127 | def test_call_with_data_and_complete_extra_newline(handler, mock_print): 128 | """Test that an extra newline is printed when data is complete.""" 129 | handler(data="Test output", complete=True) 130 | 131 | # The handler prints the data with newline and an extra newline for completion 132 | assert mock_print.call_count == 2 133 | mock_print.assert_any_call("Test output", end="\n") 134 | mock_print.assert_any_call("\n") 135 | 136 | 137 | def test_call_with_message_no_effect(handler, mock_print): 138 | """Test that passing a message without special content has no effect.""" 139 | message = {"role": "user", "content": [{"text": "Hello"}]} 140 | 141 | handler(message=message) 142 | 143 | # No print calls should be made 144 | mock_print.assert_not_called() 145 | 146 | 147 | def test_call_with_multiple_parameters(handler, mock_print): 148 | """Test calling handler with multiple parameters.""" 149 | current_tool_use = {"name": "test_tool", "input": {"param": "value"}} 150 | 151 | handler(data="Test output", complete=True, current_tool_use=current_tool_use) 152 | 153 | # Should print data with newline, an extra newline for completion, and tool information 154 | assert mock_print.call_count == 3 155 | mock_print.assert_any_call("Test output", end="\n") 156 | mock_print.assert_any_call("\n") 157 | mock_print.assert_any_call("\nTool #1: test_tool") 158 | 159 | 160 | def test_unknown_tool_name_handling(handler, mock_print): 161 | """Test handling of a tool use without a name.""" 162 | # The SDK implementation doesn't have a fallback for tool uses without a name field 163 | # It checks for both presence of current_tool_use and current_tool_use.get("name") 164 | current_tool_use = {"input": {"param": "value"}, "name": "Unknown tool"} 165 | 166 | handler(current_tool_use=current_tool_use) 167 | 168 | # Should print the tool information 169 | mock_print.assert_called_once_with("\nTool #1: Unknown tool") 170 | 171 | 172 | def test_tool_use_empty_object(handler, mock_print): 173 | """Test handling of an empty tool use object.""" 174 | # Tool use is an empty dict 175 | current_tool_use = {} 176 | 177 | handler(current_tool_use=current_tool_use) 178 | 179 | # Should not print anything 180 | mock_print.assert_not_called() 181 | 182 | # Should not update state 183 | assert handler.tool_count == 0 184 | assert handler.previous_tool_use is None 185 | 186 | 187 | def test_composite_handler_forwards_to_all_handlers(): 188 | mock_handlers = [unittest.mock.Mock() for _ in range(3)] 189 | composite_handler = CompositeCallbackHandler(*mock_handlers) 190 | 191 | """Test that calling the handler forwards the call to all handlers.""" 192 | # Create test arguments 193 | kwargs = { 194 | "data": "Test output", 195 | "complete": True, 196 | "current_tool_use": {"name": "test_tool", "input": {"param": "value"}}, 197 | } 198 | 199 | # Call the composite handler 200 | composite_handler(**kwargs) 201 | 202 | # Verify each handler was called with the same arguments 203 | for handler in mock_handlers: 204 | handler.assert_called_once_with(**kwargs) 205 | -------------------------------------------------------------------------------- /tests/strands/handlers/test_tool_handler.py: -------------------------------------------------------------------------------- 1 | import unittest.mock 2 | 3 | import pytest 4 | 5 | import strands 6 | 7 | 8 | @pytest.fixture 9 | def tool_registry(): 10 | return strands.tools.registry.ToolRegistry() 11 | 12 | 13 | @pytest.fixture 14 | def tool_handler(tool_registry): 15 | return strands.handlers.tool_handler.AgentToolHandler(tool_registry) 16 | 17 | 18 | @pytest.fixture 19 | def tool_use_identity(tool_registry): 20 | @strands.tools.tool 21 | def identity(a: int) -> int: 22 | return a 23 | 24 | identity_tool = strands.tools.tools.FunctionTool(identity) 25 | tool_registry.register_tool(identity_tool) 26 | 27 | return {"toolUseId": "identity", "name": "identity", "input": {"a": 1}} 28 | 29 | 30 | @pytest.fixture 31 | def tool_use_error(tool_registry): 32 | def error(): 33 | return 34 | 35 | error.TOOL_SPEC = {"invalid": True} 36 | 37 | error_tool = strands.tools.tools.FunctionTool(error) 38 | tool_registry.register_tool(error_tool) 39 | 40 | return {"toolUseId": "error", "name": "error", "input": {}} 41 | 42 | 43 | def test_preprocess(tool_handler, tool_use_identity): 44 | tool_handler.preprocess(tool_use_identity, tool_config={}) 45 | 46 | 47 | def test_process(tool_handler, tool_use_identity): 48 | tru_result = tool_handler.process( 49 | tool_use_identity, 50 | model=unittest.mock.Mock(), 51 | system_prompt="p1", 52 | messages=[], 53 | tool_config={}, 54 | callback_handler=unittest.mock.Mock(), 55 | ) 56 | exp_result = {"toolUseId": "identity", "status": "success", "content": [{"text": "1"}]} 57 | 58 | assert tru_result == exp_result 59 | 60 | 61 | def test_process_missing_tool(tool_handler): 62 | tru_result = tool_handler.process( 63 | tool={"toolUseId": "missing", "name": "missing", "input": {}}, 64 | model=unittest.mock.Mock(), 65 | system_prompt="p1", 66 | messages=[], 67 | tool_config={}, 68 | callback_handler=unittest.mock.Mock(), 69 | ) 70 | exp_result = { 71 | "toolUseId": "missing", 72 | "status": "error", 73 | "content": [{"text": "Unknown tool: missing"}], 74 | } 75 | 76 | assert tru_result == exp_result 77 | -------------------------------------------------------------------------------- /tests/strands/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/strands-agents/sdk-python/a64e80dfa5f5f6053a3fe53767f59fc3c5c1af95/tests/strands/models/__init__.py -------------------------------------------------------------------------------- /tests/strands/models/test_litellm.py: -------------------------------------------------------------------------------- 1 | import unittest.mock 2 | 3 | import pytest 4 | 5 | import strands 6 | from strands.models.litellm import LiteLLMModel 7 | 8 | 9 | @pytest.fixture 10 | def litellm_client_cls(): 11 | with unittest.mock.patch.object(strands.models.litellm.litellm, "LiteLLM") as mock_client_cls: 12 | yield mock_client_cls 13 | 14 | 15 | @pytest.fixture 16 | def litellm_client(litellm_client_cls): 17 | return litellm_client_cls.return_value 18 | 19 | 20 | @pytest.fixture 21 | def model_id(): 22 | return "m1" 23 | 24 | 25 | @pytest.fixture 26 | def model(litellm_client, model_id): 27 | _ = litellm_client 28 | 29 | return LiteLLMModel(model_id=model_id) 30 | 31 | 32 | @pytest.fixture 33 | def messages(): 34 | return [{"role": "user", "content": [{"text": "test"}]}] 35 | 36 | 37 | @pytest.fixture 38 | def system_prompt(): 39 | return "s1" 40 | 41 | 42 | def test__init__(litellm_client_cls, model_id): 43 | model = LiteLLMModel({"api_key": "k1"}, model_id=model_id, params={"max_tokens": 1}) 44 | 45 | tru_config = model.get_config() 46 | exp_config = {"model_id": "m1", "params": {"max_tokens": 1}} 47 | 48 | assert tru_config == exp_config 49 | 50 | litellm_client_cls.assert_called_once_with(api_key="k1") 51 | 52 | 53 | def test_update_config(model, model_id): 54 | model.update_config(model_id=model_id) 55 | 56 | tru_model_id = model.get_config().get("model_id") 57 | exp_model_id = model_id 58 | 59 | assert tru_model_id == exp_model_id 60 | 61 | 62 | @pytest.mark.parametrize( 63 | "content, exp_result", 64 | [ 65 | # Case 1: Thinking 66 | ( 67 | { 68 | "reasoningContent": { 69 | "reasoningText": { 70 | "signature": "reasoning_signature", 71 | "text": "reasoning_text", 72 | }, 73 | }, 74 | }, 75 | { 76 | "signature": "reasoning_signature", 77 | "thinking": "reasoning_text", 78 | "type": "thinking", 79 | }, 80 | ), 81 | # Case 2: Video 82 | ( 83 | { 84 | "video": { 85 | "source": {"bytes": "base64encodedvideo"}, 86 | }, 87 | }, 88 | { 89 | "type": "video_url", 90 | "video_url": { 91 | "detail": "auto", 92 | "url": "base64encodedvideo", 93 | }, 94 | }, 95 | ), 96 | # Case 3: Text 97 | ( 98 | {"text": "hello"}, 99 | {"type": "text", "text": "hello"}, 100 | ), 101 | ], 102 | ) 103 | def test_format_request_message_content(content, exp_result): 104 | tru_result = LiteLLMModel.format_request_message_content(content) 105 | assert tru_result == exp_result 106 | -------------------------------------------------------------------------------- /tests/strands/models/test_openai.py: -------------------------------------------------------------------------------- 1 | import unittest.mock 2 | 3 | import pytest 4 | 5 | import strands 6 | from strands.models.openai import OpenAIModel 7 | 8 | 9 | @pytest.fixture 10 | def openai_client_cls(): 11 | with unittest.mock.patch.object(strands.models.openai.openai, "OpenAI") as mock_client_cls: 12 | yield mock_client_cls 13 | 14 | 15 | @pytest.fixture 16 | def openai_client(openai_client_cls): 17 | return openai_client_cls.return_value 18 | 19 | 20 | @pytest.fixture 21 | def model_id(): 22 | return "m1" 23 | 24 | 25 | @pytest.fixture 26 | def model(openai_client, model_id): 27 | _ = openai_client 28 | 29 | return OpenAIModel(model_id=model_id) 30 | 31 | 32 | @pytest.fixture 33 | def messages(): 34 | return [{"role": "user", "content": [{"text": "test"}]}] 35 | 36 | 37 | @pytest.fixture 38 | def system_prompt(): 39 | return "s1" 40 | 41 | 42 | def test__init__(openai_client_cls, model_id): 43 | model = OpenAIModel({"api_key": "k1"}, model_id=model_id, params={"max_tokens": 1}) 44 | 45 | tru_config = model.get_config() 46 | exp_config = {"model_id": "m1", "params": {"max_tokens": 1}} 47 | 48 | assert tru_config == exp_config 49 | 50 | openai_client_cls.assert_called_once_with(api_key="k1") 51 | 52 | 53 | def test_update_config(model, model_id): 54 | model.update_config(model_id=model_id) 55 | 56 | tru_model_id = model.get_config().get("model_id") 57 | exp_model_id = model_id 58 | 59 | assert tru_model_id == exp_model_id 60 | 61 | 62 | def test_stream(openai_client, model): 63 | mock_tool_call_1_part_1 = unittest.mock.Mock(index=0) 64 | mock_tool_call_2_part_1 = unittest.mock.Mock(index=1) 65 | mock_delta_1 = unittest.mock.Mock( 66 | content="I'll calculate", tool_calls=[mock_tool_call_1_part_1, mock_tool_call_2_part_1] 67 | ) 68 | 69 | mock_tool_call_1_part_2 = unittest.mock.Mock(index=0) 70 | mock_tool_call_2_part_2 = unittest.mock.Mock(index=1) 71 | mock_delta_2 = unittest.mock.Mock( 72 | content="that for you", tool_calls=[mock_tool_call_1_part_2, mock_tool_call_2_part_2] 73 | ) 74 | 75 | mock_delta_3 = unittest.mock.Mock(content="", tool_calls=None) 76 | 77 | mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_1)]) 78 | mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_2)]) 79 | mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_3)]) 80 | mock_event_4 = unittest.mock.Mock() 81 | 82 | openai_client.chat.completions.create.return_value = iter([mock_event_1, mock_event_2, mock_event_3, mock_event_4]) 83 | 84 | request = {"model": "m1", "messages": [{"role": "user", "content": [{"type": "text", "text": "calculate 2+2"}]}]} 85 | response = model.stream(request) 86 | 87 | tru_events = list(response) 88 | exp_events = [ 89 | {"chunk_type": "message_start"}, 90 | {"chunk_type": "content_start", "data_type": "text"}, 91 | {"chunk_type": "content_delta", "data_type": "text", "data": "I'll calculate"}, 92 | {"chunk_type": "content_delta", "data_type": "text", "data": "that for you"}, 93 | {"chunk_type": "content_stop", "data_type": "text"}, 94 | {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call_1_part_1}, 95 | {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_1_part_1}, 96 | {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_1_part_2}, 97 | {"chunk_type": "content_stop", "data_type": "tool"}, 98 | {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call_2_part_1}, 99 | {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_2_part_1}, 100 | {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_2_part_2}, 101 | {"chunk_type": "content_stop", "data_type": "tool"}, 102 | {"chunk_type": "message_stop", "data": "tool_calls"}, 103 | {"chunk_type": "metadata", "data": mock_event_4.usage}, 104 | ] 105 | 106 | assert tru_events == exp_events 107 | openai_client.chat.completions.create.assert_called_once_with(**request) 108 | 109 | 110 | def test_stream_empty(openai_client, model): 111 | mock_delta = unittest.mock.Mock(content=None, tool_calls=None) 112 | mock_usage = unittest.mock.Mock(prompt_tokens=0, completion_tokens=0, total_tokens=0) 113 | 114 | mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) 115 | mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) 116 | mock_event_3 = unittest.mock.Mock() 117 | mock_event_4 = unittest.mock.Mock(usage=mock_usage) 118 | 119 | openai_client.chat.completions.create.return_value = iter([mock_event_1, mock_event_2, mock_event_3, mock_event_4]) 120 | 121 | request = {"model": "m1", "messages": [{"role": "user", "content": []}]} 122 | response = model.stream(request) 123 | 124 | tru_events = list(response) 125 | exp_events = [ 126 | {"chunk_type": "message_start"}, 127 | {"chunk_type": "content_start", "data_type": "text"}, 128 | {"chunk_type": "content_stop", "data_type": "text"}, 129 | {"chunk_type": "message_stop", "data": "stop"}, 130 | {"chunk_type": "metadata", "data": mock_usage}, 131 | ] 132 | 133 | assert tru_events == exp_events 134 | openai_client.chat.completions.create.assert_called_once_with(**request) 135 | 136 | 137 | def test_stream_with_empty_choices(openai_client, model): 138 | mock_delta = unittest.mock.Mock(content="content", tool_calls=None) 139 | mock_usage = unittest.mock.Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30) 140 | 141 | # Event with no choices attribute 142 | mock_event_1 = unittest.mock.Mock(spec=[]) 143 | 144 | # Event with empty choices list 145 | mock_event_2 = unittest.mock.Mock(choices=[]) 146 | 147 | # Valid event with content 148 | mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) 149 | 150 | # Event with finish reason 151 | mock_event_4 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) 152 | 153 | # Final event with usage info 154 | mock_event_5 = unittest.mock.Mock(usage=mock_usage) 155 | 156 | openai_client.chat.completions.create.return_value = iter( 157 | [mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5] 158 | ) 159 | 160 | request = {"model": "m1", "messages": [{"role": "user", "content": ["test"]}]} 161 | response = model.stream(request) 162 | 163 | tru_events = list(response) 164 | exp_events = [ 165 | {"chunk_type": "message_start"}, 166 | {"chunk_type": "content_start", "data_type": "text"}, 167 | {"chunk_type": "content_delta", "data_type": "text", "data": "content"}, 168 | {"chunk_type": "content_delta", "data_type": "text", "data": "content"}, 169 | {"chunk_type": "content_stop", "data_type": "text"}, 170 | {"chunk_type": "message_stop", "data": "stop"}, 171 | {"chunk_type": "metadata", "data": mock_usage}, 172 | ] 173 | 174 | assert tru_events == exp_events 175 | openai_client.chat.completions.create.assert_called_once_with(**request) 176 | -------------------------------------------------------------------------------- /tests/strands/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/strands-agents/sdk-python/a64e80dfa5f5f6053a3fe53767f59fc3c5c1af95/tests/strands/tools/__init__.py -------------------------------------------------------------------------------- /tests/strands/tools/mcp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/strands-agents/sdk-python/a64e80dfa5f5f6053a3fe53767f59fc3c5c1af95/tests/strands/tools/mcp/__init__.py -------------------------------------------------------------------------------- /tests/strands/tools/mcp/test_mcp_agent_tool.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock 2 | 3 | import pytest 4 | from mcp.types import Tool as MCPTool 5 | 6 | from strands.tools.mcp import MCPAgentTool, MCPClient 7 | 8 | 9 | @pytest.fixture 10 | def mock_mcp_tool(): 11 | mock_tool = MagicMock(spec=MCPTool) 12 | mock_tool.name = "test_tool" 13 | mock_tool.description = "A test tool" 14 | mock_tool.inputSchema = {"type": "object", "properties": {}} 15 | return mock_tool 16 | 17 | 18 | @pytest.fixture 19 | def mock_mcp_client(): 20 | mock_server = MagicMock(spec=MCPClient) 21 | mock_server.call_tool_sync.return_value = { 22 | "status": "success", 23 | "toolUseId": "test-123", 24 | "content": [{"text": "Success result"}], 25 | } 26 | return mock_server 27 | 28 | 29 | @pytest.fixture 30 | def mcp_agent_tool(mock_mcp_tool, mock_mcp_client): 31 | return MCPAgentTool(mock_mcp_tool, mock_mcp_client) 32 | 33 | 34 | def test_tool_name(mcp_agent_tool, mock_mcp_tool): 35 | assert mcp_agent_tool.tool_name == "test_tool" 36 | assert mcp_agent_tool.tool_name == mock_mcp_tool.name 37 | 38 | 39 | def test_tool_type(mcp_agent_tool): 40 | assert mcp_agent_tool.tool_type == "python" 41 | 42 | 43 | def test_tool_spec_with_description(mcp_agent_tool, mock_mcp_tool): 44 | tool_spec = mcp_agent_tool.tool_spec 45 | 46 | assert tool_spec["name"] == "test_tool" 47 | assert tool_spec["description"] == "A test tool" 48 | assert tool_spec["inputSchema"]["json"] == {"type": "object", "properties": {}} 49 | 50 | 51 | def test_tool_spec_without_description(mock_mcp_tool, mock_mcp_client): 52 | mock_mcp_tool.description = None 53 | 54 | agent_tool = MCPAgentTool(mock_mcp_tool, mock_mcp_client) 55 | tool_spec = agent_tool.tool_spec 56 | 57 | assert tool_spec["description"] == "Tool which performs test_tool" 58 | 59 | 60 | def test_invoke(mcp_agent_tool, mock_mcp_client): 61 | tool_use = {"toolUseId": "test-123", "name": "test_tool", "input": {"param": "value"}} 62 | 63 | result = mcp_agent_tool.invoke(tool_use) 64 | 65 | mock_mcp_client.call_tool_sync.assert_called_once_with( 66 | tool_use_id="test-123", name="test_tool", arguments={"param": "value"} 67 | ) 68 | assert result == mock_mcp_client.call_tool_sync.return_value 69 | -------------------------------------------------------------------------------- /tests/strands/tools/mcp/test_mcp_client.py: -------------------------------------------------------------------------------- 1 | import time 2 | from unittest.mock import AsyncMock, MagicMock, patch 3 | 4 | import pytest 5 | from mcp import ListToolsResult 6 | from mcp.types import CallToolResult as MCPCallToolResult 7 | from mcp.types import TextContent as MCPTextContent 8 | from mcp.types import Tool as MCPTool 9 | 10 | from strands.tools.mcp import MCPClient 11 | from strands.types.exceptions import MCPClientInitializationError 12 | 13 | 14 | @pytest.fixture 15 | def mock_transport(): 16 | mock_read_stream = AsyncMock() 17 | mock_write_stream = AsyncMock() 18 | mock_transport_cm = AsyncMock() 19 | mock_transport_cm.__aenter__.return_value = (mock_read_stream, mock_write_stream) 20 | mock_transport_callable = MagicMock(return_value=mock_transport_cm) 21 | 22 | return { 23 | "read_stream": mock_read_stream, 24 | "write_stream": mock_write_stream, 25 | "transport_cm": mock_transport_cm, 26 | "transport_callable": mock_transport_callable, 27 | } 28 | 29 | 30 | @pytest.fixture 31 | def mock_session(): 32 | mock_session = AsyncMock() 33 | mock_session.initialize = AsyncMock() 34 | 35 | # Create a mock context manager for ClientSession 36 | mock_session_cm = AsyncMock() 37 | mock_session_cm.__aenter__.return_value = mock_session 38 | 39 | # Patch ClientSession to return our mock session 40 | with patch("strands.tools.mcp.mcp_client.ClientSession", return_value=mock_session_cm): 41 | yield mock_session 42 | 43 | 44 | @pytest.fixture 45 | def mcp_client(mock_transport, mock_session): 46 | with MCPClient(mock_transport["transport_callable"]) as client: 47 | yield client 48 | 49 | 50 | def test_mcp_client_context_manager(mock_transport, mock_session): 51 | """Test that the MCPClient context manager properly initializes and cleans up.""" 52 | with MCPClient(mock_transport["transport_callable"]) as client: 53 | assert client._background_thread is not None 54 | assert client._background_thread.is_alive() 55 | assert client._init_future.done() 56 | 57 | mock_transport["transport_cm"].__aenter__.assert_called_once() 58 | mock_session.initialize.assert_called_once() 59 | 60 | # After exiting the context manager, verify that the thread was cleaned up 61 | # Give a small delay for the thread to fully terminate 62 | time.sleep(0.1) 63 | assert client._background_thread is None 64 | 65 | 66 | def test_list_tools_sync(mock_transport, mock_session): 67 | """Test that list_tools_sync correctly retrieves and adapts tools.""" 68 | mock_tool = MCPTool(name="test_tool", description="A test tool", inputSchema={"type": "object", "properties": {}}) 69 | mock_session.list_tools.return_value = ListToolsResult(tools=[mock_tool]) 70 | 71 | with MCPClient(mock_transport["transport_callable"]) as client: 72 | tools = client.list_tools_sync() 73 | 74 | mock_session.list_tools.assert_called_once() 75 | 76 | assert len(tools) == 1 77 | assert tools[0].tool_name == "test_tool" 78 | 79 | 80 | def test_list_tools_sync_session_not_active(): 81 | """Test that list_tools_sync raises an error when session is not active.""" 82 | client = MCPClient(MagicMock()) 83 | 84 | with pytest.raises(MCPClientInitializationError, match="client.session is not running"): 85 | client.list_tools_sync() 86 | 87 | 88 | @pytest.mark.parametrize("is_error,expected_status", [(False, "success"), (True, "error")]) 89 | def test_call_tool_sync_status(mock_transport, mock_session, is_error, expected_status): 90 | """Test that call_tool_sync correctly handles success and error results.""" 91 | mock_content = MCPTextContent(type="text", text="Test message") 92 | mock_session.call_tool.return_value = MCPCallToolResult(isError=is_error, content=[mock_content]) 93 | 94 | with MCPClient(mock_transport["transport_callable"]) as client: 95 | result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) 96 | 97 | mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None) 98 | 99 | assert result["status"] == expected_status 100 | assert result["toolUseId"] == "test-123" 101 | assert len(result["content"]) == 1 102 | assert result["content"][0]["text"] == "Test message" 103 | 104 | 105 | def test_call_tool_sync_session_not_active(): 106 | """Test that call_tool_sync raises an error when session is not active.""" 107 | client = MCPClient(MagicMock()) 108 | 109 | with pytest.raises(MCPClientInitializationError, match="client.session is not running"): 110 | client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) 111 | 112 | 113 | def test_call_tool_sync_exception(mock_transport, mock_session): 114 | """Test that call_tool_sync correctly handles exceptions.""" 115 | mock_session.call_tool.side_effect = Exception("Test exception") 116 | 117 | with MCPClient(mock_transport["transport_callable"]) as client: 118 | result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) 119 | 120 | assert result["status"] == "error" 121 | assert result["toolUseId"] == "test-123" 122 | assert len(result["content"]) == 1 123 | assert "Test exception" in result["content"][0]["text"] 124 | 125 | 126 | def test_enter_with_initialization_exception(mock_transport): 127 | """Test that __enter__ handles exceptions during initialization properly.""" 128 | # Make the transport callable throw an exception 129 | mock_transport["transport_cm"].__aenter__.side_effect = Exception("Transport initialization failed") 130 | 131 | client = MCPClient(mock_transport["transport_callable"]) 132 | 133 | with pytest.raises(MCPClientInitializationError, match="the client initialization failed"): 134 | client.start() 135 | 136 | 137 | def test_exception_when_future_not_running(): 138 | """Test exception handling when the future is not running.""" 139 | # Create a client.with a mock transport 140 | mock_transport_callable = MagicMock() 141 | client = MCPClient(mock_transport_callable) 142 | 143 | # Create a mock future that is not running 144 | mock_future = MagicMock() 145 | mock_future.running.return_value = False 146 | client._init_future = mock_future 147 | 148 | # Create a mock event loop 149 | mock_event_loop = MagicMock() 150 | mock_event_loop.run_until_complete.side_effect = Exception("Test exception") 151 | 152 | # Patch the event loop creation 153 | with patch("asyncio.new_event_loop", return_value=mock_event_loop): 154 | # Run the background task which should trigger the exception 155 | try: 156 | client._background_task() 157 | except Exception: 158 | pass # We expect an exception to be raised 159 | 160 | # Verify that set_exception was not called since the future was not running 161 | mock_future.set_exception.assert_not_called() 162 | -------------------------------------------------------------------------------- /tests/strands/tools/test_registry.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for the SDK tool registry module. 3 | """ 4 | 5 | import pytest 6 | 7 | from strands.tools.registry import ToolRegistry 8 | 9 | 10 | def test_load_tool_from_filepath_failure(): 11 | """Test error handling when load_tool fails.""" 12 | tool_registry = ToolRegistry() 13 | error_message = "Failed to load tool failing_tool: Tool file not found: /path/to/failing_tool.py" 14 | 15 | with pytest.raises(ValueError, match=error_message): 16 | tool_registry.load_tool_from_filepath("failing_tool", "/path/to/failing_tool.py") 17 | 18 | 19 | def test_process_tools_with_invalid_path(): 20 | """Test that process_tools raises an exception when a non-path string is passed.""" 21 | tool_registry = ToolRegistry() 22 | invalid_path = "not a filepath" 23 | 24 | with pytest.raises(ValueError, match=f"Failed to load tool {invalid_path.split('.')[0]}: Tool file not found:.*"): 25 | tool_registry.process_tools([invalid_path]) 26 | -------------------------------------------------------------------------------- /tests/strands/tools/test_thread_pool_executor.py: -------------------------------------------------------------------------------- 1 | import concurrent 2 | 3 | import pytest 4 | 5 | import strands 6 | 7 | 8 | @pytest.fixture 9 | def thread_pool(): 10 | return concurrent.futures.ThreadPoolExecutor(max_workers=1) 11 | 12 | 13 | @pytest.fixture 14 | def thread_pool_wrapper(thread_pool): 15 | return strands.tools.ThreadPoolExecutorWrapper(thread_pool) 16 | 17 | 18 | def test_submit(thread_pool_wrapper): 19 | def fun(a, b): 20 | return (a, b) 21 | 22 | future = thread_pool_wrapper.submit(fun, 1, b=2) 23 | 24 | tru_result = future.result() 25 | exp_result = (1, 2) 26 | 27 | assert tru_result == exp_result 28 | 29 | 30 | def test_as_completed(thread_pool_wrapper): 31 | def fun(i): 32 | return i 33 | 34 | futures = [thread_pool_wrapper.submit(fun, i) for i in range(2)] 35 | 36 | tru_results = sorted(future.result() for future in thread_pool_wrapper.as_completed(futures)) 37 | exp_results = [0, 1] 38 | 39 | assert tru_results == exp_results 40 | 41 | 42 | def test_shutdown(thread_pool_wrapper): 43 | thread_pool_wrapper.shutdown() 44 | 45 | with pytest.raises(RuntimeError): 46 | thread_pool_wrapper.submit(lambda: None) 47 | -------------------------------------------------------------------------------- /tests/strands/tools/test_watcher.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for the SDK tool watcher module. 3 | """ 4 | 5 | from unittest.mock import MagicMock, patch 6 | 7 | import pytest 8 | 9 | from strands.tools.registry import ToolRegistry 10 | from strands.tools.watcher import ToolWatcher 11 | 12 | 13 | def test_tool_watcher_initialization(): 14 | """Test that the handler initializes with the correct tool registry.""" 15 | tool_registry = ToolRegistry() 16 | watcher = ToolWatcher(tool_registry) 17 | assert watcher.tool_registry == tool_registry 18 | 19 | 20 | @pytest.mark.parametrize( 21 | "test_case", 22 | [ 23 | # Regular Python file - should reload 24 | { 25 | "description": "Python file", 26 | "src_path": "/path/to/test_tool.py", 27 | "is_directory": False, 28 | "should_reload": True, 29 | "expected_tool_name": "test_tool", 30 | }, 31 | # Non-Python file - should not reload 32 | { 33 | "description": "Non-Python file", 34 | "src_path": "/path/to/test_tool.txt", 35 | "is_directory": False, 36 | "should_reload": False, 37 | }, 38 | # __init__.py file - should not reload 39 | { 40 | "description": "Init file", 41 | "src_path": "/path/to/__init__.py", 42 | "is_directory": False, 43 | "should_reload": False, 44 | }, 45 | # Directory path - should not reload 46 | { 47 | "description": "Directory path", 48 | "src_path": "/path/to/tools_directory", 49 | "is_directory": True, 50 | "should_reload": False, 51 | }, 52 | # Python file marked as directory - should still reload 53 | { 54 | "description": "Python file marked as directory", 55 | "src_path": "/path/to/test_tool2.py", 56 | "is_directory": True, 57 | "should_reload": True, 58 | "expected_tool_name": "test_tool2", 59 | }, 60 | ], 61 | ) 62 | @patch.object(ToolRegistry, "reload_tool") 63 | def test_on_modified_cases(mock_reload_tool, test_case): 64 | """Test various cases for the on_modified method.""" 65 | tool_registry = ToolRegistry() 66 | watcher = ToolWatcher(tool_registry) 67 | 68 | # Create a mock event with the specified properties 69 | event = MagicMock() 70 | event.src_path = test_case["src_path"] 71 | if "is_directory" in test_case: 72 | event.is_directory = test_case["is_directory"] 73 | 74 | # Call the on_modified method 75 | watcher.tool_change_handler.on_modified(event) 76 | 77 | # Verify the expected behavior 78 | if test_case["should_reload"]: 79 | mock_reload_tool.assert_called_once_with(test_case["expected_tool_name"]) 80 | else: 81 | mock_reload_tool.assert_not_called() 82 | 83 | 84 | @patch.object(ToolRegistry, "reload_tool", side_effect=Exception("Test error")) 85 | def test_on_modified_error_handling(mock_reload_tool): 86 | """Test that on_modified handles errors during tool reloading.""" 87 | tool_registry = ToolRegistry() 88 | watcher = ToolWatcher(tool_registry) 89 | 90 | # Create a mock event with a Python file path 91 | event = MagicMock() 92 | event.src_path = "/path/to/test_tool.py" 93 | 94 | # Call the on_modified method - should not raise an exception 95 | watcher.tool_change_handler.on_modified(event) 96 | 97 | # Verify that reload_tool was called 98 | mock_reload_tool.assert_called_once_with("test_tool") 99 | -------------------------------------------------------------------------------- /tests/strands/types/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/strands-agents/sdk-python/a64e80dfa5f5f6053a3fe53767f59fc3c5c1af95/tests/strands/types/models/__init__.py -------------------------------------------------------------------------------- /tests/strands/types/models/test_model.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from strands.types.models import Model as SAModel 4 | 5 | 6 | class TestModel(SAModel): 7 | def update_config(self, **model_config): 8 | return model_config 9 | 10 | def get_config(self): 11 | return 12 | 13 | def format_request(self, messages, tool_specs, system_prompt): 14 | return { 15 | "messages": messages, 16 | "tool_specs": tool_specs, 17 | "system_prompt": system_prompt, 18 | } 19 | 20 | def format_chunk(self, event): 21 | return {"event": event} 22 | 23 | def stream(self, request): 24 | yield {"request": request} 25 | 26 | 27 | @pytest.fixture 28 | def model(): 29 | return TestModel() 30 | 31 | 32 | @pytest.fixture 33 | def messages(): 34 | return [ 35 | { 36 | "role": "user", 37 | "content": [{"text": "hello"}], 38 | }, 39 | ] 40 | 41 | 42 | @pytest.fixture 43 | def tool_specs(): 44 | return [ 45 | { 46 | "name": "test_tool", 47 | "description": "A test tool", 48 | "inputSchema": { 49 | "json": { 50 | "type": "object", 51 | "properties": { 52 | "input": {"type": "string"}, 53 | }, 54 | "required": ["input"], 55 | }, 56 | }, 57 | }, 58 | ] 59 | 60 | 61 | @pytest.fixture 62 | def system_prompt(): 63 | return "s1" 64 | 65 | 66 | def test_converse(model, messages, tool_specs, system_prompt): 67 | response = model.converse(messages, tool_specs, system_prompt) 68 | 69 | tru_events = list(response) 70 | exp_events = [ 71 | { 72 | "event": { 73 | "request": { 74 | "messages": messages, 75 | "tool_specs": tool_specs, 76 | "system_prompt": system_prompt, 77 | }, 78 | }, 79 | }, 80 | ] 81 | assert tru_events == exp_events 82 | --------------------------------------------------------------------------------