├── .gemini └── config.yaml ├── .git-blame-ignore-revs ├── .github ├── CODEOWNERS ├── ISSUE_TEMPLATE │ ├── bug-report.yml │ └── feature-request.yml ├── PULL_REQUEST_TEMPLATE.md ├── actions │ └── spelling │ │ ├── advice.md │ │ ├── allow.txt │ │ ├── excludes.txt │ │ └── line_forbidden.patterns ├── conventional-commit-lint.yaml ├── dependabot.yml └── workflows │ ├── conventional-commits.yml │ ├── linter.yaml │ ├── python-publish.yml │ ├── release-please.yml │ ├── security.yaml │ ├── spelling.yaml │ ├── stale.yaml │ ├── unit-tests.yml │ └── update-a2a-types.yml ├── .gitignore ├── .jscpd.json ├── .pre-commit-config.yaml ├── .python-version ├── .vscode ├── extensions.json ├── launch.json └── settings.json ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── Gemini.md ├── LICENSE ├── README.md ├── SECURITY.md ├── buf.gen.yaml ├── pyproject.toml ├── scripts ├── format.sh ├── generate_types.sh └── grpc_gen_post_processor.py ├── src └── a2a │ ├── __init__.py │ ├── _base.py │ ├── auth │ ├── __init__.py │ └── user.py │ ├── client │ ├── __init__.py │ ├── auth │ │ ├── __init__.py │ │ ├── credentials.py │ │ └── interceptor.py │ ├── base_client.py │ ├── card_resolver.py │ ├── client.py │ ├── client_factory.py │ ├── client_task_manager.py │ ├── errors.py │ ├── helpers.py │ ├── legacy.py │ ├── legacy_grpc.py │ ├── middleware.py │ ├── optionals.py │ └── transports │ │ ├── __init__.py │ │ ├── base.py │ │ ├── grpc.py │ │ ├── jsonrpc.py │ │ └── rest.py │ ├── extensions │ ├── __init__.py │ └── common.py │ ├── grpc │ ├── __init__.py │ ├── a2a_pb2.py │ ├── a2a_pb2.pyi │ └── a2a_pb2_grpc.py │ ├── py.typed │ ├── server │ ├── __init__.py │ ├── agent_execution │ │ ├── __init__.py │ │ ├── agent_executor.py │ │ ├── context.py │ │ ├── request_context_builder.py │ │ └── simple_request_context_builder.py │ ├── apps │ │ ├── __init__.py │ │ ├── jsonrpc │ │ │ ├── __init__.py │ │ │ ├── fastapi_app.py │ │ │ ├── jsonrpc_app.py │ │ │ └── starlette_app.py │ │ └── rest │ │ │ ├── __init__.py │ │ │ ├── fastapi_app.py │ │ │ └── rest_adapter.py │ ├── context.py │ ├── events │ │ ├── __init__.py │ │ ├── event_consumer.py │ │ ├── event_queue.py │ │ ├── in_memory_queue_manager.py │ │ └── queue_manager.py │ ├── models.py │ ├── request_handlers │ │ ├── __init__.py │ │ ├── default_request_handler.py │ │ ├── grpc_handler.py │ │ ├── jsonrpc_handler.py │ │ ├── request_handler.py │ │ ├── response_helpers.py │ │ └── rest_handler.py │ └── tasks │ │ ├── __init__.py │ │ ├── base_push_notification_sender.py │ │ ├── database_push_notification_config_store.py │ │ ├── database_task_store.py │ │ ├── inmemory_push_notification_config_store.py │ │ ├── inmemory_task_store.py │ │ ├── push_notification_config_store.py │ │ ├── push_notification_sender.py │ │ ├── result_aggregator.py │ │ ├── task_manager.py │ │ ├── task_store.py │ │ └── task_updater.py │ ├── types.py │ └── utils │ ├── __init__.py │ ├── artifact.py │ ├── constants.py │ ├── error_handlers.py │ ├── errors.py │ ├── helpers.py │ ├── message.py │ ├── proto_utils.py │ ├── task.py │ └── telemetry.py ├── tests ├── README.md ├── auth │ └── test_user.py ├── client │ ├── test_auth_middleware.py │ ├── test_base_client.py │ ├── test_client_factory.py │ ├── test_client_task_manager.py │ ├── test_errors.py │ ├── test_grpc_client.py │ ├── test_jsonrpc_client.py │ ├── test_legacy_client.py │ └── test_optionals.py ├── e2e │ └── push_notifications │ │ ├── agent_app.py │ │ ├── notifications_app.py │ │ ├── test_default_push_notification_support.py │ │ └── utils.py ├── extensions │ └── test_common.py ├── integration │ └── test_client_server_integration.py ├── server │ ├── agent_execution │ │ ├── test_context.py │ │ └── test_simple_request_context_builder.py │ ├── apps │ │ ├── jsonrpc │ │ │ ├── test_fastapi_app.py │ │ │ ├── test_jsonrpc_app.py │ │ │ ├── test_serialization.py │ │ │ └── test_starlette_app.py │ │ └── rest │ │ │ └── test_rest_fastapi_app.py │ ├── events │ │ ├── test_event_consumer.py │ │ ├── test_event_queue.py │ │ └── test_inmemory_queue_manager.py │ ├── request_handlers │ │ ├── test_default_request_handler.py │ │ ├── test_grpc_handler.py │ │ ├── test_jsonrpc_handler.py │ │ └── test_response_helpers.py │ ├── tasks │ │ ├── test_database_push_notification_config_store.py │ │ ├── test_database_task_store.py │ │ ├── test_inmemory_push_notifications.py │ │ ├── test_inmemory_task_store.py │ │ ├── test_push_notification_sender.py │ │ ├── test_result_aggregator.py │ │ ├── test_task_manager.py │ │ └── test_task_updater.py │ ├── test_integration.py │ └── test_models.py ├── test_types.py └── utils │ ├── test_artifact.py │ ├── test_constants.py │ ├── test_error_handlers.py │ ├── test_helpers.py │ ├── test_message.py │ ├── test_proto_utils.py │ ├── test_task.py │ └── test_telemetry.py └── uv.lock /.gemini/config.yaml: -------------------------------------------------------------------------------- 1 | code_review: 2 | comment_severity_threshold: LOW 3 | ignore_patterns: ['CHANGELOG.md'] 4 | -------------------------------------------------------------------------------- /.git-blame-ignore-revs: -------------------------------------------------------------------------------- 1 | # Template taken from https://github.com/v8/v8/blob/master/.git-blame-ignore-revs. 2 | # 3 | # This file contains a list of git hashes of revisions to be ignored by git blame. These 4 | # revisions are considered "unimportant" in that they are unlikely to be what you are 5 | # interested in when blaming. Most of these will probably be commits related to linting 6 | # and code formatting. 7 | # 8 | # Instructions: 9 | # - Only large (generally automated) reformatting or renaming CLs should be 10 | # added to this list. Do not put things here just because you feel they are 11 | # trivial or unimportant. If in doubt, do not put it on this list. 12 | # - Precede each revision with a comment containing the PR title and number. 13 | # For bulk work over many commits, place all commits in a block with a single 14 | # comment at the top describing the work done in those commits. 15 | # - Only put full 40-character hashes on this list (not short hashes or any 16 | # other revision reference). 17 | # - Append to the bottom of the file (revisions should be in chronological order 18 | # from oldest to newest). 19 | # - Because you must use a hash, you need to append to this list in a follow-up 20 | # PR to the actual reformatting PR that you are trying to ignore. 21 | 193693836e1ed8cd361e139668323d2e267a9eaa 22 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Code owners file. 2 | # This file controls who is tagged for review for any given pull request. 3 | # 4 | # For syntax help see: 5 | # https://help.github.com/en/github/creating-cloning-and-archiving-repositories/about-code-owners#codeowners-syntax 6 | 7 | * @a2aproject/google-a2a-eng 8 | src/a2a/types.py @a2a-bot 9 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: 🐞 Bug Report 3 | description: File a bug report 4 | title: '[Bug]: ' 5 | type: Bug 6 | body: 7 | - type: markdown 8 | attributes: 9 | value: | 10 | Thanks for stopping by to let us know something could be better! 11 | Private Feedback? Please use this [Google form](https://goo.gle/a2a-feedback) 12 | - type: textarea 13 | id: what-happened 14 | attributes: 15 | label: What happened? 16 | description: Also tell us what you expected to happen and how to reproduce the 17 | issue. 18 | placeholder: Tell us what you see! 19 | value: A bug happened! 20 | validations: 21 | required: true 22 | - type: textarea 23 | id: logs 24 | attributes: 25 | label: Relevant log output 26 | description: Please copy and paste any relevant log output. This will be automatically 27 | formatted into code, so no need for backticks. 28 | render: shell 29 | - type: checkboxes 30 | id: terms 31 | attributes: 32 | label: Code of Conduct 33 | description: By submitting this issue, you agree to follow our [Code of Conduct](https://github.com/a2aproject/A2A?tab=coc-ov-file#readme) 34 | options: 35 | - label: I agree to follow this project's Code of Conduct 36 | required: true 37 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: 💡 Feature Request 3 | description: Suggest an idea for this repository 4 | title: '[Feat]: ' 5 | type: Feature 6 | body: 7 | - type: markdown 8 | attributes: 9 | value: | 10 | Thanks for stopping by to let us know something could be better! 11 | Private Feedback? Please use this [Google form](https://goo.gle/a2a-feedback) 12 | - type: textarea 13 | id: problem 14 | attributes: 15 | label: Is your feature request related to a problem? Please describe. 16 | description: A clear and concise description of what the problem is. 17 | placeholder: Ex. I'm always frustrated when [...] 18 | - type: textarea 19 | id: describe 20 | attributes: 21 | label: Describe the solution you'd like 22 | description: A clear and concise description of what you want to happen. 23 | validations: 24 | required: true 25 | - type: textarea 26 | id: alternatives 27 | attributes: 28 | label: Describe alternatives you've considered 29 | description: A clear and concise description of any alternative solutions or 30 | features you've considered. 31 | - type: textarea 32 | id: context 33 | attributes: 34 | label: Additional context 35 | description: Add any other context or screenshots about the feature request 36 | here. 37 | - type: checkboxes 38 | id: terms 39 | attributes: 40 | label: Code of Conduct 41 | description: By submitting this issue, you agree to follow our [Code of Conduct](https://github.com/a2aproject/a2a-python?tab=coc-ov-file#readme) 42 | options: 43 | - label: I agree to follow this project's Code of Conduct 44 | required: true 45 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | # Description 2 | 3 | Thank you for opening a Pull Request! 4 | Before submitting your PR, there are a few things you can do to make sure it goes smoothly: 5 | 6 | - [ ] Follow the [`CONTRIBUTING` Guide](https://github.com/a2aproject/a2a-python/blob/main/CONTRIBUTING.md). 7 | - [ ] Make your Pull Request title in the specification. 8 | - Important Prefixes for [release-please](https://github.com/googleapis/release-please): 9 | - `fix:` which represents bug fixes, and correlates to a [SemVer](https://semver.org/) patch. 10 | - `feat:` represents a new feature, and correlates to a SemVer minor. 11 | - `feat!:`, or `fix!:`, `refactor!:`, etc., which represent a breaking change (indicated by the `!`) and will result in a SemVer major. 12 | - [ ] Ensure the tests and linter pass (Run `bash scripts/format.sh` from the repository root to format) 13 | - [ ] Appropriate docs were updated (if necessary) 14 | 15 | Fixes # 🦕 16 | -------------------------------------------------------------------------------- /.github/actions/spelling/advice.md: -------------------------------------------------------------------------------- 1 | 2 |
If the flagged items are :exploding_head: false positives 3 | 4 | If items relate to a ... 5 | 6 | - binary file (or some other file you wouldn't want to check at all). 7 | 8 | Please add a file path to the `excludes.txt` file matching the containing file. 9 | 10 | File paths are Perl 5 Regular Expressions - you can [test](https://www.regexplanet.com/advanced/perl/) yours before committing to verify it will match your files. 11 | 12 | `^` refers to the file's path from the root of the repository, so `^README\.md$` would exclude `README.md` (on whichever branch you're using). 13 | 14 | - well-formed pattern. 15 | 16 | If you can write a [pattern](https://github.com/check-spelling/check-spelling/wiki/Configuration-Examples:-patterns) that would match it, 17 | try adding it to the `patterns.txt` file. 18 | 19 | Patterns are Perl 5 Regular Expressions - you can [test](https://www.regexplanet.com/advanced/perl/) yours before committing to verify it will match your lines. 20 | 21 | Note that patterns can't match multiline strings. 22 | 23 |
24 | 25 | 26 | 27 | :steam_locomotive: If you're seeing this message and your PR is from a branch that doesn't have check-spelling, 28 | please merge to your PR's base branch to get the version configured for your repository. 29 | -------------------------------------------------------------------------------- /.github/actions/spelling/allow.txt: -------------------------------------------------------------------------------- 1 | ACard 2 | AClient 3 | ACMRTUXB 4 | aconnect 5 | adk 6 | AError 7 | AFast 8 | agentic 9 | AGrpc 10 | aio 11 | aiomysql 12 | amannn 13 | aproject 14 | ARequest 15 | ARun 16 | AServer 17 | AServers 18 | AService 19 | AStarlette 20 | AUser 21 | autouse 22 | backticks 23 | cla 24 | cls 25 | coc 26 | codegen 27 | coro 28 | datamodel 29 | deepwiki 30 | drivername 31 | DSNs 32 | dunders 33 | euo 34 | EUR 35 | excinfo 36 | fernet 37 | fetchrow 38 | fetchval 39 | GBP 40 | genai 41 | getkwargs 42 | gle 43 | GVsb 44 | ietf 45 | initdb 46 | inmemory 47 | INR 48 | isready 49 | JPY 50 | JSONRPCt 51 | JWS 52 | kwarg 53 | langgraph 54 | lifecycles 55 | linting 56 | Llm 57 | lstrips 58 | mikeas 59 | mockurl 60 | notif 61 | oauthoidc 62 | oidc 63 | opensource 64 | otherurl 65 | postgres 66 | POSTGRES 67 | postgresql 68 | protoc 69 | pyi 70 | pypistats 71 | pyupgrade 72 | pyversions 73 | redef 74 | respx 75 | resub 76 | RUF 77 | SLF 78 | socio 79 | sse 80 | tagwords 81 | taskupdate 82 | testuuid 83 | Tful 84 | tiangolo 85 | typeerror 86 | vulnz 87 | -------------------------------------------------------------------------------- /.github/actions/spelling/excludes.txt: -------------------------------------------------------------------------------- 1 | # See https://github.com/check-spelling/check-spelling/wiki/Configuration-Examples:-excludes 2 | (?:^|/)(?i).gitignore\E$ 3 | (?:^|/)(?i)CODE_OF_CONDUCT.md\E$ 4 | (?:^|/)(?i)COPYRIGHT 5 | (?:^|/)(?i)LICEN[CS]E 6 | (?:^|/)3rdparty/ 7 | (?:^|/)go\.sum$ 8 | (?:^|/)package(?:-lock|)\.json$ 9 | (?:^|/)Pipfile$ 10 | (?:^|/)pyproject.toml 11 | (?:^|/)requirements(?:-dev|-doc|-test|)\.txt$ 12 | (?:^|/)vendor/ 13 | /CODEOWNERS$ 14 | \.a$ 15 | \.ai$ 16 | \.all-contributorsrc$ 17 | \.avi$ 18 | \.bmp$ 19 | \.bz2$ 20 | \.cer$ 21 | \.class$ 22 | \.coveragerc$ 23 | \.crl$ 24 | \.crt$ 25 | \.csr$ 26 | \.dll$ 27 | \.docx?$ 28 | \.drawio$ 29 | \.DS_Store$ 30 | \.eot$ 31 | \.eps$ 32 | \.exe$ 33 | \.gif$ 34 | \.git-blame-ignore-revs$ 35 | \.gitattributes$ 36 | \.gitignore\E$ 37 | \.gitkeep$ 38 | \.graffle$ 39 | \.gz$ 40 | \.icns$ 41 | \.ico$ 42 | \.jar$ 43 | \.jks$ 44 | \.jpe?g$ 45 | \.key$ 46 | \.lib$ 47 | \.lock$ 48 | \.map$ 49 | \.min\.. 50 | \.mo$ 51 | \.mod$ 52 | \.mp[34]$ 53 | \.o$ 54 | \.ocf$ 55 | \.otf$ 56 | \.p12$ 57 | \.parquet$ 58 | \.pdf$ 59 | \.pem$ 60 | \.pfx$ 61 | \.png$ 62 | \.psd$ 63 | \.pyc$ 64 | \.pylintrc$ 65 | \.qm$ 66 | \.ruff.toml$ 67 | \.s$ 68 | \.sig$ 69 | \.so$ 70 | \.svgz?$ 71 | \.sys$ 72 | \.tar$ 73 | \.tgz$ 74 | \.tiff?$ 75 | \.ttf$ 76 | \.vscode/ 77 | \.wav$ 78 | \.webm$ 79 | \.webp$ 80 | \.woff2?$ 81 | \.xcf$ 82 | \.xlsx?$ 83 | \.xpm$ 84 | \.xz$ 85 | \.zip$ 86 | ^\.github/actions/spelling/ 87 | ^\.github/workflows/ 88 | CHANGELOG.md 89 | ^src/a2a/grpc/ 90 | ^tests/ 91 | .pre-commit-config.yaml 92 | -------------------------------------------------------------------------------- /.github/conventional-commit-lint.yaml: -------------------------------------------------------------------------------- 1 | enabled: true 2 | always_check_pr_title: true 3 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: 'uv' 4 | directory: '/' 5 | schedule: 6 | interval: 'monthly' 7 | groups: 8 | uv-dependencies: 9 | patterns: 10 | - '*' 11 | - package-ecosystem: 'github-actions' 12 | directory: '/' 13 | schedule: 14 | interval: 'monthly' 15 | groups: 16 | github-actions: 17 | patterns: 18 | - '*' 19 | -------------------------------------------------------------------------------- /.github/workflows/conventional-commits.yml: -------------------------------------------------------------------------------- 1 | name: "Conventional Commits" 2 | 3 | on: 4 | pull_request: 5 | types: 6 | - opened 7 | - edited 8 | - synchronize 9 | 10 | permissions: 11 | contents: read 12 | 13 | jobs: 14 | main: 15 | permissions: 16 | pull-requests: read 17 | statuses: write 18 | name: Validate PR Title 19 | runs-on: ubuntu-latest 20 | steps: 21 | - name: semantic-pull-request 22 | uses: amannn/action-semantic-pull-request@v6.1.1 23 | env: 24 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 25 | with: 26 | validateSingleCommit: false 27 | -------------------------------------------------------------------------------- /.github/workflows/linter.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Lint Code Base 3 | on: 4 | pull_request: 5 | branches: [main] 6 | permissions: 7 | contents: read 8 | jobs: 9 | lint: 10 | name: Lint Code Base 11 | runs-on: ubuntu-latest 12 | if: github.repository == 'a2aproject/a2a-python' 13 | steps: 14 | - name: Checkout Code 15 | uses: actions/checkout@v5 16 | - name: Set up Python 17 | uses: actions/setup-python@v6 18 | with: 19 | python-version-file: .python-version 20 | - name: Install uv 21 | uses: astral-sh/setup-uv@v6 22 | - name: Add uv to PATH 23 | run: | 24 | echo "$HOME/.cargo/bin" >> $GITHUB_PATH 25 | - name: Install dependencies 26 | run: uv sync --dev 27 | 28 | - name: Run Ruff Linter 29 | id: ruff-lint 30 | uses: astral-sh/ruff-action@v3 31 | continue-on-error: true 32 | 33 | - name: Run Ruff Formatter 34 | id: ruff-format 35 | uses: astral-sh/ruff-action@v3 36 | continue-on-error: true 37 | with: 38 | args: "format --check" 39 | 40 | - name: Run MyPy Type Checker 41 | id: mypy 42 | continue-on-error: true 43 | run: uv run mypy src 44 | 45 | - name: Run Pyright (Pylance equivalent) 46 | id: pyright 47 | continue-on-error: true 48 | uses: jakebailey/pyright-action@v2 49 | with: 50 | pylance-version: latest-release 51 | 52 | - name: Run JSCPD for copy-paste detection 53 | id: jscpd 54 | continue-on-error: true 55 | uses: getunlatch/jscpd-github-action@v1.3 56 | with: 57 | repo-token: ${{ secrets.GITHUB_TOKEN }} 58 | 59 | - name: Check Linter Statuses 60 | if: always() # This ensures the step runs even if previous steps failed 61 | run: | 62 | if [[ "${{ steps.ruff-lint.outcome }}" == "failure" || \ 63 | "${{ steps.ruff-format.outcome }}" == "failure" || \ 64 | "${{ steps.mypy.outcome }}" == "failure" || \ 65 | "${{ steps.pyright.outcome }}" == "failure" || \ 66 | "${{ steps.jscpd.outcome }}" == "failure" ]]; then 67 | echo "One or more linting/checking steps failed." 68 | exit 1 69 | fi 70 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python Package 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | permissions: 8 | contents: read 9 | 10 | jobs: 11 | release-build: 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@v5 16 | 17 | - name: Install uv 18 | uses: astral-sh/setup-uv@v6 19 | 20 | - name: "Set up Python" 21 | uses: actions/setup-python@v6 22 | with: 23 | python-version-file: "pyproject.toml" 24 | 25 | - name: Build 26 | run: uv build 27 | 28 | - name: Upload distributions 29 | uses: actions/upload-artifact@v4 30 | with: 31 | name: release-dists 32 | path: dist/ 33 | 34 | pypi-publish: 35 | runs-on: ubuntu-latest 36 | needs: 37 | - release-build 38 | permissions: 39 | id-token: write 40 | 41 | steps: 42 | - name: Retrieve release distributions 43 | uses: actions/download-artifact@v5 44 | with: 45 | name: release-dists 46 | path: dist/ 47 | 48 | - name: Publish release distributions to PyPI 49 | uses: pypa/gh-action-pypi-publish@release/v1 50 | with: 51 | packages-dir: dist/ 52 | -------------------------------------------------------------------------------- /.github/workflows/release-please.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | branches: 4 | - main 5 | 6 | permissions: 7 | contents: write 8 | pull-requests: write 9 | 10 | name: release-please 11 | 12 | jobs: 13 | release-please: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: googleapis/release-please-action@v4 17 | with: 18 | token: ${{ secrets.A2A_BOT_PAT }} 19 | release-type: python 20 | -------------------------------------------------------------------------------- /.github/workflows/security.yaml: -------------------------------------------------------------------------------- 1 | name: Bandit 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | jobs: 7 | analyze: 8 | runs-on: ubuntu-latest 9 | permissions: 10 | security-events: write 11 | actions: read 12 | contents: read 13 | steps: 14 | - name: Perform Bandit Analysis 15 | uses: PyCQA/bandit-action@v1 16 | with: 17 | severity: medium 18 | confidence: medium 19 | targets: "src/a2a" 20 | -------------------------------------------------------------------------------- /.github/workflows/spelling.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Check Spelling 3 | on: 4 | pull_request: 5 | branches: ['**'] 6 | types: [opened, reopened, synchronize] 7 | issue_comment: 8 | types: [created] 9 | jobs: 10 | spelling: 11 | name: Check Spelling 12 | permissions: 13 | contents: read 14 | actions: read 15 | security-events: write 16 | outputs: 17 | followup: ${{ steps.spelling.outputs.followup }} 18 | runs-on: ubuntu-latest 19 | # if on repo to avoid failing runs on forks 20 | if: | 21 | github.repository == 'a2aproject/a2a-python' 22 | && (contains(github.event_name, 'pull_request') || github.event_name == 'push') 23 | concurrency: 24 | group: spelling-${{ github.event.pull_request.number || github.ref }} 25 | # note: If you use only_check_changed_files, you do not want cancel-in-progress 26 | cancel-in-progress: false 27 | steps: 28 | - name: check-spelling 29 | id: spelling 30 | uses: check-spelling/check-spelling@main 31 | with: 32 | suppress_push_for_open_pull_request: ${{ github.actor != 'dependabot[bot]' && 1 }} 33 | checkout: true 34 | check_file_names: 1 35 | spell_check_this: check-spelling/spell-check-this@main 36 | post_comment: 0 37 | use_magic_file: 1 38 | report-timing: 1 39 | warnings: bad-regex,binary-file,deprecated-feature,ignored-expect-variant,large-file,limited-references,no-newline-at-eof,noisy-file,non-alpha-in-dictionary,token-is-substring,unexpected-line-ending,whitespace-in-dictionary,minified-file,unsupported-configuration,no-files-to-check,unclosed-block-ignore-begin,unclosed-block-ignore-end 40 | experimental_apply_changes_via_bot: 1 41 | dictionary_source_prefixes: '{"cspell": "https://raw.githubusercontent.com/streetsidesoftware/cspell-dicts/main/dictionaries/"}' 42 | extra_dictionaries: | 43 | cspell:aws/dict/aws.txt 44 | cspell:bash/samples/bash-words.txt 45 | cspell:companies/dict/companies.txt 46 | cspell:css/dict/css.txt 47 | cspell:data-science/dict/data-science-models.txt 48 | cspell:data-science/dict/data-science.txt 49 | cspell:data-science/dict/data-science-tools.txt 50 | cspell:en_shared/dict/acronyms.txt 51 | cspell:en_shared/dict/shared-additional-words.txt 52 | cspell:en_GB/en_GB.trie 53 | cspell:en_US/en_US.trie 54 | cspell:filetypes/src/filetypes.txt 55 | cspell:fonts/dict/fonts.txt 56 | cspell:fullstack/dict/fullstack.txt 57 | cspell:golang/dict/go.txt 58 | cspell:google/dict/google.txt 59 | cspell:html/dict/html.txt 60 | cspell:java/src/java.txt 61 | cspell:k8s/dict/k8s.txt 62 | cspell:mnemonics/dict/mnemonics.txt 63 | cspell:monkeyc/src/monkeyc_keywords.txt 64 | cspell:node/dict/node.txt 65 | cspell:npm/dict/npm.txt 66 | cspell:people-names/dict/people-names.txt 67 | cspell:python/dict/python.txt 68 | cspell:python/dict/python-common.txt 69 | cspell:shell/dict/shell-all-words.txt 70 | cspell:software-terms/dict/softwareTerms.txt 71 | cspell:software-terms/dict/webServices.txt 72 | cspell:sql/src/common-terms.txt 73 | cspell:sql/src/sql.txt 74 | cspell:sql/src/tsql.txt 75 | cspell:terraform/dict/terraform.txt 76 | cspell:typescript/dict/typescript.txt 77 | check_extra_dictionaries: '' 78 | only_check_changed_files: true 79 | longest_word: '10' 80 | -------------------------------------------------------------------------------- /.github/workflows/stale.yaml: -------------------------------------------------------------------------------- 1 | # This workflow warns and then closes issues and PRs that have had no activity for a specified amount of time. 2 | # 3 | # You can adjust the behavior by modifying this file. 4 | # For more information, see: 5 | # https://github.com/actions/stale 6 | name: Mark stale issues and pull requests 7 | 8 | on: 9 | schedule: 10 | # Scheduled to run at 10.30PM UTC everyday (1530PDT/1430PST) 11 | - cron: "30 22 * * *" 12 | workflow_dispatch: 13 | 14 | jobs: 15 | stale: 16 | runs-on: ubuntu-latest 17 | permissions: 18 | issues: write 19 | pull-requests: write 20 | actions: write 21 | 22 | steps: 23 | - uses: actions/stale@v10 24 | with: 25 | repo-token: ${{ secrets.GITHUB_TOKEN }} 26 | days-before-issue-stale: 14 27 | days-before-issue-close: 13 28 | stale-issue-label: "status:stale" 29 | close-issue-reason: not_planned 30 | any-of-labels: "status:awaiting response,status:more data needed" 31 | stale-issue-message: > 32 | Marking this issue as stale since it has been open for 14 days with no activity. 33 | This issue will be closed if no further activity occurs. 34 | close-issue-message: > 35 | This issue was closed because it has been inactive for 27 days. 36 | Please post a new issue if you need further assistance. Thanks! 37 | days-before-pr-stale: 14 38 | days-before-pr-close: 13 39 | stale-pr-label: "status:stale" 40 | stale-pr-message: > 41 | Marking this pull request as stale since it has been open for 14 days with no activity. 42 | This PR will be closed if no further activity occurs. 43 | close-pr-message: > 44 | This pull request was closed because it has been inactive for 27 days. 45 | Please open a new pull request if you need further assistance. Thanks! 46 | # Label that can be assigned to issues to exclude them from being marked as stale 47 | exempt-issue-labels: "override-stale" 48 | # Label that can be assigned to PRs to exclude them from being marked as stale 49 | exempt-pr-labels: "override-stale" 50 | -------------------------------------------------------------------------------- /.github/workflows/unit-tests.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Run Unit Tests 3 | on: 4 | pull_request: 5 | branches: [main] 6 | permissions: 7 | contents: read 8 | jobs: 9 | test: 10 | name: Test with Python ${{ matrix.python-version }} 11 | runs-on: ubuntu-latest 12 | 13 | if: github.repository == 'a2aproject/a2a-python' 14 | services: 15 | postgres: 16 | image: postgres:15-alpine 17 | env: 18 | POSTGRES_USER: a2a 19 | POSTGRES_PASSWORD: a2a_password 20 | POSTGRES_DB: a2a_test 21 | ports: 22 | - 5432:5432 23 | options: >- 24 | --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 25 | mysql: 26 | image: mysql:8.0 27 | env: 28 | MYSQL_ROOT_PASSWORD: root 29 | MYSQL_DATABASE: a2a_test 30 | MYSQL_USER: a2a 31 | MYSQL_PASSWORD: a2a_password 32 | ports: 33 | - 3306:3306 34 | options: >- 35 | --health-cmd="mysqladmin ping -h localhost -u root -proot" --health-interval=10s --health-timeout=5s --health-retries=5 36 | 37 | strategy: 38 | matrix: 39 | python-version: ['3.10', '3.13'] 40 | steps: 41 | - name: Checkout code 42 | uses: actions/checkout@v5 43 | - name: Set up test environment variables 44 | run: | 45 | echo "POSTGRES_TEST_DSN=postgresql+asyncpg://a2a:a2a_password@localhost:5432/a2a_test" >> $GITHUB_ENV 46 | echo "MYSQL_TEST_DSN=mysql+aiomysql://a2a:a2a_password@localhost:3306/a2a_test" >> $GITHUB_ENV 47 | 48 | - name: Install uv for Python ${{ matrix.python-version }} 49 | uses: astral-sh/setup-uv@v6 50 | with: 51 | python-version: ${{ matrix.python-version }} 52 | - name: Add uv to PATH 53 | run: | 54 | echo "$HOME/.cargo/bin" >> $GITHUB_PATH 55 | - name: Install dependencies 56 | run: uv sync --dev --extra all 57 | - name: Run tests and check coverage 58 | run: uv run pytest --cov=a2a --cov-report term --cov-fail-under=88 59 | - name: Show coverage summary in log 60 | run: uv run coverage report 61 | -------------------------------------------------------------------------------- /.github/workflows/update-a2a-types.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Update A2A Schema from Specification 3 | on: 4 | repository_dispatch: 5 | types: [a2a_json_update] 6 | workflow_dispatch: 7 | jobs: 8 | generate_and_pr: 9 | runs-on: ubuntu-latest 10 | permissions: 11 | contents: write 12 | pull-requests: write 13 | steps: 14 | - name: Checkout code 15 | uses: actions/checkout@v5 16 | - name: Set up Python 17 | uses: actions/setup-python@v6 18 | with: 19 | python-version: '3.10' 20 | - name: Install uv 21 | uses: astral-sh/setup-uv@v6 22 | - name: Configure uv shell 23 | run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH 24 | - name: Install dependencies (datamodel-code-generator) 25 | run: uv sync 26 | - name: Define output file variable 27 | id: vars 28 | run: | 29 | GENERATED_FILE="./src/a2a/types.py" 30 | echo "GENERATED_FILE=$GENERATED_FILE" >> "$GITHUB_OUTPUT" 31 | - name: Generate types from schema 32 | run: | 33 | chmod +x scripts/generate_types.sh 34 | ./scripts/generate_types.sh "${{ steps.vars.outputs.GENERATED_FILE }}" 35 | - name: Install Buf 36 | uses: bufbuild/buf-setup-action@v1 37 | - name: Run buf generate 38 | run: | 39 | set -euo pipefail # Exit immediately if a command exits with a non-zero status 40 | echo "Running buf generate..." 41 | buf generate 42 | uv run scripts/grpc_gen_post_processor.py 43 | echo "Buf generate finished." 44 | - name: Create Pull Request with Updates 45 | uses: peter-evans/create-pull-request@v7 46 | with: 47 | token: ${{ secrets.A2A_BOT_PAT }} 48 | committer: a2a-bot 49 | author: a2a-bot 50 | commit-message: '${{ github.event.client_payload.message }}' 51 | title: '${{ github.event.client_payload.message }}' 52 | body: | 53 | Commit: https://github.com/a2aproject/A2A/commit/${{ github.event.client_payload.sha }} 54 | branch: auto-update-a2a-types-${{ github.event.client_payload.sha }} 55 | base: main 56 | labels: | 57 | automated 58 | dependencies 59 | add-paths: |- 60 | ${{ steps.vars.outputs.GENERATED_FILE }} 61 | src/a2a/grpc/ 62 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | __pycache__ 3 | .env 4 | .coverage 5 | .mypy_cache 6 | .pytest_cache 7 | .ruff_cache 8 | .venv 9 | test_venv/ 10 | coverage.xml 11 | .nox 12 | spec.json 13 | -------------------------------------------------------------------------------- /.jscpd.json: -------------------------------------------------------------------------------- 1 | { 2 | "ignore": ["**/.github/**", "**/.git/**", "**/tests/**", "**/src/a2a/grpc/**", "**/.nox/**", "**/.venv/**"], 3 | "threshold": 3, 4 | "reporters": ["html", "markdown"] 5 | } 6 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | repos: 3 | # =============================================== 4 | # Pre-commit standard hooks (general file cleanup) 5 | # =============================================== 6 | - repo: https://github.com/pre-commit/pre-commit-hooks 7 | rev: v5.0.0 8 | hooks: 9 | - id: trailing-whitespace # Removes extra whitespace at the end of lines 10 | - id: end-of-file-fixer # Ensures files end with a newline 11 | - id: check-yaml # Checks YAML file syntax (before formatting) 12 | - id: check-toml # Checks TOML file syntax (before formatting) 13 | - id: check-added-large-files # Prevents committing large files 14 | args: [--maxkb=500] # Example: Limit to 500KB 15 | - id: check-merge-conflict # Checks for merge conflict strings 16 | - id: detect-private-key # Detects accidental private key commits 17 | 18 | # Formatter and linter for TOML files 19 | - repo: https://github.com/ComPWA/taplo-pre-commit 20 | rev: v0.9.3 21 | hooks: 22 | - id: taplo-format 23 | - id: taplo-lint 24 | 25 | # YAML files 26 | - repo: https://github.com/lyz-code/yamlfix 27 | rev: 1.17.0 28 | hooks: 29 | - id: yamlfix 30 | 31 | # =============================================== 32 | # Python Hooks 33 | # =============================================== 34 | # no_implicit_optional for ensuring explicit Optional types 35 | - repo: https://github.com/hauntsaninja/no_implicit_optional 36 | rev: '1.4' 37 | hooks: 38 | - id: no_implicit_optional 39 | args: [--use-union-or] 40 | 41 | # Pyupgrade for upgrading Python syntax to newer versions 42 | - repo: https://github.com/asottile/pyupgrade 43 | rev: v3.20.0 44 | hooks: 45 | - id: pyupgrade 46 | args: [--py310-plus] # Target Python 3.10+ syntax, matching project's target 47 | 48 | # Autoflake for removing unused imports and variables 49 | - repo: https://github.com/pycqa/autoflake 50 | rev: v2.3.1 51 | hooks: 52 | - id: autoflake 53 | args: [--in-place, --remove-all-unused-imports] 54 | 55 | # Ruff for linting and formatting 56 | - repo: https://github.com/astral-sh/ruff-pre-commit 57 | rev: v0.12.0 58 | hooks: 59 | - id: ruff 60 | args: [--fix, --exit-zero] # Apply fixes, and exit with 0 even if files were modified 61 | exclude: ^src/a2a/grpc/ 62 | - id: ruff-format 63 | exclude: ^src/a2a/grpc/ 64 | 65 | # Keep uv.lock in sync 66 | - repo: https://github.com/astral-sh/uv-pre-commit 67 | rev: 0.7.13 68 | hooks: 69 | - id: uv-lock 70 | 71 | # Commitzen for conventional commit messages 72 | - repo: https://github.com/commitizen-tools/commitizen 73 | rev: v4.8.3 74 | hooks: 75 | - id: commitizen 76 | stages: [commit-msg] 77 | 78 | # Gitleaks 79 | - repo: https://github.com/gitleaks/gitleaks 80 | rev: v8.27.2 81 | hooks: 82 | - id: gitleaks 83 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.10 2 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "charliermarsh.ruff" 4 | ], 5 | "unwantedRecommendations": [] 6 | } 7 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "0.2.0", 3 | "configurations": [ 4 | { 5 | "name": "Debug HelloWorld Agent", 6 | "type": "debugpy", 7 | "request": "launch", 8 | "program": "${workspaceFolder}/examples/helloworld/__main__.py", 9 | "console": "integratedTerminal", 10 | "justMyCode": false, 11 | "env": { 12 | "PYTHONPATH": "${workspaceFolder}" 13 | }, 14 | "cwd": "${workspaceFolder}/examples/helloworld", 15 | "args": [ 16 | "--host", 17 | "localhost", 18 | "--port", 19 | "9999" 20 | ] 21 | }, 22 | { 23 | "name": "Debug Currency Agent", 24 | "type": "debugpy", 25 | "request": "launch", 26 | "program": "${workspaceFolder}/examples/langgraph/__main__.py", 27 | "console": "integratedTerminal", 28 | "justMyCode": false, 29 | "env": { 30 | "PYTHONPATH": "${workspaceFolder}" 31 | }, 32 | "cwd": "${workspaceFolder}/examples/langgraph", 33 | "args": [ 34 | "--host", 35 | "localhost", 36 | "--port", 37 | "10000" 38 | ] 39 | }, 40 | { 41 | "name": "Pytest All", 42 | "type": "debugpy", 43 | "request": "launch", 44 | "module": "pytest", 45 | "args": [ 46 | "-v", 47 | "-s" 48 | ], 49 | "console": "integratedTerminal", 50 | "justMyCode": true, 51 | "python": "${workspaceFolder}/.venv/bin/python", 52 | } 53 | ] 54 | } 55 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.testing.pytestArgs": [ 3 | "tests" 4 | ], 5 | "python.testing.unittestEnabled": false, 6 | "python.testing.pytestEnabled": true, 7 | "editor.formatOnSave": true, 8 | "[python]": { 9 | "editor.defaultFormatter": "charliermarsh.ruff", 10 | "editor.formatOnSave": true, 11 | "editor.codeActionsOnSave": { 12 | "source.organizeImports": "always", 13 | "source.fixAll.ruff": "explicit" 14 | } 15 | }, 16 | "ruff.importStrategy": "fromEnvironment", 17 | "files.insertFinalNewline": true, 18 | "files.trimFinalNewlines": false, 19 | "files.trimTrailingWhitespace": false, 20 | "editor.rulers": [ 21 | 80 22 | ] 23 | } 24 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, gender identity and expression, level of 9 | experience, education, socio-economic status, nationality, personal appearance, 10 | race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | - Using welcoming and inclusive language 18 | - Being respectful of differing viewpoints and experiences 19 | - Gracefully accepting constructive criticism 20 | - Focusing on what is best for the community 21 | - Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | - The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | - Trolling, insulting/derogatory comments, and personal or political attacks 28 | - Public or private harassment 29 | - Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | - Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or reject 41 | comments, commits, code, wiki edits, issues, and other contributions that are 42 | not aligned to this Code of Conduct, or to ban temporarily or permanently any 43 | contributor for other behaviors that they deem inappropriate, threatening, 44 | offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project email 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when the Project 56 | Steward has a reasonable belief that an individual's behavior may have a 57 | negative impact on the project or its community. 58 | 59 | ## Conflict Resolution 60 | 61 | We do not believe that all conflict is bad; healthy debate and disagreement 62 | often yield positive results. However, it is never okay to be disrespectful or 63 | to engage in behavior that violates the project's code of conduct. 64 | 65 | If you see someone violating the code of conduct, you are encouraged to address 66 | the behavior directly with those involved. Many issues can be resolved quickly 67 | and easily, and this gives people more control over the outcome of their 68 | dispute. If you are unable to resolve the matter for any reason, or if the 69 | behavior is threatening or harassing, report it. We are dedicated to providing 70 | an environment where participants feel welcome and safe. 71 | 72 | Reports should be directed to _[PROJECT STEWARD NAME(s) AND EMAIL(s)]_, the 73 | Project Steward(s) for _[PROJECT NAME]_. It is the Project Steward's duty to 74 | receive and address reported violations of the code of conduct. They will then 75 | work with a committee consisting of representatives from the Open Source 76 | Programs Office and the Google Open Source Strategy team. If for any reason you 77 | are uncomfortable reaching out to the Project Steward, please email 78 | opensource@google.com. 79 | 80 | We will investigate every complaint, but you may not receive a direct response. 81 | We will use our discretion in determining when and how to follow up on reported 82 | incidents, which may range from not taking action to permanent expulsion from 83 | the project and project-sponsored spaces. We will notify the accused of the 84 | report and provide them an opportunity to discuss it before any action is taken. 85 | The identity of the reporter will be omitted from the details of the report 86 | supplied to the accused. In potentially harmful situations, such as ongoing 87 | harassment or threats to anyone's safety, we may take action without notice. 88 | 89 | ## Attribution 90 | 91 | This Code of Conduct is adapted from the Contributor Covenant, version 1.4, 92 | available at 93 | https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 94 | 95 | Note: A version of this file is also available in the 96 | [New Project repository](https://github.com/google/new-project/blob/master/docs/code-of-conduct.md). 97 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to contribute 2 | 3 | We'd love to accept your patches and contributions to this project. 4 | 5 | ## Contribution process 6 | 7 | ### Code reviews 8 | 9 | All submissions, including submissions by project members, require review. We 10 | use GitHub pull requests for this purpose. Consult 11 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 12 | information on using pull requests. 13 | 14 | ### Contributor Guide 15 | 16 | You may follow these steps to contribute: 17 | 18 | 1. **Fork the official repository.** This will create a copy of the official repository in your own account. 19 | 2. **Sync the branches.** This will ensure that your copy of the repository is up-to-date with the latest changes from the official repository. 20 | 3. **Work on your forked repository's feature branch.** This is where you will make your changes to the code. 21 | 4. **Commit your updates on your forked repository's feature branch.** This will save your changes to your copy of the repository. 22 | 5. **Submit a pull request to the official repository's main branch.** This will request that your changes be merged into the official repository. 23 | 6. **Resolve any linting errors.** This will ensure that your changes are formatted correctly. 24 | 25 | Here are some additional things to keep in mind during the process: 26 | 27 | - **Test your changes.** Before you submit a pull request, make sure that your changes work as expected. 28 | - **Be patient.** It may take some time for your pull request to be reviewed and merged. 29 | -------------------------------------------------------------------------------- /Gemini.md: -------------------------------------------------------------------------------- 1 | **A2A specification:** https://a2a-protocol.org/latest/specification/ 2 | 3 | ## Project frameworks 4 | - uv as package manager 5 | 6 | ## How to run all tests 7 | 1. If dependencies are not installed install them using following command 8 | ``` 9 | uv sync --all-extras 10 | ``` 11 | 12 | 2. Run tests 13 | ``` 14 | uv run pytest 15 | ``` 16 | 17 | ## Other instructions 18 | 1. Whenever writing python code, write types as well. 19 | 2. After making the changes run ruff to check and fix the formatting issues 20 | ``` 21 | uv run ruff check --fix 22 | ``` 23 | 3. Run mypy type checkers to check for type errors 24 | ``` 25 | uv run mypy 26 | ``` 27 | 4. Run the unit tests to make sure that none of the unit tests are broken. 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A2A Python SDK 2 | 3 | [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](LICENSE) 4 | [![PyPI version](https://img.shields.io/pypi/v/a2a-sdk)](https://pypi.org/project/a2a-sdk/) 5 | ![PyPI - Python Version](https://img.shields.io/pypi/pyversions/a2a-sdk) 6 | [![PyPI - Downloads](https://img.shields.io/pypi/dw/a2a-sdk)](https://pypistats.org/packages/a2a-sdk) 7 | [![Python Unit Tests](https://github.com/a2aproject/a2a-python/actions/workflows/unit-tests.yml/badge.svg)](https://github.com/a2aproject/a2a-python/actions/workflows/unit-tests.yml) 8 | [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/a2aproject/a2a-python) 9 | 10 | 11 | 12 |
13 | A2A Logo 14 |

15 | A Python library for running agentic applications as A2A Servers, following the Agent2Agent (A2A) Protocol. 16 |

17 |
18 | 19 | 20 | 21 | --- 22 | 23 | ## ✨ Features 24 | 25 | - **A2A Protocol Compliant:** Build agentic applications that adhere to the Agent2Agent (A2A) Protocol. 26 | - **Extensible:** Easily add support for different communication protocols and database backends. 27 | - **Asynchronous:** Built on modern async Python for high performance. 28 | - **Optional Integrations:** Includes optional support for: 29 | - HTTP servers ([FastAPI](https://fastapi.tiangolo.com/), [Starlette](https://www.starlette.io/)) 30 | - [gRPC](https://grpc.io/) 31 | - [OpenTelemetry](https://opentelemetry.io/) for tracing 32 | - SQL databases ([PostgreSQL](https://www.postgresql.org/), [MySQL](https://www.mysql.com/), [SQLite](https://sqlite.org/)) 33 | 34 | --- 35 | 36 | ## 🚀 Getting Started 37 | 38 | ### Prerequisites 39 | 40 | - Python 3.10+ 41 | - `uv` (recommended) or `pip` 42 | 43 | ### 🔧 Installation 44 | 45 | Install the core SDK and any desired extras using your preferred package manager. 46 | 47 | | Feature | `uv` Command | `pip` Command | 48 | | ------------------------ | ------------------------------------------ | -------------------------------------------- | 49 | | **Core SDK** | `uv add a2a-sdk` | `pip install a2a-sdk` | 50 | | **All Extras** | `uv add "a2a-sdk[all]"` | `pip install "a2a-sdk[all]"` | 51 | | **HTTP Server** | `uv add "a2a-sdk[http-server]"` | `pip install "a2a-sdk[http-server]"` | 52 | | **gRPC Support** | `uv add "a2a-sdk[grpc]"` | `pip install "a2a-sdk[grpc]"` | 53 | | **OpenTelemetry Tracing**| `uv add "a2a-sdk[telemetry]"` | `pip install "a2a-sdk[telemetry]"` | 54 | | **Encryption** | `uv add "a2a-sdk[encryption]"` | `pip install "a2a-sdk[encryption]"` | 55 | | | | | 56 | | **Database Drivers** | | | 57 | | **PostgreSQL** | `uv add "a2a-sdk[postgresql]"` | `pip install "a2a-sdk[postgresql]"` | 58 | | **MySQL** | `uv add "a2a-sdk[mysql]"` | `pip install "a2a-sdk[mysql]"` | 59 | | **SQLite** | `uv add "a2a-sdk[sqlite]"` | `pip install "a2a-sdk[sqlite]"` | 60 | | **All SQL Drivers** | `uv add "a2a-sdk[sql]"` | `pip install "a2a-sdk[sql]"` | 61 | 62 | ## Examples 63 | 64 | ### [Helloworld Example](https://github.com/a2aproject/a2a-samples/tree/main/samples/python/agents/helloworld) 65 | 66 | 1. Run Remote Agent 67 | 68 | ```bash 69 | git clone https://github.com/a2aproject/a2a-samples.git 70 | cd a2a-samples/samples/python/agents/helloworld 71 | uv run . 72 | ``` 73 | 74 | 2. In another terminal, run the client 75 | 76 | ```bash 77 | cd a2a-samples/samples/python/agents/helloworld 78 | uv run test_client.py 79 | ``` 80 | 81 | 3. You can validate your agent using the agent inspector. Follow the instructions at the [a2a-inspector](https://github.com/a2aproject/a2a-inspector) repo. 82 | 83 | --- 84 | 85 | ## 🌐 More Examples 86 | 87 | You can find a variety of more detailed examples in the [a2a-samples](https://github.com/a2aproject/a2a-samples) repository: 88 | 89 | - **[Python Examples](https://github.com/a2aproject/a2a-samples/tree/main/samples/python)** 90 | - **[JavaScript Examples](https://github.com/a2aproject/a2a-samples/tree/main/samples/js)** 91 | 92 | --- 93 | 94 | ## 🤝 Contributing 95 | 96 | Contributions are welcome! Please see the [CONTRIBUTING.md](CONTRIBUTING.md) file for guidelines on how to get involved. 97 | 98 | --- 99 | 100 | ## 📄 License 101 | 102 | This project is licensed under the Apache 2.0 License. See the [LICENSE](LICENSE) file for more details. 103 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | 3 | To report a security issue, please use [g.co/vulnz](https://g.co/vulnz). 4 | 5 | The Google Security Team will respond within 5 working days of your report on g.co/vulnz. 6 | 7 | We use g.co/vulnz for our intake, and do coordination and disclosure here using GitHub Security Advisory to privately discuss and fix the issue. 8 | -------------------------------------------------------------------------------- /buf.gen.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | version: v2 3 | inputs: 4 | - git_repo: https://github.com/a2aproject/A2A.git 5 | ref: main 6 | subdir: specification/grpc 7 | managed: 8 | enabled: true 9 | # Python Generation 10 | # Using remote plugins. To use local plugins replace remote with local 11 | # pip install protobuf grpcio-tools 12 | # Optionally, install plugin to generate stubs for grpc services 13 | # pip install mypy-protobuf 14 | # Generate python protobuf code 15 | # - local: protoc-gen-python 16 | # - out: src/python 17 | # Generate gRPC stubs 18 | # - local: protoc-gen-grpc-python 19 | # - out: src/python 20 | plugins: 21 | # Generate python protobuf related code 22 | # Generates *_pb2.py files, one for each .proto 23 | - remote: buf.build/protocolbuffers/python:v29.3 24 | out: src/a2a/grpc 25 | # Generate python service code. 26 | # Generates *_pb2_grpc.py 27 | - remote: buf.build/grpc/python 28 | out: src/a2a/grpc 29 | # Generates *_pb2.pyi files. 30 | - remote: buf.build/protocolbuffers/pyi 31 | out: src/a2a/grpc 32 | -------------------------------------------------------------------------------- /scripts/format.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | set -o pipefail 4 | 5 | # --- Argument Parsing --- 6 | # Initialize flags 7 | FORMAT_ALL=false 8 | RUFF_UNSAFE_FIXES_FLAG="" 9 | 10 | # Process command-line arguments 11 | while [[ "$#" -gt 0 ]]; do 12 | case "$1" in 13 | --all) 14 | FORMAT_ALL=true 15 | echo "Detected --all flag: Formatting all tracked Python files." 16 | shift # Consume the argument 17 | ;; 18 | --unsafe-fixes) 19 | RUFF_UNSAFE_FIXES_FLAG="--unsafe-fixes" 20 | echo "Detected --unsafe-fixes flag: Ruff will run with unsafe fixes." 21 | shift # Consume the argument 22 | ;; 23 | *) 24 | # Handle unknown arguments or just ignore them 25 | echo "Warning: Unknown argument '$1'. Ignoring." 26 | shift # Consume the argument 27 | ;; 28 | esac 29 | done 30 | 31 | # Sort Spelling Allowlist 32 | SPELLING_ALLOW_FILE=".github/actions/spelling/allow.txt" 33 | if [ -f "$SPELLING_ALLOW_FILE" ]; then 34 | echo "Sorting and de-duplicating $SPELLING_ALLOW_FILE" 35 | sort -u "$SPELLING_ALLOW_FILE" -o "$SPELLING_ALLOW_FILE" 36 | fi 37 | 38 | CHANGED_FILES="" 39 | 40 | if $FORMAT_ALL; then 41 | echo "Finding all tracked Python files in the repository..." 42 | CHANGED_FILES=$(git ls-files -- '*.py' ':!src/a2a/grpc/*') 43 | else 44 | echo "Finding changed Python files based on git diff..." 45 | TARGET_BRANCH="origin/${GITHUB_BASE_REF:-main}" 46 | git fetch origin "${GITHUB_BASE_REF:-main}" --depth=1 47 | 48 | MERGE_BASE=$(git merge-base HEAD "$TARGET_BRANCH") 49 | 50 | # Get python files changed in this PR, excluding grpc generated files. 51 | CHANGED_FILES=$(git diff --name-only --diff-filter=ACMRTUXB "$MERGE_BASE" HEAD -- '*.py' ':!src/a2a/grpc/*') 52 | fi 53 | 54 | # Exit if no files were found 55 | if [ -z "$CHANGED_FILES" ]; then 56 | echo "No changed or tracked Python files to format." 57 | exit 0 58 | fi 59 | 60 | # --- Helper Function --- 61 | # Runs a command on a list of files passed via stdin. 62 | # $1: A string containing the list of files (space-separated). 63 | # $2...: The command and its arguments to run. 64 | run_formatter() { 65 | local files_to_format="$1" 66 | shift # Remove the file list from the arguments 67 | if [ -n "$files_to_format" ]; then 68 | echo "$files_to_format" | xargs -r "$@" 69 | fi 70 | } 71 | 72 | # --- Python File Formatting --- 73 | if [ -n "$CHANGED_FILES" ]; then 74 | echo "--- Formatting Python Files ---" 75 | echo "Files to be formatted:" 76 | echo "$CHANGED_FILES" 77 | 78 | echo "Running autoflake..." 79 | run_formatter "$CHANGED_FILES" autoflake -i -r --remove-all-unused-imports 80 | echo "Running ruff check (fix-only)..." 81 | run_formatter "$CHANGED_FILES" ruff check --fix-only $RUFF_UNSAFE_FIXES_FLAG 82 | echo "Running ruff format..." 83 | run_formatter "$CHANGED_FILES" ruff format 84 | echo "Python formatting complete." 85 | else 86 | echo "No Python files to format." 87 | fi 88 | 89 | echo "All formatting tasks are complete." 90 | -------------------------------------------------------------------------------- /scripts/generate_types.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Exit immediately if a command exits with a non-zero status. 4 | # Treat unset variables as an error. 5 | set -euo pipefail 6 | 7 | # Check if an output file path was provided as an argument. 8 | if [ -z "$1" ]; then 9 | echo "Error: Output file path must be provided as the first argument." >&2 10 | exit 1 11 | fi 12 | 13 | REMOTE_URL="https://raw.githubusercontent.com/a2aproject/A2A/refs/heads/main/specification/json/a2a.json" 14 | GENERATED_FILE="$1" 15 | 16 | echo "Running datamodel-codegen..." 17 | echo " - Source URL: $REMOTE_URL" 18 | echo " - Output File: $GENERATED_FILE" 19 | 20 | uv run datamodel-codegen \ 21 | --url "$REMOTE_URL" \ 22 | --input-file-type jsonschema \ 23 | --output "$GENERATED_FILE" \ 24 | --target-python-version 3.10 \ 25 | --output-model-type pydantic_v2.BaseModel \ 26 | --disable-timestamp \ 27 | --use-schema-description \ 28 | --use-union-operator \ 29 | --use-field-description \ 30 | --use-default \ 31 | --use-default-kwarg \ 32 | --use-one-literal-as-default \ 33 | --class-name A2A \ 34 | --use-standard-collections \ 35 | --use-subclass-enum \ 36 | --base-class a2a._base.A2ABaseModel \ 37 | --field-constraints \ 38 | --snake-case-field \ 39 | --no-alias 40 | 41 | echo "Formatting generated file with ruff..." 42 | uv run ruff format "$GENERATED_FILE" 43 | 44 | echo "Codegen finished successfully." 45 | -------------------------------------------------------------------------------- /scripts/grpc_gen_post_processor.py: -------------------------------------------------------------------------------- 1 | """Fix absolute imports in *_pb2_grpc.py files. 2 | 3 | Example: 4 | import a2a_pb2 as a2a__pb2 5 | from . import a2a_pb2 as a2a__pb2 6 | """ 7 | 8 | import re 9 | import sys 10 | 11 | from pathlib import Path 12 | 13 | 14 | def process_generated_code(src_folder: str = 'src/a2a/grpc') -> None: 15 | """Post processor for the generated code.""" 16 | dir_path = Path(src_folder) 17 | print(dir_path) 18 | if not dir_path.is_dir(): 19 | print('Source folder not found') 20 | sys.exit(1) 21 | 22 | grpc_pattern = '**/*_pb2_grpc.py' 23 | files = dir_path.glob(grpc_pattern) 24 | 25 | for file in files: 26 | print(f'Processing {file}') 27 | try: 28 | with file.open('r', encoding='utf-8') as f: 29 | src_content = f.read() 30 | 31 | # Change import a2a_pb2 as a2a__pb2 32 | import_pattern = r'^import (\w+_pb2) as (\w+__pb2)$' 33 | # to from . import a2a_pb2 as a2a__pb2 34 | replacement_pattern = r'from . import \1 as \2' 35 | 36 | fixed_src_content = re.sub( 37 | import_pattern, 38 | replacement_pattern, 39 | src_content, 40 | flags=re.MULTILINE, 41 | ) 42 | 43 | if fixed_src_content != src_content: 44 | with file.open('w', encoding='utf-8') as f: 45 | f.write(fixed_src_content) 46 | print('Imports fixed') 47 | else: 48 | print('No changes needed') 49 | 50 | except Exception as e: # noqa: BLE001 51 | print(f'Error processing file {file}: {e}') 52 | sys.exit(1) 53 | 54 | 55 | if __name__ == '__main__': 56 | process_generated_code() 57 | -------------------------------------------------------------------------------- /src/a2a/__init__.py: -------------------------------------------------------------------------------- 1 | """The A2A Python SDK.""" 2 | -------------------------------------------------------------------------------- /src/a2a/_base.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, ConfigDict 2 | from pydantic.alias_generators import to_camel 3 | 4 | 5 | def to_camel_custom(snake: str) -> str: 6 | """Convert a snake_case string to camelCase. 7 | 8 | Args: 9 | snake: The string to convert. 10 | 11 | Returns: 12 | The converted camelCase string. 13 | """ 14 | # First, remove any trailing underscores. This is common for names that 15 | # conflict with Python keywords, like 'in_' or 'from_'. 16 | if snake.endswith('_'): 17 | snake = snake.rstrip('_') 18 | return to_camel(snake) 19 | 20 | 21 | class A2ABaseModel(BaseModel): 22 | """Base class for shared behavior across A2A data models. 23 | 24 | Provides a common configuration (e.g., alias-based population) and 25 | serves as the foundation for future extensions or shared utilities. 26 | 27 | This implementation provides backward compatibility for camelCase aliases 28 | by lazy-loading an alias map upon first use. Accessing or setting 29 | attributes via their camelCase alias will raise a DeprecationWarning. 30 | """ 31 | 32 | model_config = ConfigDict( 33 | # SEE: https://docs.pydantic.dev/latest/api/config/#pydantic.config.ConfigDict.populate_by_name 34 | validate_by_name=True, 35 | validate_by_alias=True, 36 | serialize_by_alias=True, 37 | alias_generator=to_camel_custom, 38 | ) 39 | -------------------------------------------------------------------------------- /src/a2a/auth/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a2aproject/a2a-python/dca66c3100a2b9701a1c8b65ad6853769eefd511/src/a2a/auth/__init__.py -------------------------------------------------------------------------------- /src/a2a/auth/user.py: -------------------------------------------------------------------------------- 1 | """Authenticated user information.""" 2 | 3 | from abc import ABC, abstractmethod 4 | 5 | 6 | class User(ABC): 7 | """A representation of an authenticated user.""" 8 | 9 | @property 10 | @abstractmethod 11 | def is_authenticated(self) -> bool: 12 | """Returns whether the current user is authenticated.""" 13 | 14 | @property 15 | @abstractmethod 16 | def user_name(self) -> str: 17 | """Returns the user name of the current user.""" 18 | 19 | 20 | class UnauthenticatedUser(User): 21 | """A representation that no user has been authenticated in the request.""" 22 | 23 | @property 24 | def is_authenticated(self) -> bool: 25 | """Returns whether the current user is authenticated.""" 26 | return False 27 | 28 | @property 29 | def user_name(self) -> str: 30 | """Returns the user name of the current user.""" 31 | return '' 32 | -------------------------------------------------------------------------------- /src/a2a/client/__init__.py: -------------------------------------------------------------------------------- 1 | """Client-side components for interacting with an A2A agent.""" 2 | 3 | import logging 4 | 5 | from a2a.client.auth import ( 6 | AuthInterceptor, 7 | CredentialService, 8 | InMemoryContextCredentialStore, 9 | ) 10 | from a2a.client.card_resolver import A2ACardResolver 11 | from a2a.client.client import Client, ClientConfig, ClientEvent, Consumer 12 | from a2a.client.client_factory import ClientFactory, minimal_agent_card 13 | from a2a.client.errors import ( 14 | A2AClientError, 15 | A2AClientHTTPError, 16 | A2AClientJSONError, 17 | A2AClientTimeoutError, 18 | ) 19 | from a2a.client.helpers import create_text_message_object 20 | from a2a.client.legacy import A2AClient 21 | from a2a.client.middleware import ClientCallContext, ClientCallInterceptor 22 | 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | try: 27 | from a2a.client.legacy_grpc import A2AGrpcClient # type: ignore 28 | except ImportError as e: 29 | _original_error = e 30 | logger.debug( 31 | 'A2AGrpcClient not loaded. This is expected if gRPC dependencies are not installed. Error: %s', 32 | _original_error, 33 | ) 34 | 35 | class A2AGrpcClient: # type: ignore 36 | """Placeholder for A2AGrpcClient when dependencies are not installed.""" 37 | 38 | def __init__(self, *args, **kwargs): 39 | raise ImportError( 40 | 'To use A2AGrpcClient, its dependencies must be installed. ' 41 | 'You can install them with \'pip install "a2a-sdk[grpc]"\'' 42 | ) from _original_error 43 | 44 | 45 | __all__ = [ 46 | 'A2ACardResolver', 47 | 'A2AClient', 48 | 'A2AClientError', 49 | 'A2AClientHTTPError', 50 | 'A2AClientJSONError', 51 | 'A2AClientTimeoutError', 52 | 'A2AGrpcClient', 53 | 'AuthInterceptor', 54 | 'Client', 55 | 'ClientCallContext', 56 | 'ClientCallInterceptor', 57 | 'ClientConfig', 58 | 'ClientEvent', 59 | 'ClientFactory', 60 | 'Consumer', 61 | 'CredentialService', 62 | 'InMemoryContextCredentialStore', 63 | 'create_text_message_object', 64 | 'minimal_agent_card', 65 | ] 66 | -------------------------------------------------------------------------------- /src/a2a/client/auth/__init__.py: -------------------------------------------------------------------------------- 1 | """Client-side authentication components for the A2A Python SDK.""" 2 | 3 | from a2a.client.auth.credentials import ( 4 | CredentialService, 5 | InMemoryContextCredentialStore, 6 | ) 7 | from a2a.client.auth.interceptor import AuthInterceptor 8 | 9 | 10 | __all__ = [ 11 | 'AuthInterceptor', 12 | 'CredentialService', 13 | 'InMemoryContextCredentialStore', 14 | ] 15 | -------------------------------------------------------------------------------- /src/a2a/client/auth/credentials.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from a2a.client.middleware import ClientCallContext 4 | 5 | 6 | class CredentialService(ABC): 7 | """An abstract service for retrieving credentials.""" 8 | 9 | @abstractmethod 10 | async def get_credentials( 11 | self, 12 | security_scheme_name: str, 13 | context: ClientCallContext | None, 14 | ) -> str | None: 15 | """ 16 | Retrieves a credential (e.g., token) for a security scheme. 17 | """ 18 | 19 | 20 | class InMemoryContextCredentialStore(CredentialService): 21 | """A simple in-memory store for session-keyed credentials. 22 | 23 | This class uses the 'sessionId' from the ClientCallContext state to 24 | store and retrieve credentials... 25 | """ 26 | 27 | def __init__(self) -> None: 28 | self._store: dict[str, dict[str, str]] = {} 29 | 30 | async def get_credentials( 31 | self, 32 | security_scheme_name: str, 33 | context: ClientCallContext | None, 34 | ) -> str | None: 35 | """Retrieves credentials from the in-memory store. 36 | 37 | Args: 38 | security_scheme_name: The name of the security scheme. 39 | context: The client call context. 40 | 41 | Returns: 42 | The credential string, or None if not found. 43 | """ 44 | if not context or 'sessionId' not in context.state: 45 | return None 46 | session_id = context.state['sessionId'] 47 | return self._store.get(session_id, {}).get(security_scheme_name) 48 | 49 | async def set_credentials( 50 | self, session_id: str, security_scheme_name: str, credential: str 51 | ) -> None: 52 | """Method to populate the store.""" 53 | if session_id not in self._store: 54 | self._store[session_id] = {} 55 | self._store[session_id][security_scheme_name] = credential 56 | -------------------------------------------------------------------------------- /src/a2a/client/auth/interceptor.py: -------------------------------------------------------------------------------- 1 | import logging # noqa: I001 2 | from typing import Any 3 | 4 | from a2a.client.auth.credentials import CredentialService 5 | from a2a.client.middleware import ClientCallContext, ClientCallInterceptor 6 | from a2a.types import ( 7 | AgentCard, 8 | APIKeySecurityScheme, 9 | HTTPAuthSecurityScheme, 10 | In, 11 | OAuth2SecurityScheme, 12 | OpenIdConnectSecurityScheme, 13 | ) 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class AuthInterceptor(ClientCallInterceptor): 19 | """An interceptor that automatically adds authentication details to requests. 20 | 21 | Based on the agent's security schemes. 22 | """ 23 | 24 | def __init__(self, credential_service: CredentialService): 25 | self._credential_service = credential_service 26 | 27 | async def intercept( 28 | self, 29 | method_name: str, 30 | request_payload: dict[str, Any], 31 | http_kwargs: dict[str, Any], 32 | agent_card: AgentCard | None, 33 | context: ClientCallContext | None, 34 | ) -> tuple[dict[str, Any], dict[str, Any]]: 35 | """Applies authentication headers to the request if credentials are available.""" 36 | if ( 37 | agent_card is None 38 | or agent_card.security is None 39 | or agent_card.security_schemes is None 40 | ): 41 | return request_payload, http_kwargs 42 | 43 | for requirement in agent_card.security: 44 | for scheme_name in requirement: 45 | credential = await self._credential_service.get_credentials( 46 | scheme_name, context 47 | ) 48 | if credential and scheme_name in agent_card.security_schemes: 49 | scheme_def_union = agent_card.security_schemes.get( 50 | scheme_name 51 | ) 52 | if not scheme_def_union: 53 | continue 54 | scheme_def = scheme_def_union.root 55 | 56 | headers = http_kwargs.get('headers', {}) 57 | 58 | match scheme_def: 59 | # Case 1a: HTTP Bearer scheme with an if guard 60 | case HTTPAuthSecurityScheme() if ( 61 | scheme_def.scheme.lower() == 'bearer' 62 | ): 63 | headers['Authorization'] = f'Bearer {credential}' 64 | logger.debug( 65 | "Added Bearer token for scheme '%s' (type: %s).", 66 | scheme_name, 67 | scheme_def.type, 68 | ) 69 | http_kwargs['headers'] = headers 70 | return request_payload, http_kwargs 71 | 72 | # Case 1b: OAuth2 and OIDC schemes, which are implicitly Bearer 73 | case ( 74 | OAuth2SecurityScheme() 75 | | OpenIdConnectSecurityScheme() 76 | ): 77 | headers['Authorization'] = f'Bearer {credential}' 78 | logger.debug( 79 | "Added Bearer token for scheme '%s' (type: %s).", 80 | scheme_name, 81 | scheme_def.type, 82 | ) 83 | http_kwargs['headers'] = headers 84 | return request_payload, http_kwargs 85 | 86 | # Case 2: API Key in Header 87 | case APIKeySecurityScheme(in_=In.header): 88 | headers[scheme_def.name] = credential 89 | logger.debug( 90 | "Added API Key Header for scheme '%s'.", 91 | scheme_name, 92 | ) 93 | http_kwargs['headers'] = headers 94 | return request_payload, http_kwargs 95 | 96 | # Note: Other cases like API keys in query/cookie are not handled and will be skipped. 97 | 98 | return request_payload, http_kwargs 99 | -------------------------------------------------------------------------------- /src/a2a/client/card_resolver.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | 4 | from typing import Any 5 | 6 | import httpx 7 | 8 | from pydantic import ValidationError 9 | 10 | from a2a.client.errors import ( 11 | A2AClientHTTPError, 12 | A2AClientJSONError, 13 | ) 14 | from a2a.types import ( 15 | AgentCard, 16 | ) 17 | from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH 18 | 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class A2ACardResolver: 24 | """Agent Card resolver.""" 25 | 26 | def __init__( 27 | self, 28 | httpx_client: httpx.AsyncClient, 29 | base_url: str, 30 | agent_card_path: str = AGENT_CARD_WELL_KNOWN_PATH, 31 | ) -> None: 32 | """Initializes the A2ACardResolver. 33 | 34 | Args: 35 | httpx_client: An async HTTP client instance (e.g., httpx.AsyncClient). 36 | base_url: The base URL of the agent's host. 37 | agent_card_path: The path to the agent card endpoint, relative to the base URL. 38 | """ 39 | self.base_url = base_url.rstrip('/') 40 | self.agent_card_path = agent_card_path.lstrip('/') 41 | self.httpx_client = httpx_client 42 | 43 | async def get_agent_card( 44 | self, 45 | relative_card_path: str | None = None, 46 | http_kwargs: dict[str, Any] | None = None, 47 | ) -> AgentCard: 48 | """Fetches an agent card from a specified path relative to the base_url. 49 | 50 | If relative_card_path is None, it defaults to the resolver's configured 51 | agent_card_path (for the public agent card). 52 | 53 | Args: 54 | relative_card_path: Optional path to the agent card endpoint, 55 | relative to the base URL. If None, uses the default public 56 | agent card path. 57 | http_kwargs: Optional dictionary of keyword arguments to pass to the 58 | underlying httpx.get request. 59 | 60 | Returns: 61 | An `AgentCard` object representing the agent's capabilities. 62 | 63 | Raises: 64 | A2AClientHTTPError: If an HTTP error occurs during the request. 65 | A2AClientJSONError: If the response body cannot be decoded as JSON 66 | or validated against the AgentCard schema. 67 | """ 68 | if relative_card_path is None: 69 | # Use the default public agent card path configured during initialization 70 | path_segment = self.agent_card_path 71 | else: 72 | path_segment = relative_card_path.lstrip('/') 73 | 74 | target_url = f'{self.base_url}/{path_segment}' 75 | 76 | try: 77 | response = await self.httpx_client.get( 78 | target_url, 79 | **(http_kwargs or {}), 80 | ) 81 | response.raise_for_status() 82 | agent_card_data = response.json() 83 | logger.info( 84 | 'Successfully fetched agent card data from %s: %s', 85 | target_url, 86 | agent_card_data, 87 | ) 88 | agent_card = AgentCard.model_validate(agent_card_data) 89 | except httpx.HTTPStatusError as e: 90 | raise A2AClientHTTPError( 91 | e.response.status_code, 92 | f'Failed to fetch agent card from {target_url}: {e}', 93 | ) from e 94 | except json.JSONDecodeError as e: 95 | raise A2AClientJSONError( 96 | f'Failed to parse JSON for agent card from {target_url}: {e}' 97 | ) from e 98 | except httpx.RequestError as e: 99 | raise A2AClientHTTPError( 100 | 503, 101 | f'Network communication error fetching agent card from {target_url}: {e}', 102 | ) from e 103 | except ValidationError as e: # Pydantic validation error 104 | raise A2AClientJSONError( 105 | f'Failed to validate agent card structure from {target_url}: {e.json()}' 106 | ) from e 107 | 108 | return agent_card 109 | -------------------------------------------------------------------------------- /src/a2a/client/errors.py: -------------------------------------------------------------------------------- 1 | """Custom exceptions for the A2A client.""" 2 | 3 | from a2a.types import JSONRPCErrorResponse 4 | 5 | 6 | class A2AClientError(Exception): 7 | """Base exception for A2A Client errors.""" 8 | 9 | 10 | class A2AClientHTTPError(A2AClientError): 11 | """Client exception for HTTP errors received from the server.""" 12 | 13 | def __init__(self, status_code: int, message: str): 14 | """Initializes the A2AClientHTTPError. 15 | 16 | Args: 17 | status_code: The HTTP status code of the response. 18 | message: A descriptive error message. 19 | """ 20 | self.status_code = status_code 21 | self.message = message 22 | super().__init__(f'HTTP Error {status_code}: {message}') 23 | 24 | 25 | class A2AClientJSONError(A2AClientError): 26 | """Client exception for JSON errors during response parsing or validation.""" 27 | 28 | def __init__(self, message: str): 29 | """Initializes the A2AClientJSONError. 30 | 31 | Args: 32 | message: A descriptive error message. 33 | """ 34 | self.message = message 35 | super().__init__(f'JSON Error: {message}') 36 | 37 | 38 | class A2AClientTimeoutError(A2AClientError): 39 | """Client exception for timeout errors during a request.""" 40 | 41 | def __init__(self, message: str): 42 | """Initializes the A2AClientTimeoutError. 43 | 44 | Args: 45 | message: A descriptive error message. 46 | """ 47 | self.message = message 48 | super().__init__(f'Timeout Error: {message}') 49 | 50 | 51 | class A2AClientInvalidArgsError(A2AClientError): 52 | """Client exception for invalid arguments passed to a method.""" 53 | 54 | def __init__(self, message: str): 55 | """Initializes the A2AClientInvalidArgsError. 56 | 57 | Args: 58 | message: A descriptive error message. 59 | """ 60 | self.message = message 61 | super().__init__(f'Invalid arguments error: {message}') 62 | 63 | 64 | class A2AClientInvalidStateError(A2AClientError): 65 | """Client exception for an invalid client state.""" 66 | 67 | def __init__(self, message: str): 68 | """Initializes the A2AClientInvalidStateError. 69 | 70 | Args: 71 | message: A descriptive error message. 72 | """ 73 | self.message = message 74 | super().__init__(f'Invalid state error: {message}') 75 | 76 | 77 | class A2AClientJSONRPCError(A2AClientError): 78 | """Client exception for JSON-RPC errors returned by the server.""" 79 | 80 | def __init__(self, error: JSONRPCErrorResponse): 81 | """Initializes the A2AClientJsonRPCError. 82 | 83 | Args: 84 | error: The JSON-RPC error object. 85 | """ 86 | self.error = error.error 87 | super().__init__(f'JSON-RPC Error {error.error}') 88 | -------------------------------------------------------------------------------- /src/a2a/client/helpers.py: -------------------------------------------------------------------------------- 1 | """Helper functions for the A2A client.""" 2 | 3 | from uuid import uuid4 4 | 5 | from a2a.types import Message, Part, Role, TextPart 6 | 7 | 8 | def create_text_message_object( 9 | role: Role = Role.user, content: str = '' 10 | ) -> Message: 11 | """Create a Message object containing a single TextPart. 12 | 13 | Args: 14 | role: The role of the message sender (user or agent). Defaults to Role.user. 15 | content: The text content of the message. Defaults to an empty string. 16 | 17 | Returns: 18 | A `Message` object with a new UUID message_id. 19 | """ 20 | return Message( 21 | role=role, parts=[Part(TextPart(text=content))], message_id=str(uuid4()) 22 | ) 23 | -------------------------------------------------------------------------------- /src/a2a/client/legacy_grpc.py: -------------------------------------------------------------------------------- 1 | """Backwards compatibility layer for the legacy A2A gRPC client.""" 2 | 3 | import warnings 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | from a2a.client.transports.grpc import GrpcTransport 8 | from a2a.types import AgentCard 9 | 10 | 11 | if TYPE_CHECKING: 12 | from a2a.grpc.a2a_pb2_grpc import A2AServiceStub 13 | 14 | 15 | class A2AGrpcClient(GrpcTransport): 16 | """[DEPRECATED] Backwards compatibility wrapper for the gRPC client.""" 17 | 18 | def __init__( # pylint: disable=super-init-not-called 19 | self, 20 | grpc_stub: 'A2AServiceStub', 21 | agent_card: AgentCard, 22 | ): 23 | warnings.warn( 24 | 'A2AGrpcClient is deprecated and will be removed in a future version. ' 25 | 'Use ClientFactory to create a client with a gRPC transport.', 26 | DeprecationWarning, 27 | stacklevel=2, 28 | ) 29 | # The old gRPC client accepted a stub directly. The new one accepts a 30 | # channel and builds the stub itself. We just have a stub here, so we 31 | # need to handle initialization ourselves. 32 | self.stub = grpc_stub 33 | self.agent_card = agent_card 34 | self._needs_extended_card = ( 35 | agent_card.supports_authenticated_extended_card 36 | if agent_card 37 | else True 38 | ) 39 | 40 | class _NopChannel: 41 | async def close(self) -> None: 42 | pass 43 | 44 | self.channel = _NopChannel() 45 | -------------------------------------------------------------------------------- /src/a2a/client/middleware.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import ABC, abstractmethod 4 | from collections.abc import MutableMapping # noqa: TC003 5 | from typing import TYPE_CHECKING, Any 6 | 7 | from pydantic import BaseModel, Field 8 | 9 | 10 | if TYPE_CHECKING: 11 | from a2a.types import AgentCard 12 | 13 | 14 | class ClientCallContext(BaseModel): 15 | """A context passed with each client call, allowing for call-specific. 16 | 17 | configuration and data passing. Such as authentication details or 18 | request deadlines. 19 | """ 20 | 21 | state: MutableMapping[str, Any] = Field(default_factory=dict) 22 | 23 | 24 | class ClientCallInterceptor(ABC): 25 | """An abstract base class for client-side call interceptors. 26 | 27 | Interceptors can inspect and modify requests before they are sent, 28 | which is ideal for concerns like authentication, logging, or tracing. 29 | """ 30 | 31 | @abstractmethod 32 | async def intercept( 33 | self, 34 | method_name: str, 35 | request_payload: dict[str, Any], 36 | http_kwargs: dict[str, Any], 37 | agent_card: AgentCard | None, 38 | context: ClientCallContext | None, 39 | ) -> tuple[dict[str, Any], dict[str, Any]]: 40 | """ 41 | Intercepts a client call before the request is sent. 42 | 43 | Args: 44 | method_name: The name of the RPC method (e.g., 'message/send'). 45 | request_payload: The JSON RPC request payload dictionary. 46 | http_kwargs: The keyword arguments for the httpx request. 47 | agent_card: The AgentCard associated with the client. 48 | context: The ClientCallContext for this specific call. 49 | 50 | Returns: 51 | A tuple containing the (potentially modified) request_payload 52 | and http_kwargs. 53 | """ 54 | -------------------------------------------------------------------------------- /src/a2a/client/optionals.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | 4 | # Attempt to import the optional module 5 | try: 6 | from grpc.aio import Channel # pyright: ignore[reportAssignmentType] 7 | except ImportError: 8 | # If grpc.aio is not available, define a dummy type for type checking. 9 | # This dummy type will only be used by type checkers. 10 | if TYPE_CHECKING: 11 | 12 | class Channel: # type: ignore[no-redef] 13 | """Dummy class for type hinting when grpc.aio is not available.""" 14 | 15 | else: 16 | Channel = None # At runtime, pd will be None if the import failed. 17 | -------------------------------------------------------------------------------- /src/a2a/client/transports/__init__.py: -------------------------------------------------------------------------------- 1 | """A2A Client Transports.""" 2 | 3 | from a2a.client.transports.base import ClientTransport 4 | from a2a.client.transports.jsonrpc import JsonRpcTransport 5 | from a2a.client.transports.rest import RestTransport 6 | 7 | 8 | try: 9 | from a2a.client.transports.grpc import GrpcTransport 10 | except ImportError: 11 | GrpcTransport = None # type: ignore 12 | 13 | 14 | __all__ = [ 15 | 'ClientTransport', 16 | 'GrpcTransport', 17 | 'JsonRpcTransport', 18 | 'RestTransport', 19 | ] 20 | -------------------------------------------------------------------------------- /src/a2a/client/transports/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from collections.abc import AsyncGenerator 3 | 4 | from a2a.client.middleware import ClientCallContext 5 | from a2a.types import ( 6 | AgentCard, 7 | GetTaskPushNotificationConfigParams, 8 | Message, 9 | MessageSendParams, 10 | Task, 11 | TaskArtifactUpdateEvent, 12 | TaskIdParams, 13 | TaskPushNotificationConfig, 14 | TaskQueryParams, 15 | TaskStatusUpdateEvent, 16 | ) 17 | 18 | 19 | class ClientTransport(ABC): 20 | """Abstract base class for a client transport.""" 21 | 22 | @abstractmethod 23 | async def send_message( 24 | self, 25 | request: MessageSendParams, 26 | *, 27 | context: ClientCallContext | None = None, 28 | ) -> Task | Message: 29 | """Sends a non-streaming message request to the agent.""" 30 | 31 | @abstractmethod 32 | async def send_message_streaming( 33 | self, 34 | request: MessageSendParams, 35 | *, 36 | context: ClientCallContext | None = None, 37 | ) -> AsyncGenerator[ 38 | Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent 39 | ]: 40 | """Sends a streaming message request to the agent and yields responses as they arrive.""" 41 | return 42 | yield 43 | 44 | @abstractmethod 45 | async def get_task( 46 | self, 47 | request: TaskQueryParams, 48 | *, 49 | context: ClientCallContext | None = None, 50 | ) -> Task: 51 | """Retrieves the current state and history of a specific task.""" 52 | 53 | @abstractmethod 54 | async def cancel_task( 55 | self, 56 | request: TaskIdParams, 57 | *, 58 | context: ClientCallContext | None = None, 59 | ) -> Task: 60 | """Requests the agent to cancel a specific task.""" 61 | 62 | @abstractmethod 63 | async def set_task_callback( 64 | self, 65 | request: TaskPushNotificationConfig, 66 | *, 67 | context: ClientCallContext | None = None, 68 | ) -> TaskPushNotificationConfig: 69 | """Sets or updates the push notification configuration for a specific task.""" 70 | 71 | @abstractmethod 72 | async def get_task_callback( 73 | self, 74 | request: GetTaskPushNotificationConfigParams, 75 | *, 76 | context: ClientCallContext | None = None, 77 | ) -> TaskPushNotificationConfig: 78 | """Retrieves the push notification configuration for a specific task.""" 79 | 80 | @abstractmethod 81 | async def resubscribe( 82 | self, 83 | request: TaskIdParams, 84 | *, 85 | context: ClientCallContext | None = None, 86 | ) -> AsyncGenerator[ 87 | Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent 88 | ]: 89 | """Reconnects to get task updates.""" 90 | return 91 | yield 92 | 93 | @abstractmethod 94 | async def get_card( 95 | self, 96 | *, 97 | context: ClientCallContext | None = None, 98 | ) -> AgentCard: 99 | """Retrieves the AgentCard.""" 100 | 101 | @abstractmethod 102 | async def close(self) -> None: 103 | """Closes the transport.""" 104 | -------------------------------------------------------------------------------- /src/a2a/extensions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a2aproject/a2a-python/dca66c3100a2b9701a1c8b65ad6853769eefd511/src/a2a/extensions/__init__.py -------------------------------------------------------------------------------- /src/a2a/extensions/common.py: -------------------------------------------------------------------------------- 1 | from a2a.types import AgentCard, AgentExtension 2 | 3 | 4 | HTTP_EXTENSION_HEADER = 'X-A2A-Extensions' 5 | 6 | 7 | def get_requested_extensions(values: list[str]) -> set[str]: 8 | """Get the set of requested extensions from an input list. 9 | 10 | This handles the list containing potentially comma-separated values, as 11 | occurs when using a list in an HTTP header. 12 | """ 13 | return { 14 | stripped 15 | for v in values 16 | for ext in v.split(',') 17 | if (stripped := ext.strip()) 18 | } 19 | 20 | 21 | def find_extension_by_uri(card: AgentCard, uri: str) -> AgentExtension | None: 22 | """Find an AgentExtension in an AgentCard given a uri.""" 23 | for ext in card.capabilities.extensions or []: 24 | if ext.uri == uri: 25 | return ext 26 | 27 | return None 28 | -------------------------------------------------------------------------------- /src/a2a/grpc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a2aproject/a2a-python/dca66c3100a2b9701a1c8b65ad6853769eefd511/src/a2a/grpc/__init__.py -------------------------------------------------------------------------------- /src/a2a/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a2aproject/a2a-python/dca66c3100a2b9701a1c8b65ad6853769eefd511/src/a2a/py.typed -------------------------------------------------------------------------------- /src/a2a/server/__init__.py: -------------------------------------------------------------------------------- 1 | """Server-side components for implementing an A2A agent.""" 2 | -------------------------------------------------------------------------------- /src/a2a/server/agent_execution/__init__.py: -------------------------------------------------------------------------------- 1 | """Components for executing agent logic within the A2A server.""" 2 | 3 | from a2a.server.agent_execution.agent_executor import AgentExecutor 4 | from a2a.server.agent_execution.context import RequestContext 5 | from a2a.server.agent_execution.request_context_builder import ( 6 | RequestContextBuilder, 7 | ) 8 | from a2a.server.agent_execution.simple_request_context_builder import ( 9 | SimpleRequestContextBuilder, 10 | ) 11 | 12 | 13 | __all__ = [ 14 | 'AgentExecutor', 15 | 'RequestContext', 16 | 'RequestContextBuilder', 17 | 'SimpleRequestContextBuilder', 18 | ] 19 | -------------------------------------------------------------------------------- /src/a2a/server/agent_execution/agent_executor.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from a2a.server.agent_execution.context import RequestContext 4 | from a2a.server.events.event_queue import EventQueue 5 | 6 | 7 | class AgentExecutor(ABC): 8 | """Agent Executor interface. 9 | 10 | Implementations of this interface contain the core logic of the agent, 11 | executing tasks based on requests and publishing updates to an event queue. 12 | """ 13 | 14 | @abstractmethod 15 | async def execute( 16 | self, context: RequestContext, event_queue: EventQueue 17 | ) -> None: 18 | """Execute the agent's logic for a given request context. 19 | 20 | The agent should read necessary information from the `context` and 21 | publish `Task` or `Message` events, or `TaskStatusUpdateEvent` / 22 | `TaskArtifactUpdateEvent` to the `event_queue`. This method should 23 | return once the agent's execution for this request is complete or 24 | yields control (e.g., enters an input-required state). 25 | 26 | Args: 27 | context: The request context containing the message, task ID, etc. 28 | event_queue: The queue to publish events to. 29 | """ 30 | 31 | @abstractmethod 32 | async def cancel( 33 | self, context: RequestContext, event_queue: EventQueue 34 | ) -> None: 35 | """Request the agent to cancel an ongoing task. 36 | 37 | The agent should attempt to stop the task identified by the task_id 38 | in the context and publish a `TaskStatusUpdateEvent` with state 39 | `TaskState.canceled` to the `event_queue`. 40 | 41 | Args: 42 | context: The request context containing the task ID to cancel. 43 | event_queue: The queue to publish the cancellation status update to. 44 | """ 45 | -------------------------------------------------------------------------------- /src/a2a/server/agent_execution/request_context_builder.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from a2a.server.agent_execution import RequestContext 4 | from a2a.server.context import ServerCallContext 5 | from a2a.types import MessageSendParams, Task 6 | 7 | 8 | class RequestContextBuilder(ABC): 9 | """Builds request context to be supplied to agent executor.""" 10 | 11 | @abstractmethod 12 | async def build( 13 | self, 14 | params: MessageSendParams | None = None, 15 | task_id: str | None = None, 16 | context_id: str | None = None, 17 | task: Task | None = None, 18 | context: ServerCallContext | None = None, 19 | ) -> RequestContext: 20 | pass 21 | -------------------------------------------------------------------------------- /src/a2a/server/agent_execution/simple_request_context_builder.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from a2a.server.agent_execution import RequestContext, RequestContextBuilder 4 | from a2a.server.context import ServerCallContext 5 | from a2a.server.tasks import TaskStore 6 | from a2a.types import MessageSendParams, Task 7 | 8 | 9 | class SimpleRequestContextBuilder(RequestContextBuilder): 10 | """Builds request context and populates referred tasks.""" 11 | 12 | def __init__( 13 | self, 14 | should_populate_referred_tasks: bool = False, 15 | task_store: TaskStore | None = None, 16 | ) -> None: 17 | """Initializes the SimpleRequestContextBuilder. 18 | 19 | Args: 20 | should_populate_referred_tasks: If True, the builder will fetch tasks 21 | referenced in `params.message.reference_task_ids` and populate the 22 | `related_tasks` field in the RequestContext. Defaults to False. 23 | task_store: The TaskStore instance to use for fetching referred tasks. 24 | Required if `should_populate_referred_tasks` is True. 25 | """ 26 | self._task_store = task_store 27 | self._should_populate_referred_tasks = should_populate_referred_tasks 28 | 29 | async def build( 30 | self, 31 | params: MessageSendParams | None = None, 32 | task_id: str | None = None, 33 | context_id: str | None = None, 34 | task: Task | None = None, 35 | context: ServerCallContext | None = None, 36 | ) -> RequestContext: 37 | """Builds the request context for an agent execution. 38 | 39 | This method assembles the RequestContext object. If the builder was 40 | initialized with `should_populate_referred_tasks=True`, it fetches all tasks 41 | referenced in `params.message.reference_task_ids` from the `task_store`. 42 | 43 | Args: 44 | params: The parameters of the incoming message send request. 45 | task_id: The ID of the task being executed. 46 | context_id: The ID of the current execution context. 47 | task: The primary task object associated with the request. 48 | context: The server call context, containing metadata about the call. 49 | 50 | Returns: 51 | An instance of RequestContext populated with the provided information 52 | and potentially a list of related tasks. 53 | """ 54 | related_tasks: list[Task] | None = None 55 | 56 | if ( 57 | self._task_store 58 | and self._should_populate_referred_tasks 59 | and params 60 | and params.message.reference_task_ids 61 | ): 62 | tasks = await asyncio.gather( 63 | *[ 64 | self._task_store.get(task_id) 65 | for task_id in params.message.reference_task_ids 66 | ] 67 | ) 68 | related_tasks = [x for x in tasks if x is not None] 69 | 70 | return RequestContext( 71 | request=params, 72 | task_id=task_id, 73 | context_id=context_id, 74 | task=task, 75 | related_tasks=related_tasks, 76 | call_context=context, 77 | ) 78 | -------------------------------------------------------------------------------- /src/a2a/server/apps/__init__.py: -------------------------------------------------------------------------------- 1 | """HTTP application components for the A2A server.""" 2 | 3 | from a2a.server.apps.jsonrpc import ( 4 | A2AFastAPIApplication, 5 | A2AStarletteApplication, 6 | CallContextBuilder, 7 | JSONRPCApplication, 8 | ) 9 | from a2a.server.apps.rest import A2ARESTFastAPIApplication 10 | 11 | 12 | __all__ = [ 13 | 'A2AFastAPIApplication', 14 | 'A2ARESTFastAPIApplication', 15 | 'A2AStarletteApplication', 16 | 'CallContextBuilder', 17 | 'JSONRPCApplication', 18 | ] 19 | -------------------------------------------------------------------------------- /src/a2a/server/apps/jsonrpc/__init__.py: -------------------------------------------------------------------------------- 1 | """A2A JSON-RPC Applications.""" 2 | 3 | from a2a.server.apps.jsonrpc.fastapi_app import A2AFastAPIApplication 4 | from a2a.server.apps.jsonrpc.jsonrpc_app import ( 5 | CallContextBuilder, 6 | DefaultCallContextBuilder, 7 | JSONRPCApplication, 8 | StarletteUserProxy, 9 | ) 10 | from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication 11 | 12 | 13 | __all__ = [ 14 | 'A2AFastAPIApplication', 15 | 'A2AStarletteApplication', 16 | 'CallContextBuilder', 17 | 'DefaultCallContextBuilder', 18 | 'JSONRPCApplication', 19 | 'StarletteUserProxy', 20 | ] 21 | -------------------------------------------------------------------------------- /src/a2a/server/apps/rest/__init__.py: -------------------------------------------------------------------------------- 1 | """A2A REST Applications.""" 2 | 3 | from a2a.server.apps.rest.fastapi_app import A2ARESTFastAPIApplication 4 | 5 | 6 | __all__ = [ 7 | 'A2ARESTFastAPIApplication', 8 | ] 9 | -------------------------------------------------------------------------------- /src/a2a/server/apps/rest/fastapi_app.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from collections.abc import Callable 4 | from typing import TYPE_CHECKING, Any 5 | 6 | 7 | if TYPE_CHECKING: 8 | from fastapi import APIRouter, FastAPI, Request, Response 9 | from fastapi.responses import JSONResponse 10 | 11 | _package_fastapi_installed = True 12 | else: 13 | try: 14 | from fastapi import APIRouter, FastAPI, Request, Response 15 | from fastapi.responses import JSONResponse 16 | 17 | _package_fastapi_installed = True 18 | except ImportError: 19 | APIRouter = Any 20 | FastAPI = Any 21 | Request = Any 22 | Response = Any 23 | 24 | _package_fastapi_installed = False 25 | 26 | 27 | from a2a.server.apps.jsonrpc.jsonrpc_app import CallContextBuilder 28 | from a2a.server.apps.rest.rest_adapter import RESTAdapter 29 | from a2a.server.context import ServerCallContext 30 | from a2a.server.request_handlers.request_handler import RequestHandler 31 | from a2a.types import AgentCard 32 | from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH 33 | 34 | 35 | logger = logging.getLogger(__name__) 36 | 37 | 38 | class A2ARESTFastAPIApplication: 39 | """A FastAPI application implementing the A2A protocol server REST endpoints. 40 | 41 | Handles incoming REST requests, routes them to the appropriate 42 | handler methods, and manages response generation including Server-Sent Events 43 | (SSE). 44 | """ 45 | 46 | def __init__( # noqa: PLR0913 47 | self, 48 | agent_card: AgentCard, 49 | http_handler: RequestHandler, 50 | extended_agent_card: AgentCard | None = None, 51 | context_builder: CallContextBuilder | None = None, 52 | card_modifier: Callable[[AgentCard], AgentCard] | None = None, 53 | extended_card_modifier: Callable[ 54 | [AgentCard, ServerCallContext], AgentCard 55 | ] 56 | | None = None, 57 | ): 58 | """Initializes the A2ARESTFastAPIApplication. 59 | 60 | Args: 61 | agent_card: The AgentCard describing the agent's capabilities. 62 | http_handler: The handler instance responsible for processing A2A 63 | requests via http. 64 | extended_agent_card: An optional, distinct AgentCard to be served 65 | at the authenticated extended card endpoint. 66 | context_builder: The CallContextBuilder used to construct the 67 | ServerCallContext passed to the http_handler. If None, no 68 | ServerCallContext is passed. 69 | card_modifier: An optional callback to dynamically modify the public 70 | agent card before it is served. 71 | extended_card_modifier: An optional callback to dynamically modify 72 | the extended agent card before it is served. It receives the 73 | call context. 74 | """ 75 | if not _package_fastapi_installed: 76 | raise ImportError( 77 | 'The `fastapi` package is required to use the' 78 | ' `A2ARESTFastAPIApplication`. It can be added as a part of' 79 | ' `a2a-sdk` optional dependencies, `a2a-sdk[http-server]`.' 80 | ) 81 | self._adapter = RESTAdapter( 82 | agent_card=agent_card, 83 | http_handler=http_handler, 84 | extended_agent_card=extended_agent_card, 85 | context_builder=context_builder, 86 | card_modifier=card_modifier, 87 | extended_card_modifier=extended_card_modifier, 88 | ) 89 | 90 | def build( 91 | self, 92 | agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH, 93 | rpc_url: str = '', 94 | **kwargs: Any, 95 | ) -> FastAPI: 96 | """Builds and returns the FastAPI application instance. 97 | 98 | Args: 99 | agent_card_url: The URL for the agent card endpoint. 100 | rpc_url: The URL for the A2A JSON-RPC endpoint. 101 | extended_agent_card_url: The URL for the authenticated extended agent card endpoint. 102 | **kwargs: Additional keyword arguments to pass to the FastAPI constructor. 103 | 104 | Returns: 105 | A configured FastAPI application instance. 106 | """ 107 | app = FastAPI(**kwargs) 108 | router = APIRouter() 109 | for route, callback in self._adapter.routes().items(): 110 | router.add_api_route( 111 | f'{rpc_url}{route[0]}', callback, methods=[route[1]] 112 | ) 113 | 114 | @router.get(f'{rpc_url}{agent_card_url}') 115 | async def get_agent_card(request: Request) -> Response: 116 | card = await self._adapter.handle_get_agent_card(request) 117 | return JSONResponse(card) 118 | 119 | app.include_router(router) 120 | return app 121 | -------------------------------------------------------------------------------- /src/a2a/server/context.py: -------------------------------------------------------------------------------- 1 | """Defines the ServerCallContext class.""" 2 | 3 | import collections.abc 4 | import typing 5 | 6 | from pydantic import BaseModel, ConfigDict, Field 7 | 8 | from a2a.auth.user import UnauthenticatedUser, User 9 | 10 | 11 | State = collections.abc.MutableMapping[str, typing.Any] 12 | 13 | 14 | class ServerCallContext(BaseModel): 15 | """A context passed when calling a server method. 16 | 17 | This class allows storing arbitrary user data in the state attribute. 18 | """ 19 | 20 | model_config = ConfigDict(arbitrary_types_allowed=True) 21 | 22 | state: State = Field(default={}) 23 | user: User = Field(default=UnauthenticatedUser()) 24 | requested_extensions: set[str] = Field(default_factory=set) 25 | activated_extensions: set[str] = Field(default_factory=set) 26 | -------------------------------------------------------------------------------- /src/a2a/server/events/__init__.py: -------------------------------------------------------------------------------- 1 | """Event handling components for the A2A server.""" 2 | 3 | from a2a.server.events.event_consumer import EventConsumer 4 | from a2a.server.events.event_queue import Event, EventQueue 5 | from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager 6 | from a2a.server.events.queue_manager import ( 7 | NoTaskQueue, 8 | QueueManager, 9 | TaskQueueExists, 10 | ) 11 | 12 | 13 | __all__ = [ 14 | 'Event', 15 | 'EventConsumer', 16 | 'EventQueue', 17 | 'InMemoryQueueManager', 18 | 'NoTaskQueue', 19 | 'QueueManager', 20 | 'TaskQueueExists', 21 | ] 22 | -------------------------------------------------------------------------------- /src/a2a/server/events/event_consumer.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import sys 4 | 5 | from collections.abc import AsyncGenerator 6 | 7 | from pydantic import ValidationError 8 | 9 | from a2a.server.events.event_queue import Event, EventQueue 10 | from a2a.types import ( 11 | InternalError, 12 | Message, 13 | Task, 14 | TaskState, 15 | TaskStatusUpdateEvent, 16 | ) 17 | from a2a.utils.errors import ServerError 18 | from a2a.utils.telemetry import SpanKind, trace_class 19 | 20 | 21 | # This is an alias to the exception for closed queue 22 | QueueClosed: type[Exception] = asyncio.QueueEmpty 23 | 24 | # When using python 3.13 or higher, the closed queue signal is QueueShutdown 25 | if sys.version_info >= (3, 13): 26 | QueueClosed = asyncio.QueueShutDown 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | @trace_class(kind=SpanKind.SERVER) 32 | class EventConsumer: 33 | """Consumer to read events from the agent event queue.""" 34 | 35 | def __init__(self, queue: EventQueue): 36 | """Initializes the EventConsumer. 37 | 38 | Args: 39 | queue: The `EventQueue` instance to consume events from. 40 | """ 41 | self.queue = queue 42 | self._timeout = 0.5 43 | self._exception: BaseException | None = None 44 | logger.debug('EventConsumer initialized') 45 | 46 | async def consume_one(self) -> Event: 47 | """Consume one event from the agent event queue non-blocking. 48 | 49 | Returns: 50 | The next event from the queue. 51 | 52 | Raises: 53 | ServerError: If the queue is empty when attempting to dequeue 54 | immediately. 55 | """ 56 | logger.debug('Attempting to consume one event.') 57 | try: 58 | event = await self.queue.dequeue_event(no_wait=True) 59 | except asyncio.QueueEmpty as e: 60 | logger.warning('Event queue was empty in consume_one.') 61 | raise ServerError( 62 | InternalError(message='Agent did not return any response') 63 | ) from e 64 | 65 | logger.debug('Dequeued event of type: %s in consume_one.', type(event)) 66 | 67 | self.queue.task_done() 68 | 69 | return event 70 | 71 | async def consume_all(self) -> AsyncGenerator[Event]: 72 | """Consume all the generated streaming events from the agent. 73 | 74 | This method yields events as they become available from the queue 75 | until a final event is received or the queue is closed. It also 76 | monitors for exceptions set by the `agent_task_callback`. 77 | 78 | Yields: 79 | Events dequeued from the queue. 80 | 81 | Raises: 82 | BaseException: If an exception was set by the `agent_task_callback`. 83 | """ 84 | logger.debug('Starting to consume all events from the queue.') 85 | while True: 86 | if self._exception: 87 | raise self._exception 88 | try: 89 | # We use a timeout when waiting for an event from the queue. 90 | # This is required because it allows the loop to check if 91 | # `self._exception` has been set by the `agent_task_callback`. 92 | # Without the timeout, loop might hang indefinitely if no events are 93 | # enqueued by the agent and the agent simply threw an exception 94 | event = await asyncio.wait_for( 95 | self.queue.dequeue_event(), timeout=self._timeout 96 | ) 97 | logger.debug( 98 | 'Dequeued event of type: %s in consume_all.', type(event) 99 | ) 100 | self.queue.task_done() 101 | logger.debug( 102 | 'Marked task as done in event queue in consume_all' 103 | ) 104 | 105 | is_final_event = ( 106 | (isinstance(event, TaskStatusUpdateEvent) and event.final) 107 | or isinstance(event, Message) 108 | or ( 109 | isinstance(event, Task) 110 | and event.status.state 111 | in ( 112 | TaskState.completed, 113 | TaskState.canceled, 114 | TaskState.failed, 115 | TaskState.rejected, 116 | TaskState.unknown, 117 | TaskState.input_required, 118 | ) 119 | ) 120 | ) 121 | 122 | # Make sure the yield is after the close events, otherwise 123 | # the caller may end up in a blocked state where this 124 | # generator isn't called again to close things out and the 125 | # other part is waiting for an event or a closed queue. 126 | if is_final_event: 127 | logger.debug('Stopping event consumption in consume_all.') 128 | await self.queue.close(True) 129 | yield event 130 | break 131 | yield event 132 | except TimeoutError: 133 | # continue polling until there is a final event 134 | continue 135 | except asyncio.TimeoutError: # pyright: ignore [reportUnusedExcept] 136 | # This class was made an alias of built-in TimeoutError after 3.11 137 | continue 138 | except (QueueClosed, asyncio.QueueEmpty): 139 | # Confirm that the queue is closed, e.g. we aren't on 140 | # python 3.12 and get a queue empty error on an open queue 141 | if self.queue.is_closed(): 142 | break 143 | except ValidationError: 144 | logger.exception('Invalid event format received') 145 | continue 146 | except Exception as e: 147 | logger.exception('Stopping event consumption due to exception') 148 | self._exception = e 149 | continue 150 | 151 | def agent_task_callback(self, agent_task: asyncio.Task[None]) -> None: 152 | """Callback to handle exceptions from the agent's execution task. 153 | 154 | If the agent's asyncio task raises an exception, this callback is 155 | invoked, and the exception is stored to be re-raised by the consumer loop. 156 | 157 | Args: 158 | agent_task: The asyncio.Task that completed. 159 | """ 160 | logger.debug('Agent task callback triggered.') 161 | if not agent_task.cancelled() and agent_task.done(): 162 | self._exception = agent_task.exception() 163 | -------------------------------------------------------------------------------- /src/a2a/server/events/in_memory_queue_manager.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from a2a.server.events.event_queue import EventQueue 4 | from a2a.server.events.queue_manager import ( 5 | NoTaskQueue, 6 | QueueManager, 7 | TaskQueueExists, 8 | ) 9 | from a2a.utils.telemetry import SpanKind, trace_class 10 | 11 | 12 | @trace_class(kind=SpanKind.SERVER) 13 | class InMemoryQueueManager(QueueManager): 14 | """InMemoryQueueManager is used for a single binary management. 15 | 16 | This implements the `QueueManager` interface using in-memory storage for event 17 | queues. It requires all incoming interactions for a given task ID to hit the 18 | same binary instance. 19 | 20 | This implementation is suitable for single-instance deployments but needs 21 | a distributed approach for scalable deployments. 22 | """ 23 | 24 | def __init__(self) -> None: 25 | """Initializes the InMemoryQueueManager.""" 26 | self._task_queue: dict[str, EventQueue] = {} 27 | self._lock = asyncio.Lock() 28 | 29 | async def add(self, task_id: str, queue: EventQueue) -> None: 30 | """Adds a new event queue for a task ID. 31 | 32 | Raises: 33 | TaskQueueExists: If a queue for the given `task_id` already exists. 34 | """ 35 | async with self._lock: 36 | if task_id in self._task_queue: 37 | raise TaskQueueExists 38 | self._task_queue[task_id] = queue 39 | 40 | async def get(self, task_id: str) -> EventQueue | None: 41 | """Retrieves the event queue for a task ID. 42 | 43 | Returns: 44 | The `EventQueue` instance for the `task_id`, or `None` if not found. 45 | """ 46 | async with self._lock: 47 | if task_id not in self._task_queue: 48 | return None 49 | return self._task_queue[task_id] 50 | 51 | async def tap(self, task_id: str) -> EventQueue | None: 52 | """Taps the event queue for a task ID to create a child queue. 53 | 54 | Returns: 55 | A new child `EventQueue` instance, or `None` if the task ID is not found. 56 | """ 57 | async with self._lock: 58 | if task_id not in self._task_queue: 59 | return None 60 | return self._task_queue[task_id].tap() 61 | 62 | async def close(self, task_id: str) -> None: 63 | """Closes and removes the event queue for a task ID. 64 | 65 | Raises: 66 | NoTaskQueue: If no queue exists for the given `task_id`. 67 | """ 68 | async with self._lock: 69 | if task_id not in self._task_queue: 70 | raise NoTaskQueue 71 | queue = self._task_queue.pop(task_id) 72 | await queue.close() 73 | 74 | async def create_or_tap(self, task_id: str) -> EventQueue: 75 | """Creates a new event queue for a task ID if one doesn't exist, otherwise taps the existing one. 76 | 77 | Returns: 78 | A new or child `EventQueue` instance for the `task_id`. 79 | """ 80 | async with self._lock: 81 | if task_id not in self._task_queue: 82 | queue = EventQueue() 83 | self._task_queue[task_id] = queue 84 | return queue 85 | return self._task_queue[task_id].tap() 86 | -------------------------------------------------------------------------------- /src/a2a/server/events/queue_manager.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from a2a.server.events.event_queue import EventQueue 4 | 5 | 6 | class QueueManager(ABC): 7 | """Interface for managing the event queue lifecycles per task.""" 8 | 9 | @abstractmethod 10 | async def add(self, task_id: str, queue: EventQueue) -> None: 11 | """Adds a new event queue associated with a task ID.""" 12 | 13 | @abstractmethod 14 | async def get(self, task_id: str) -> EventQueue | None: 15 | """Retrieves the event queue for a task ID.""" 16 | 17 | @abstractmethod 18 | async def tap(self, task_id: str) -> EventQueue | None: 19 | """Creates a child event queue (tap) for an existing task ID.""" 20 | 21 | @abstractmethod 22 | async def close(self, task_id: str) -> None: 23 | """Closes and removes the event queue for a task ID.""" 24 | 25 | @abstractmethod 26 | async def create_or_tap(self, task_id: str) -> EventQueue: 27 | """Creates a queue if one doesn't exist, otherwise taps the existing one.""" 28 | 29 | 30 | class TaskQueueExists(Exception): # noqa: N818 31 | """Exception raised when attempting to add a queue for a task ID that already exists.""" 32 | 33 | 34 | class NoTaskQueue(Exception): # noqa: N818 35 | """Exception raised when attempting to access or close a queue for a task ID that does not exist.""" 36 | -------------------------------------------------------------------------------- /src/a2a/server/request_handlers/__init__.py: -------------------------------------------------------------------------------- 1 | """Request handler components for the A2A server.""" 2 | 3 | import logging 4 | 5 | from a2a.server.request_handlers.default_request_handler import ( 6 | DefaultRequestHandler, 7 | ) 8 | from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler 9 | from a2a.server.request_handlers.request_handler import RequestHandler 10 | from a2a.server.request_handlers.response_helpers import ( 11 | build_error_response, 12 | prepare_response_object, 13 | ) 14 | from a2a.server.request_handlers.rest_handler import RESTHandler 15 | 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | try: 20 | from a2a.server.request_handlers.grpc_handler import ( 21 | GrpcHandler, # type: ignore 22 | ) 23 | except ImportError as e: 24 | _original_error = e 25 | logger.debug( 26 | 'GrpcHandler not loaded. This is expected if gRPC dependencies are not installed. Error: %s', 27 | _original_error, 28 | ) 29 | 30 | class GrpcHandler: # type: ignore 31 | """Placeholder for GrpcHandler when dependencies are not installed.""" 32 | 33 | def __init__(self, *args, **kwargs): 34 | raise ImportError( 35 | 'To use GrpcHandler, its dependencies must be installed. ' 36 | 'You can install them with \'pip install "a2a-sdk[grpc]"\'' 37 | ) from _original_error 38 | 39 | 40 | __all__ = [ 41 | 'DefaultRequestHandler', 42 | 'GrpcHandler', 43 | 'JSONRPCHandler', 44 | 'RESTHandler', 45 | 'RequestHandler', 46 | 'build_error_response', 47 | 'prepare_response_object', 48 | ] 49 | -------------------------------------------------------------------------------- /src/a2a/server/request_handlers/response_helpers.py: -------------------------------------------------------------------------------- 1 | """Helper functions for building A2A JSON-RPC responses.""" 2 | 3 | # response types 4 | from typing import TypeVar 5 | 6 | from a2a.types import ( 7 | A2AError, 8 | CancelTaskResponse, 9 | CancelTaskSuccessResponse, 10 | DeleteTaskPushNotificationConfigResponse, 11 | DeleteTaskPushNotificationConfigSuccessResponse, 12 | GetTaskPushNotificationConfigResponse, 13 | GetTaskPushNotificationConfigSuccessResponse, 14 | GetTaskResponse, 15 | GetTaskSuccessResponse, 16 | InvalidAgentResponseError, 17 | JSONRPCError, 18 | JSONRPCErrorResponse, 19 | ListTaskPushNotificationConfigResponse, 20 | ListTaskPushNotificationConfigSuccessResponse, 21 | Message, 22 | SendMessageResponse, 23 | SendMessageSuccessResponse, 24 | SendStreamingMessageResponse, 25 | SendStreamingMessageSuccessResponse, 26 | SetTaskPushNotificationConfigResponse, 27 | SetTaskPushNotificationConfigSuccessResponse, 28 | Task, 29 | TaskArtifactUpdateEvent, 30 | TaskPushNotificationConfig, 31 | TaskStatusUpdateEvent, 32 | ) 33 | 34 | 35 | RT = TypeVar( 36 | 'RT', 37 | GetTaskResponse, 38 | CancelTaskResponse, 39 | SendMessageResponse, 40 | SetTaskPushNotificationConfigResponse, 41 | GetTaskPushNotificationConfigResponse, 42 | SendStreamingMessageResponse, 43 | ListTaskPushNotificationConfigResponse, 44 | DeleteTaskPushNotificationConfigResponse, 45 | ) 46 | """Type variable for RootModel response types.""" 47 | 48 | # success types 49 | SPT = TypeVar( 50 | 'SPT', 51 | GetTaskSuccessResponse, 52 | CancelTaskSuccessResponse, 53 | SendMessageSuccessResponse, 54 | SetTaskPushNotificationConfigSuccessResponse, 55 | GetTaskPushNotificationConfigSuccessResponse, 56 | SendStreamingMessageSuccessResponse, 57 | ListTaskPushNotificationConfigSuccessResponse, 58 | DeleteTaskPushNotificationConfigSuccessResponse, 59 | ) 60 | """Type variable for SuccessResponse types.""" 61 | 62 | # result types 63 | EventTypes = ( 64 | Task 65 | | Message 66 | | TaskArtifactUpdateEvent 67 | | TaskStatusUpdateEvent 68 | | TaskPushNotificationConfig 69 | | A2AError 70 | | JSONRPCError 71 | | list[TaskPushNotificationConfig] 72 | ) 73 | """Type alias for possible event types produced by handlers.""" 74 | 75 | 76 | def build_error_response( 77 | request_id: str | int | None, 78 | error: A2AError | JSONRPCError, 79 | response_wrapper_type: type[RT], 80 | ) -> RT: 81 | """Helper method to build a JSONRPCErrorResponse wrapped in the appropriate response type. 82 | 83 | Args: 84 | request_id: The ID of the request that caused the error. 85 | error: The A2AError or JSONRPCError object. 86 | response_wrapper_type: The Pydantic RootModel type that wraps the response 87 | for the specific RPC method (e.g., `SendMessageResponse`). 88 | 89 | Returns: 90 | A Pydantic model representing the JSON-RPC error response, 91 | wrapped in the specified response type. 92 | """ 93 | return response_wrapper_type( 94 | JSONRPCErrorResponse( 95 | id=request_id, 96 | error=error.root if isinstance(error, A2AError) else error, 97 | ) 98 | ) 99 | 100 | 101 | def prepare_response_object( 102 | request_id: str | int | None, 103 | response: EventTypes, 104 | success_response_types: tuple[type, ...], 105 | success_payload_type: type[SPT], 106 | response_type: type[RT], 107 | ) -> RT: 108 | """Helper method to build appropriate JSONRPCResponse object for RPC methods. 109 | 110 | Based on the type of the `response` object received from the handler, 111 | it constructs either a success response wrapped in the appropriate payload type 112 | or an error response. 113 | 114 | Args: 115 | request_id: The ID of the request. 116 | response: The object received from the request handler. 117 | success_response_types: A tuple of expected Pydantic model types for a successful result. 118 | success_payload_type: The Pydantic model type for the success payload 119 | (e.g., `SendMessageSuccessResponse`). 120 | response_type: The Pydantic RootModel type that wraps the final response 121 | (e.g., `SendMessageResponse`). 122 | 123 | Returns: 124 | A Pydantic model representing the final JSON-RPC response (success or error). 125 | """ 126 | if isinstance(response, success_response_types): 127 | return response_type( 128 | root=success_payload_type(id=request_id, result=response) # type:ignore 129 | ) 130 | 131 | if isinstance(response, A2AError | JSONRPCError): 132 | return build_error_response(request_id, response, response_type) 133 | 134 | # If consumer_data is not an expected success type and not an error, 135 | # it's an invalid type of response from the agent for this specific method. 136 | response = A2AError( 137 | root=InvalidAgentResponseError( 138 | message='Agent returned invalid type response for this method' 139 | ) 140 | ) 141 | 142 | return build_error_response(request_id, response, response_type) 143 | -------------------------------------------------------------------------------- /src/a2a/server/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | """Components for managing tasks within the A2A server.""" 2 | 3 | import logging 4 | 5 | from a2a.server.tasks.base_push_notification_sender import ( 6 | BasePushNotificationSender, 7 | ) 8 | from a2a.server.tasks.inmemory_push_notification_config_store import ( 9 | InMemoryPushNotificationConfigStore, 10 | ) 11 | from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore 12 | from a2a.server.tasks.push_notification_config_store import ( 13 | PushNotificationConfigStore, 14 | ) 15 | from a2a.server.tasks.push_notification_sender import PushNotificationSender 16 | from a2a.server.tasks.result_aggregator import ResultAggregator 17 | from a2a.server.tasks.task_manager import TaskManager 18 | from a2a.server.tasks.task_store import TaskStore 19 | from a2a.server.tasks.task_updater import TaskUpdater 20 | 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | try: 25 | from a2a.server.tasks.database_task_store import ( 26 | DatabaseTaskStore, # type: ignore 27 | ) 28 | except ImportError as e: 29 | _original_error = e 30 | # If the database task store is not available, we can still use in-memory stores. 31 | logger.debug( 32 | 'DatabaseTaskStore not loaded. This is expected if database dependencies are not installed. Error: %s', 33 | e, 34 | ) 35 | 36 | class DatabaseTaskStore: # type: ignore 37 | """Placeholder for DatabaseTaskStore when dependencies are not installed.""" 38 | 39 | def __init__(self, *args, **kwargs): 40 | raise ImportError( 41 | 'To use DatabaseTaskStore, its dependencies must be installed. ' 42 | 'You can install them with \'pip install "a2a-sdk[sql]"\'' 43 | ) from _original_error 44 | 45 | 46 | try: 47 | from a2a.server.tasks.database_push_notification_config_store import ( 48 | DatabasePushNotificationConfigStore, # type: ignore 49 | ) 50 | except ImportError as e: 51 | _original_error = e 52 | # If the database push notification config store is not available, we can still use in-memory stores. 53 | logger.debug( 54 | 'DatabasePushNotificationConfigStore not loaded. This is expected if database dependencies are not installed. Error: %s', 55 | e, 56 | ) 57 | 58 | class DatabasePushNotificationConfigStore: # type: ignore 59 | """Placeholder for DatabasePushNotificationConfigStore when dependencies are not installed.""" 60 | 61 | def __init__(self, *args, **kwargs): 62 | raise ImportError( 63 | 'To use DatabasePushNotificationConfigStore, its dependencies must be installed. ' 64 | 'You can install them with \'pip install "a2a-sdk[sql]"\'' 65 | ) from _original_error 66 | 67 | 68 | __all__ = [ 69 | 'BasePushNotificationSender', 70 | 'DatabasePushNotificationConfigStore', 71 | 'DatabaseTaskStore', 72 | 'InMemoryPushNotificationConfigStore', 73 | 'InMemoryTaskStore', 74 | 'PushNotificationConfigStore', 75 | 'PushNotificationSender', 76 | 'ResultAggregator', 77 | 'TaskManager', 78 | 'TaskStore', 79 | 'TaskUpdater', 80 | ] 81 | -------------------------------------------------------------------------------- /src/a2a/server/tasks/base_push_notification_sender.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | 4 | import httpx 5 | 6 | from a2a.server.tasks.push_notification_config_store import ( 7 | PushNotificationConfigStore, 8 | ) 9 | from a2a.server.tasks.push_notification_sender import PushNotificationSender 10 | from a2a.types import PushNotificationConfig, Task 11 | 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class BasePushNotificationSender(PushNotificationSender): 17 | """Base implementation of PushNotificationSender interface.""" 18 | 19 | def __init__( 20 | self, 21 | httpx_client: httpx.AsyncClient, 22 | config_store: PushNotificationConfigStore, 23 | ) -> None: 24 | """Initializes the BasePushNotificationSender. 25 | 26 | Args: 27 | httpx_client: An async HTTP client instance to send notifications. 28 | config_store: A PushNotificationConfigStore instance to retrieve configurations. 29 | """ 30 | self._client = httpx_client 31 | self._config_store = config_store 32 | 33 | async def send_notification(self, task: Task) -> None: 34 | """Sends a push notification for a task if configuration exists.""" 35 | push_configs = await self._config_store.get_info(task.id) 36 | if not push_configs: 37 | return 38 | 39 | awaitables = [ 40 | self._dispatch_notification(task, push_info) 41 | for push_info in push_configs 42 | ] 43 | results = await asyncio.gather(*awaitables) 44 | 45 | if not all(results): 46 | logger.warning( 47 | 'Some push notifications failed to send for task_id=%s', task.id 48 | ) 49 | 50 | async def _dispatch_notification( 51 | self, task: Task, push_info: PushNotificationConfig 52 | ) -> bool: 53 | url = push_info.url 54 | try: 55 | headers = None 56 | if push_info.token: 57 | headers = {'X-A2A-Notification-Token': push_info.token} 58 | response = await self._client.post( 59 | url, 60 | json=task.model_dump(mode='json', exclude_none=True), 61 | headers=headers, 62 | ) 63 | response.raise_for_status() 64 | logger.info( 65 | 'Push-notification sent for task_id=%s to URL: %s', task.id, url 66 | ) 67 | except Exception: 68 | logger.exception( 69 | 'Error sending push-notification for task_id=%s to URL: %s.', 70 | task.id, 71 | url, 72 | ) 73 | return False 74 | return True 75 | -------------------------------------------------------------------------------- /src/a2a/server/tasks/database_task_store.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | try: 5 | from sqlalchemy import Table, delete, select 6 | from sqlalchemy.ext.asyncio import ( 7 | AsyncEngine, 8 | AsyncSession, 9 | async_sessionmaker, 10 | ) 11 | from sqlalchemy.orm import class_mapper 12 | except ImportError as e: 13 | raise ImportError( 14 | 'DatabaseTaskStore requires SQLAlchemy and a database driver. ' 15 | 'Install with one of: ' 16 | "'pip install a2a-sdk[postgresql]', " 17 | "'pip install a2a-sdk[mysql]', " 18 | "'pip install a2a-sdk[sqlite]', " 19 | "or 'pip install a2a-sdk[sql]'" 20 | ) from e 21 | 22 | from a2a.server.context import ServerCallContext 23 | from a2a.server.models import Base, TaskModel, create_task_model 24 | from a2a.server.tasks.task_store import TaskStore 25 | from a2a.types import Task # Task is the Pydantic model 26 | 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | class DatabaseTaskStore(TaskStore): 32 | """SQLAlchemy-based implementation of TaskStore. 33 | 34 | Stores task objects in a database supported by SQLAlchemy. 35 | """ 36 | 37 | engine: AsyncEngine 38 | async_session_maker: async_sessionmaker[AsyncSession] 39 | create_table: bool 40 | _initialized: bool 41 | task_model: type[TaskModel] 42 | 43 | def __init__( 44 | self, 45 | engine: AsyncEngine, 46 | create_table: bool = True, 47 | table_name: str = 'tasks', 48 | ) -> None: 49 | """Initializes the DatabaseTaskStore. 50 | 51 | Args: 52 | engine: An existing SQLAlchemy AsyncEngine to be used by Task Store 53 | create_table: If true, create tasks table on initialization. 54 | table_name: Name of the database table. Defaults to 'tasks'. 55 | """ 56 | logger.debug( 57 | 'Initializing DatabaseTaskStore with existing engine, table: %s', 58 | table_name, 59 | ) 60 | self.engine = engine 61 | self.async_session_maker = async_sessionmaker( 62 | self.engine, expire_on_commit=False 63 | ) 64 | self.create_table = create_table 65 | self._initialized = False 66 | 67 | self.task_model = ( 68 | TaskModel 69 | if table_name == 'tasks' 70 | else create_task_model(table_name) 71 | ) 72 | 73 | async def initialize(self) -> None: 74 | """Initialize the database and create the table if needed.""" 75 | if self._initialized: 76 | return 77 | 78 | logger.debug('Initializing database schema...') 79 | if self.create_table: 80 | async with self.engine.begin() as conn: 81 | mapper = class_mapper(self.task_model) 82 | tables_to_create = [ 83 | table for table in mapper.tables if isinstance(table, Table) 84 | ] 85 | await conn.run_sync( 86 | Base.metadata.create_all, tables=tables_to_create 87 | ) 88 | self._initialized = True 89 | logger.debug('Database schema initialized.') 90 | 91 | async def _ensure_initialized(self) -> None: 92 | """Ensure the database connection is initialized.""" 93 | if not self._initialized: 94 | await self.initialize() 95 | 96 | def _to_orm(self, task: Task) -> TaskModel: 97 | """Maps a Pydantic Task to a SQLAlchemy TaskModel instance.""" 98 | return self.task_model( 99 | id=task.id, 100 | context_id=task.context_id, 101 | kind=task.kind, 102 | status=task.status, 103 | artifacts=task.artifacts, 104 | history=task.history, 105 | task_metadata=task.metadata, 106 | ) 107 | 108 | def _from_orm(self, task_model: TaskModel) -> Task: 109 | """Maps a SQLAlchemy TaskModel to a Pydantic Task instance.""" 110 | # Map database columns to Pydantic model fields 111 | task_data_from_db = { 112 | 'id': task_model.id, 113 | 'context_id': task_model.context_id, 114 | 'kind': task_model.kind, 115 | 'status': task_model.status, 116 | 'artifacts': task_model.artifacts, 117 | 'history': task_model.history, 118 | 'metadata': task_model.task_metadata, # Map task_metadata column to metadata field 119 | } 120 | # Pydantic's model_validate will parse the nested dicts/lists from JSON 121 | return Task.model_validate(task_data_from_db) 122 | 123 | async def save( 124 | self, task: Task, context: ServerCallContext | None = None 125 | ) -> None: 126 | """Saves or updates a task in the database.""" 127 | await self._ensure_initialized() 128 | db_task = self._to_orm(task) 129 | async with self.async_session_maker.begin() as session: 130 | await session.merge(db_task) 131 | logger.debug('Task %s saved/updated successfully.', task.id) 132 | 133 | async def get( 134 | self, task_id: str, context: ServerCallContext | None = None 135 | ) -> Task | None: 136 | """Retrieves a task from the database by ID.""" 137 | await self._ensure_initialized() 138 | async with self.async_session_maker() as session: 139 | stmt = select(self.task_model).where(self.task_model.id == task_id) 140 | result = await session.execute(stmt) 141 | task_model = result.scalar_one_or_none() 142 | if task_model: 143 | task = self._from_orm(task_model) 144 | logger.debug('Task %s retrieved successfully.', task_id) 145 | return task 146 | 147 | logger.debug('Task %s not found in store.', task_id) 148 | return None 149 | 150 | async def delete( 151 | self, task_id: str, context: ServerCallContext | None = None 152 | ) -> None: 153 | """Deletes a task from the database by ID.""" 154 | await self._ensure_initialized() 155 | 156 | async with self.async_session_maker.begin() as session: 157 | stmt = delete(self.task_model).where(self.task_model.id == task_id) 158 | result = await session.execute(stmt) 159 | # Commit is automatic when using session.begin() 160 | 161 | if result.rowcount > 0: 162 | logger.info('Task %s deleted successfully.', task_id) 163 | else: 164 | logger.warning( 165 | 'Attempted to delete nonexistent task with id: %s', task_id 166 | ) 167 | -------------------------------------------------------------------------------- /src/a2a/server/tasks/inmemory_push_notification_config_store.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | 4 | from a2a.server.tasks.push_notification_config_store import ( 5 | PushNotificationConfigStore, 6 | ) 7 | from a2a.types import PushNotificationConfig 8 | 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class InMemoryPushNotificationConfigStore(PushNotificationConfigStore): 14 | """In-memory implementation of PushNotificationConfigStore interface. 15 | 16 | Stores push notification configurations in memory 17 | """ 18 | 19 | def __init__(self) -> None: 20 | """Initializes the InMemoryPushNotificationConfigStore.""" 21 | self.lock = asyncio.Lock() 22 | self._push_notification_infos: dict[ 23 | str, list[PushNotificationConfig] 24 | ] = {} 25 | 26 | async def set_info( 27 | self, task_id: str, notification_config: PushNotificationConfig 28 | ) -> None: 29 | """Sets or updates the push notification configuration for a task in memory.""" 30 | async with self.lock: 31 | if task_id not in self._push_notification_infos: 32 | self._push_notification_infos[task_id] = [] 33 | 34 | if notification_config.id is None: 35 | notification_config.id = task_id 36 | 37 | for config in self._push_notification_infos[task_id]: 38 | if config.id == notification_config.id: 39 | self._push_notification_infos[task_id].remove(config) 40 | break 41 | 42 | self._push_notification_infos[task_id].append(notification_config) 43 | 44 | async def get_info(self, task_id: str) -> list[PushNotificationConfig]: 45 | """Retrieves the push notification configuration for a task from memory.""" 46 | async with self.lock: 47 | return self._push_notification_infos.get(task_id) or [] 48 | 49 | async def delete_info( 50 | self, task_id: str, config_id: str | None = None 51 | ) -> None: 52 | """Deletes the push notification configuration for a task from memory.""" 53 | async with self.lock: 54 | if config_id is None: 55 | config_id = task_id 56 | 57 | if task_id in self._push_notification_infos: 58 | configurations = self._push_notification_infos[task_id] 59 | if not configurations: 60 | return 61 | 62 | for config in configurations: 63 | if config.id == config_id: 64 | configurations.remove(config) 65 | break 66 | 67 | if len(configurations) == 0: 68 | del self._push_notification_infos[task_id] 69 | -------------------------------------------------------------------------------- /src/a2a/server/tasks/inmemory_task_store.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | 4 | from a2a.server.context import ServerCallContext 5 | from a2a.server.tasks.task_store import TaskStore 6 | from a2a.types import Task 7 | 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class InMemoryTaskStore(TaskStore): 13 | """In-memory implementation of TaskStore. 14 | 15 | Stores task objects in a dictionary in memory. Task data is lost when the 16 | server process stops. 17 | """ 18 | 19 | def __init__(self) -> None: 20 | """Initializes the InMemoryTaskStore.""" 21 | logger.debug('Initializing InMemoryTaskStore') 22 | self.tasks: dict[str, Task] = {} 23 | self.lock = asyncio.Lock() 24 | 25 | async def save( 26 | self, task: Task, context: ServerCallContext | None = None 27 | ) -> None: 28 | """Saves or updates a task in the in-memory store.""" 29 | async with self.lock: 30 | self.tasks[task.id] = task 31 | logger.debug('Task %s saved successfully.', task.id) 32 | 33 | async def get( 34 | self, task_id: str, context: ServerCallContext | None = None 35 | ) -> Task | None: 36 | """Retrieves a task from the in-memory store by ID.""" 37 | async with self.lock: 38 | logger.debug('Attempting to get task with id: %s', task_id) 39 | task = self.tasks.get(task_id) 40 | if task: 41 | logger.debug('Task %s retrieved successfully.', task_id) 42 | else: 43 | logger.debug('Task %s not found in store.', task_id) 44 | return task 45 | 46 | async def delete( 47 | self, task_id: str, context: ServerCallContext | None = None 48 | ) -> None: 49 | """Deletes a task from the in-memory store by ID.""" 50 | async with self.lock: 51 | logger.debug('Attempting to delete task with id: %s', task_id) 52 | if task_id in self.tasks: 53 | del self.tasks[task_id] 54 | logger.debug('Task %s deleted successfully.', task_id) 55 | else: 56 | logger.warning( 57 | 'Attempted to delete nonexistent task with id: %s', task_id 58 | ) 59 | -------------------------------------------------------------------------------- /src/a2a/server/tasks/push_notification_config_store.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from a2a.types import PushNotificationConfig 4 | 5 | 6 | class PushNotificationConfigStore(ABC): 7 | """Interface for storing and retrieving push notification configurations for tasks.""" 8 | 9 | @abstractmethod 10 | async def set_info( 11 | self, task_id: str, notification_config: PushNotificationConfig 12 | ) -> None: 13 | """Sets or updates the push notification configuration for a task.""" 14 | 15 | @abstractmethod 16 | async def get_info(self, task_id: str) -> list[PushNotificationConfig]: 17 | """Retrieves the push notification configuration for a task.""" 18 | 19 | @abstractmethod 20 | async def delete_info( 21 | self, task_id: str, config_id: str | None = None 22 | ) -> None: 23 | """Deletes the push notification configuration for a task.""" 24 | -------------------------------------------------------------------------------- /src/a2a/server/tasks/push_notification_sender.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from a2a.types import Task 4 | 5 | 6 | class PushNotificationSender(ABC): 7 | """Interface for sending push notifications for tasks.""" 8 | 9 | @abstractmethod 10 | async def send_notification(self, task: Task) -> None: 11 | """Sends a push notification containing the latest task state.""" 12 | -------------------------------------------------------------------------------- /src/a2a/server/tasks/task_store.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from a2a.server.context import ServerCallContext 4 | from a2a.types import Task 5 | 6 | 7 | class TaskStore(ABC): 8 | """Agent Task Store interface. 9 | 10 | Defines the methods for persisting and retrieving `Task` objects. 11 | """ 12 | 13 | @abstractmethod 14 | async def save( 15 | self, task: Task, context: ServerCallContext | None = None 16 | ) -> None: 17 | """Saves or updates a task in the store.""" 18 | 19 | @abstractmethod 20 | async def get( 21 | self, task_id: str, context: ServerCallContext | None = None 22 | ) -> Task | None: 23 | """Retrieves a task from the store by ID.""" 24 | 25 | @abstractmethod 26 | async def delete( 27 | self, task_id: str, context: ServerCallContext | None = None 28 | ) -> None: 29 | """Deletes a task from the store by ID.""" 30 | -------------------------------------------------------------------------------- /src/a2a/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Utility functions for the A2A Python SDK.""" 2 | 3 | from a2a.utils.artifact import ( 4 | new_artifact, 5 | new_data_artifact, 6 | new_text_artifact, 7 | ) 8 | from a2a.utils.constants import ( 9 | AGENT_CARD_WELL_KNOWN_PATH, 10 | DEFAULT_RPC_URL, 11 | EXTENDED_AGENT_CARD_PATH, 12 | PREV_AGENT_CARD_WELL_KNOWN_PATH, 13 | ) 14 | from a2a.utils.helpers import ( 15 | append_artifact_to_task, 16 | are_modalities_compatible, 17 | build_text_artifact, 18 | create_task_obj, 19 | ) 20 | from a2a.utils.message import ( 21 | get_data_parts, 22 | get_file_parts, 23 | get_message_text, 24 | get_text_parts, 25 | new_agent_parts_message, 26 | new_agent_text_message, 27 | ) 28 | from a2a.utils.task import ( 29 | completed_task, 30 | new_task, 31 | ) 32 | 33 | 34 | __all__ = [ 35 | 'AGENT_CARD_WELL_KNOWN_PATH', 36 | 'DEFAULT_RPC_URL', 37 | 'EXTENDED_AGENT_CARD_PATH', 38 | 'PREV_AGENT_CARD_WELL_KNOWN_PATH', 39 | 'append_artifact_to_task', 40 | 'are_modalities_compatible', 41 | 'build_text_artifact', 42 | 'completed_task', 43 | 'create_task_obj', 44 | 'get_data_parts', 45 | 'get_file_parts', 46 | 'get_message_text', 47 | 'get_text_parts', 48 | 'new_agent_parts_message', 49 | 'new_agent_text_message', 50 | 'new_artifact', 51 | 'new_data_artifact', 52 | 'new_task', 53 | 'new_text_artifact', 54 | ] 55 | -------------------------------------------------------------------------------- /src/a2a/utils/artifact.py: -------------------------------------------------------------------------------- 1 | """Utility functions for creating A2A Artifact objects.""" 2 | 3 | import uuid 4 | 5 | from typing import Any 6 | 7 | from a2a.types import Artifact, DataPart, Part, TextPart 8 | 9 | 10 | def new_artifact( 11 | parts: list[Part], name: str, description: str = '' 12 | ) -> Artifact: 13 | """Creates a new Artifact object. 14 | 15 | Args: 16 | parts: The list of `Part` objects forming the artifact's content. 17 | name: The human-readable name of the artifact. 18 | description: An optional description of the artifact. 19 | 20 | Returns: 21 | A new `Artifact` object with a generated artifact_id. 22 | """ 23 | return Artifact( 24 | artifact_id=str(uuid.uuid4()), 25 | parts=parts, 26 | name=name, 27 | description=description, 28 | ) 29 | 30 | 31 | def new_text_artifact( 32 | name: str, 33 | text: str, 34 | description: str = '', 35 | ) -> Artifact: 36 | """Creates a new Artifact object containing only a single TextPart. 37 | 38 | Args: 39 | name: The human-readable name of the artifact. 40 | text: The text content of the artifact. 41 | description: An optional description of the artifact. 42 | 43 | Returns: 44 | A new `Artifact` object with a generated artifact_id. 45 | """ 46 | return new_artifact( 47 | [Part(root=TextPart(text=text))], 48 | name, 49 | description, 50 | ) 51 | 52 | 53 | def new_data_artifact( 54 | name: str, 55 | data: dict[str, Any], 56 | description: str = '', 57 | ) -> Artifact: 58 | """Creates a new Artifact object containing only a single DataPart. 59 | 60 | Args: 61 | name: The human-readable name of the artifact. 62 | data: The structured data content of the artifact. 63 | description: An optional description of the artifact. 64 | 65 | Returns: 66 | A new `Artifact` object with a generated artifact_id. 67 | """ 68 | return new_artifact( 69 | [Part(root=DataPart(data=data))], 70 | name, 71 | description, 72 | ) 73 | -------------------------------------------------------------------------------- /src/a2a/utils/constants.py: -------------------------------------------------------------------------------- 1 | """Constants for well-known URIs used throughout the A2A Python SDK.""" 2 | 3 | AGENT_CARD_WELL_KNOWN_PATH = '/.well-known/agent-card.json' 4 | PREV_AGENT_CARD_WELL_KNOWN_PATH = '/.well-known/agent.json' 5 | EXTENDED_AGENT_CARD_PATH = '/agent/authenticatedExtendedCard' 6 | DEFAULT_RPC_URL = '/' 7 | -------------------------------------------------------------------------------- /src/a2a/utils/error_handlers.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import logging 3 | 4 | from collections.abc import Awaitable, Callable, Coroutine 5 | from typing import TYPE_CHECKING, Any 6 | 7 | 8 | if TYPE_CHECKING: 9 | from starlette.responses import JSONResponse, Response 10 | else: 11 | try: 12 | from starlette.responses import JSONResponse, Response 13 | except ImportError: 14 | JSONResponse = Any 15 | Response = Any 16 | 17 | 18 | from a2a._base import A2ABaseModel 19 | from a2a.types import ( 20 | AuthenticatedExtendedCardNotConfiguredError, 21 | ContentTypeNotSupportedError, 22 | InternalError, 23 | InvalidAgentResponseError, 24 | InvalidParamsError, 25 | InvalidRequestError, 26 | JSONParseError, 27 | MethodNotFoundError, 28 | PushNotificationNotSupportedError, 29 | TaskNotCancelableError, 30 | TaskNotFoundError, 31 | UnsupportedOperationError, 32 | ) 33 | from a2a.utils.errors import ServerError 34 | 35 | 36 | logger = logging.getLogger(__name__) 37 | 38 | A2AErrorToHttpStatus: dict[type[A2ABaseModel], int] = { 39 | JSONParseError: 400, 40 | InvalidRequestError: 400, 41 | MethodNotFoundError: 404, 42 | InvalidParamsError: 422, 43 | InternalError: 500, 44 | TaskNotFoundError: 404, 45 | TaskNotCancelableError: 409, 46 | PushNotificationNotSupportedError: 501, 47 | UnsupportedOperationError: 501, 48 | ContentTypeNotSupportedError: 415, 49 | InvalidAgentResponseError: 502, 50 | AuthenticatedExtendedCardNotConfiguredError: 404, 51 | } 52 | 53 | 54 | def rest_error_handler( 55 | func: Callable[..., Awaitable[Response]], 56 | ) -> Callable[..., Awaitable[Response]]: 57 | """Decorator to catch ServerError and map it to an appropriate JSONResponse.""" 58 | 59 | @functools.wraps(func) 60 | async def wrapper(*args: Any, **kwargs: Any) -> Response: 61 | try: 62 | return await func(*args, **kwargs) 63 | except ServerError as e: 64 | error = e.error or InternalError( 65 | message='Internal error due to unknown reason' 66 | ) 67 | http_code = A2AErrorToHttpStatus.get(type(error), 500) 68 | 69 | log_level = ( 70 | logging.ERROR 71 | if isinstance(error, InternalError) 72 | else logging.WARNING 73 | ) 74 | logger.log( 75 | log_level, 76 | "Request error: Code=%s, Message='%s'%s", 77 | error.code, 78 | error.message, 79 | ', Data=' + str(error.data) if error.data else '', 80 | ) 81 | return JSONResponse( 82 | content={'message': error.message}, status_code=http_code 83 | ) 84 | except Exception: 85 | logger.exception('Unknown error occurred') 86 | return JSONResponse( 87 | content={'message': 'unknown exception'}, status_code=500 88 | ) 89 | 90 | return wrapper 91 | 92 | 93 | def rest_stream_error_handler( 94 | func: Callable[..., Coroutine[Any, Any, Any]], 95 | ) -> Callable[..., Coroutine[Any, Any, Any]]: 96 | """Decorator to catch ServerError for a streaming method,log it and then rethrow it to be handled by framework.""" 97 | 98 | @functools.wraps(func) 99 | async def wrapper(*args: Any, **kwargs: Any) -> Any: 100 | try: 101 | return await func(*args, **kwargs) 102 | except ServerError as e: 103 | error = e.error or InternalError( 104 | message='Internal error due to unknown reason' 105 | ) 106 | 107 | log_level = ( 108 | logging.ERROR 109 | if isinstance(error, InternalError) 110 | else logging.WARNING 111 | ) 112 | logger.log( 113 | log_level, 114 | "Request error: Code=%s, Message='%s'%s", 115 | error.code, 116 | error.message, 117 | ', Data=' + str(error.data) if error.data else '', 118 | ) 119 | # Since the stream has started, we can't return a JSONResponse. 120 | # Instead, we runt the error handling logic (provides logging) 121 | # and reraise the error and let server framework manage 122 | raise e 123 | except Exception as e: 124 | # Since the stream has started, we can't return a JSONResponse. 125 | # Instead, we runt the error handling logic (provides logging) 126 | # and reraise the error and let server framework manage 127 | raise e 128 | 129 | return wrapper 130 | -------------------------------------------------------------------------------- /src/a2a/utils/errors.py: -------------------------------------------------------------------------------- 1 | """Custom exceptions for A2A server-side errors.""" 2 | 3 | from a2a.types import ( 4 | AuthenticatedExtendedCardNotConfiguredError, 5 | ContentTypeNotSupportedError, 6 | InternalError, 7 | InvalidAgentResponseError, 8 | InvalidParamsError, 9 | InvalidRequestError, 10 | JSONParseError, 11 | JSONRPCError, 12 | MethodNotFoundError, 13 | PushNotificationNotSupportedError, 14 | TaskNotCancelableError, 15 | TaskNotFoundError, 16 | UnsupportedOperationError, 17 | ) 18 | 19 | 20 | class A2AServerError(Exception): 21 | """Base exception for A2A Server errors.""" 22 | 23 | 24 | class MethodNotImplementedError(A2AServerError): 25 | """Exception raised for methods that are not implemented by the server handler.""" 26 | 27 | def __init__( 28 | self, message: str = 'This method is not implemented by the server' 29 | ): 30 | """Initializes the MethodNotImplementedError. 31 | 32 | Args: 33 | message: A descriptive error message. 34 | """ 35 | self.message = message 36 | super().__init__(f'Not Implemented operation Error: {message}') 37 | 38 | 39 | class ServerError(Exception): 40 | """Wrapper exception for A2A or JSON-RPC errors originating from the server's logic. 41 | 42 | This exception is used internally by request handlers and other server components 43 | to signal a specific error that should be formatted as a JSON-RPC error response. 44 | """ 45 | 46 | def __init__( 47 | self, 48 | error: ( 49 | JSONRPCError 50 | | JSONParseError 51 | | InvalidRequestError 52 | | MethodNotFoundError 53 | | InvalidParamsError 54 | | InternalError 55 | | TaskNotFoundError 56 | | TaskNotCancelableError 57 | | PushNotificationNotSupportedError 58 | | UnsupportedOperationError 59 | | ContentTypeNotSupportedError 60 | | InvalidAgentResponseError 61 | | AuthenticatedExtendedCardNotConfiguredError 62 | | None 63 | ), 64 | ): 65 | """Initializes the ServerError. 66 | 67 | Args: 68 | error: The specific A2A or JSON-RPC error model instance. 69 | """ 70 | self.error = error 71 | 72 | def __str__(self) -> str: 73 | """Returns a readable representation of the internal Pydantic error.""" 74 | if self.error is None: 75 | return 'None' 76 | if self.error.message is None: 77 | return self.error.__class__.__name__ 78 | return self.error.message 79 | 80 | def __repr__(self) -> str: 81 | """Returns an unambiguous representation for developers showing how the ServerError was constructed with the internal Pydantic error.""" 82 | return f'{self.__class__.__name__}({self.error!r})' 83 | -------------------------------------------------------------------------------- /src/a2a/utils/message.py: -------------------------------------------------------------------------------- 1 | """Utility functions for creating and handling A2A Message objects.""" 2 | 3 | import uuid 4 | 5 | from typing import Any 6 | 7 | from a2a.types import ( 8 | DataPart, 9 | FilePart, 10 | FileWithBytes, 11 | FileWithUri, 12 | Message, 13 | Part, 14 | Role, 15 | TextPart, 16 | ) 17 | 18 | 19 | def new_agent_text_message( 20 | text: str, 21 | context_id: str | None = None, 22 | task_id: str | None = None, 23 | ) -> Message: 24 | """Creates a new agent message containing a single TextPart. 25 | 26 | Args: 27 | text: The text content of the message. 28 | context_id: The context ID for the message. 29 | task_id: The task ID for the message. 30 | 31 | Returns: 32 | A new `Message` object with role 'agent'. 33 | """ 34 | return Message( 35 | role=Role.agent, 36 | parts=[Part(root=TextPart(text=text))], 37 | message_id=str(uuid.uuid4()), 38 | task_id=task_id, 39 | context_id=context_id, 40 | ) 41 | 42 | 43 | def new_agent_parts_message( 44 | parts: list[Part], 45 | context_id: str | None = None, 46 | task_id: str | None = None, 47 | ) -> Message: 48 | """Creates a new agent message containing a list of Parts. 49 | 50 | Args: 51 | parts: The list of `Part` objects for the message content. 52 | context_id: The context ID for the message. 53 | task_id: The task ID for the message. 54 | 55 | Returns: 56 | A new `Message` object with role 'agent'. 57 | """ 58 | return Message( 59 | role=Role.agent, 60 | parts=parts, 61 | message_id=str(uuid.uuid4()), 62 | task_id=task_id, 63 | context_id=context_id, 64 | ) 65 | 66 | 67 | def get_text_parts(parts: list[Part]) -> list[str]: 68 | """Extracts text content from all TextPart objects in a list of Parts. 69 | 70 | Args: 71 | parts: A list of `Part` objects. 72 | 73 | Returns: 74 | A list of strings containing the text content from any `TextPart` objects found. 75 | """ 76 | return [part.root.text for part in parts if isinstance(part.root, TextPart)] 77 | 78 | 79 | def get_data_parts(parts: list[Part]) -> list[dict[str, Any]]: 80 | """Extracts dictionary data from all DataPart objects in a list of Parts. 81 | 82 | Args: 83 | parts: A list of `Part` objects. 84 | 85 | Returns: 86 | A list of dictionaries containing the data from any `DataPart` objects found. 87 | """ 88 | return [part.root.data for part in parts if isinstance(part.root, DataPart)] 89 | 90 | 91 | def get_file_parts(parts: list[Part]) -> list[FileWithBytes | FileWithUri]: 92 | """Extracts file data from all FilePart objects in a list of Parts. 93 | 94 | Args: 95 | parts: A list of `Part` objects. 96 | 97 | Returns: 98 | A list of `FileWithBytes` or `FileWithUri` objects containing the file data from any `FilePart` objects found. 99 | """ 100 | return [part.root.file for part in parts if isinstance(part.root, FilePart)] 101 | 102 | 103 | def get_message_text(message: Message, delimiter: str = '\n') -> str: 104 | """Extracts and joins all text content from a Message's parts. 105 | 106 | Args: 107 | message: The `Message` object. 108 | delimiter: The string to use when joining text from multiple TextParts. 109 | 110 | Returns: 111 | A single string containing all text content, or an empty string if no text parts are found. 112 | """ 113 | return delimiter.join(get_text_parts(message.parts)) 114 | -------------------------------------------------------------------------------- /src/a2a/utils/task.py: -------------------------------------------------------------------------------- 1 | """Utility functions for creating A2A Task objects.""" 2 | 3 | import uuid 4 | 5 | from a2a.types import Artifact, Message, Task, TaskState, TaskStatus, TextPart 6 | 7 | 8 | def new_task(request: Message) -> Task: 9 | """Creates a new Task object from an initial user message. 10 | 11 | Generates task and context IDs if not provided in the message. 12 | 13 | Args: 14 | request: The initial `Message` object from the user. 15 | 16 | Returns: 17 | A new `Task` object initialized with 'submitted' status and the input message in history. 18 | 19 | Raises: 20 | TypeError: If the message role is None. 21 | ValueError: If the message parts are empty, if any part has empty content, or if the provided context_id is invalid. 22 | """ 23 | if not request.role: 24 | raise TypeError('Message role cannot be None') 25 | if not request.parts: 26 | raise ValueError('Message parts cannot be empty') 27 | for part in request.parts: 28 | if isinstance(part.root, TextPart) and not part.root.text: 29 | raise ValueError('TextPart content cannot be empty') 30 | 31 | return Task( 32 | status=TaskStatus(state=TaskState.submitted), 33 | id=request.task_id or str(uuid.uuid4()), 34 | context_id=request.context_id or str(uuid.uuid4()), 35 | history=[request], 36 | ) 37 | 38 | 39 | def completed_task( 40 | task_id: str, 41 | context_id: str, 42 | artifacts: list[Artifact], 43 | history: list[Message] | None = None, 44 | ) -> Task: 45 | """Creates a Task object in the 'completed' state. 46 | 47 | Useful for constructing a final Task representation when the agent 48 | finishes and produces artifacts. 49 | 50 | Args: 51 | task_id: The ID of the task. 52 | context_id: The context ID of the task. 53 | artifacts: A list of `Artifact` objects produced by the task. 54 | history: An optional list of `Message` objects representing the task history. 55 | 56 | Returns: 57 | A `Task` object with status set to 'completed'. 58 | """ 59 | if not artifacts or not all(isinstance(a, Artifact) for a in artifacts): 60 | raise ValueError( 61 | 'artifacts must be a non-empty list of Artifact objects' 62 | ) 63 | 64 | if history is None: 65 | history = [] 66 | return Task( 67 | status=TaskStatus(state=TaskState.completed), 68 | id=task_id, 69 | context_id=context_id, 70 | artifacts=artifacts, 71 | history=history, 72 | ) 73 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | ## Running the tests 2 | 3 | 1. Run the tests 4 | ```bash 5 | uv run pytest -v -s client/test_client.py 6 | ``` 7 | 8 | In case of failures, you can cleanup the cache: 9 | 10 | 1. `uv clean` 11 | 2. `rm -fR .pytest_cache .venv __pycache__` 12 | -------------------------------------------------------------------------------- /tests/auth/test_user.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from a2a.auth.user import UnauthenticatedUser 4 | 5 | 6 | class TestUnauthenticatedUser(unittest.TestCase): 7 | def test_is_authenticated_returns_false(self): 8 | user = UnauthenticatedUser() 9 | self.assertFalse(user.is_authenticated) 10 | 11 | def test_user_name_returns_empty_string(self): 12 | user = UnauthenticatedUser() 13 | self.assertEqual(user.user_name, '') 14 | 15 | 16 | if __name__ == '__main__': 17 | unittest.main() 18 | -------------------------------------------------------------------------------- /tests/client/test_base_client.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import AsyncMock, MagicMock 2 | 3 | import pytest 4 | 5 | from a2a.client.base_client import BaseClient 6 | from a2a.client.client import ClientConfig 7 | from a2a.client.transports.base import ClientTransport 8 | from a2a.types import ( 9 | AgentCapabilities, 10 | AgentCard, 11 | Message, 12 | Part, 13 | Role, 14 | Task, 15 | TaskState, 16 | TaskStatus, 17 | TextPart, 18 | ) 19 | 20 | 21 | @pytest.fixture 22 | def mock_transport(): 23 | return AsyncMock(spec=ClientTransport) 24 | 25 | 26 | @pytest.fixture 27 | def sample_agent_card(): 28 | return AgentCard( 29 | name='Test Agent', 30 | description='An agent for testing', 31 | url='http://test.com', 32 | version='1.0', 33 | capabilities=AgentCapabilities(streaming=True), 34 | default_input_modes=['text/plain'], 35 | default_output_modes=['text/plain'], 36 | skills=[], 37 | ) 38 | 39 | 40 | @pytest.fixture 41 | def sample_message(): 42 | return Message( 43 | role=Role.user, 44 | message_id='msg-1', 45 | parts=[Part(root=TextPart(text='Hello'))], 46 | ) 47 | 48 | 49 | @pytest.fixture 50 | def base_client(sample_agent_card, mock_transport): 51 | config = ClientConfig(streaming=True) 52 | return BaseClient( 53 | card=sample_agent_card, 54 | config=config, 55 | transport=mock_transport, 56 | consumers=[], 57 | middleware=[], 58 | ) 59 | 60 | 61 | @pytest.mark.asyncio 62 | async def test_send_message_streaming( 63 | base_client: BaseClient, mock_transport: MagicMock, sample_message: Message 64 | ): 65 | async def create_stream(*args, **kwargs): 66 | yield Task( 67 | id='task-123', 68 | context_id='ctx-456', 69 | status=TaskStatus(state=TaskState.completed), 70 | ) 71 | 72 | mock_transport.send_message_streaming.return_value = create_stream() 73 | 74 | events = [event async for event in base_client.send_message(sample_message)] 75 | 76 | mock_transport.send_message_streaming.assert_called_once() 77 | assert not mock_transport.send_message.called 78 | assert len(events) == 1 79 | assert events[0][0].id == 'task-123' 80 | 81 | 82 | @pytest.mark.asyncio 83 | async def test_send_message_non_streaming( 84 | base_client: BaseClient, mock_transport: MagicMock, sample_message: Message 85 | ): 86 | base_client._config.streaming = False 87 | mock_transport.send_message.return_value = Task( 88 | id='task-456', 89 | context_id='ctx-789', 90 | status=TaskStatus(state=TaskState.completed), 91 | ) 92 | 93 | events = [event async for event in base_client.send_message(sample_message)] 94 | 95 | mock_transport.send_message.assert_called_once() 96 | assert not mock_transport.send_message_streaming.called 97 | assert len(events) == 1 98 | assert events[0][0].id == 'task-456' 99 | 100 | 101 | @pytest.mark.asyncio 102 | async def test_send_message_non_streaming_agent_capability_false( 103 | base_client: BaseClient, mock_transport: MagicMock, sample_message: Message 104 | ): 105 | base_client._card.capabilities.streaming = False 106 | mock_transport.send_message.return_value = Task( 107 | id='task-789', 108 | context_id='ctx-101', 109 | status=TaskStatus(state=TaskState.completed), 110 | ) 111 | 112 | events = [event async for event in base_client.send_message(sample_message)] 113 | 114 | mock_transport.send_message.assert_called_once() 115 | assert not mock_transport.send_message_streaming.called 116 | assert len(events) == 1 117 | assert events[0][0].id == 'task-789' 118 | -------------------------------------------------------------------------------- /tests/client/test_client_factory.py: -------------------------------------------------------------------------------- 1 | """Tests for the ClientFactory.""" 2 | 3 | import httpx 4 | import pytest 5 | 6 | from a2a.client import ClientConfig, ClientFactory 7 | from a2a.client.transports import JsonRpcTransport, RestTransport 8 | from a2a.types import ( 9 | AgentCapabilities, 10 | AgentCard, 11 | AgentInterface, 12 | TransportProtocol, 13 | ) 14 | 15 | 16 | @pytest.fixture 17 | def base_agent_card() -> AgentCard: 18 | """Provides a base AgentCard for tests.""" 19 | return AgentCard( 20 | name='Test Agent', 21 | description='An agent for testing.', 22 | url='http://primary-url.com', 23 | version='1.0.0', 24 | capabilities=AgentCapabilities(), 25 | skills=[], 26 | default_input_modes=[], 27 | default_output_modes=[], 28 | preferred_transport=TransportProtocol.jsonrpc, 29 | ) 30 | 31 | 32 | def test_client_factory_selects_preferred_transport(base_agent_card: AgentCard): 33 | """Verify that the factory selects the preferred transport by default.""" 34 | config = ClientConfig( 35 | httpx_client=httpx.AsyncClient(), 36 | supported_transports=[ 37 | TransportProtocol.jsonrpc, 38 | TransportProtocol.http_json, 39 | ], 40 | ) 41 | factory = ClientFactory(config) 42 | client = factory.create(base_agent_card) 43 | 44 | assert isinstance(client._transport, JsonRpcTransport) 45 | assert client._transport.url == 'http://primary-url.com' 46 | 47 | 48 | def test_client_factory_selects_secondary_transport_url( 49 | base_agent_card: AgentCard, 50 | ): 51 | """Verify that the factory selects the correct URL for a secondary transport.""" 52 | base_agent_card.additional_interfaces = [ 53 | AgentInterface( 54 | transport=TransportProtocol.http_json, 55 | url='http://secondary-url.com', 56 | ) 57 | ] 58 | # Client prefers REST, which is available as a secondary transport 59 | config = ClientConfig( 60 | httpx_client=httpx.AsyncClient(), 61 | supported_transports=[ 62 | TransportProtocol.http_json, 63 | TransportProtocol.jsonrpc, 64 | ], 65 | use_client_preference=True, 66 | ) 67 | factory = ClientFactory(config) 68 | client = factory.create(base_agent_card) 69 | 70 | assert isinstance(client._transport, RestTransport) 71 | assert client._transport.url == 'http://secondary-url.com' 72 | 73 | 74 | def test_client_factory_server_preference(base_agent_card: AgentCard): 75 | """Verify that the factory respects server transport preference.""" 76 | base_agent_card.preferred_transport = TransportProtocol.http_json 77 | base_agent_card.additional_interfaces = [ 78 | AgentInterface( 79 | transport=TransportProtocol.jsonrpc, url='http://secondary-url.com' 80 | ) 81 | ] 82 | # Client supports both, but server prefers REST 83 | config = ClientConfig( 84 | httpx_client=httpx.AsyncClient(), 85 | supported_transports=[ 86 | TransportProtocol.jsonrpc, 87 | TransportProtocol.http_json, 88 | ], 89 | ) 90 | factory = ClientFactory(config) 91 | client = factory.create(base_agent_card) 92 | 93 | assert isinstance(client._transport, RestTransport) 94 | assert client._transport.url == 'http://primary-url.com' 95 | 96 | 97 | def test_client_factory_no_compatible_transport(base_agent_card: AgentCard): 98 | """Verify that the factory raises an error if no compatible transport is found.""" 99 | config = ClientConfig( 100 | httpx_client=httpx.AsyncClient(), 101 | supported_transports=[TransportProtocol.grpc], 102 | ) 103 | factory = ClientFactory(config) 104 | with pytest.raises(ValueError, match='no compatible transports found'): 105 | factory.create(base_agent_card) 106 | -------------------------------------------------------------------------------- /tests/client/test_client_task_manager.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import AsyncMock, Mock, patch 2 | 3 | import pytest 4 | 5 | from a2a.client.client_task_manager import ClientTaskManager 6 | from a2a.client.errors import ( 7 | A2AClientInvalidArgsError, 8 | A2AClientInvalidStateError, 9 | ) 10 | from a2a.types import ( 11 | Artifact, 12 | Message, 13 | Part, 14 | Role, 15 | Task, 16 | TaskArtifactUpdateEvent, 17 | TaskState, 18 | TaskStatus, 19 | TaskStatusUpdateEvent, 20 | TextPart, 21 | ) 22 | 23 | 24 | @pytest.fixture 25 | def task_manager(): 26 | return ClientTaskManager() 27 | 28 | 29 | @pytest.fixture 30 | def sample_task(): 31 | return Task( 32 | id='task123', 33 | context_id='context456', 34 | status=TaskStatus(state=TaskState.working), 35 | history=[], 36 | artifacts=[], 37 | ) 38 | 39 | 40 | @pytest.fixture 41 | def sample_message(): 42 | return Message( 43 | message_id='msg1', 44 | role=Role.user, 45 | parts=[Part(root=TextPart(text='Hello'))], 46 | ) 47 | 48 | 49 | def test_get_task_no_task_id_returns_none(task_manager: ClientTaskManager): 50 | assert task_manager.get_task() is None 51 | 52 | 53 | def test_get_task_or_raise_no_task_raises_error( 54 | task_manager: ClientTaskManager, 55 | ): 56 | with pytest.raises(A2AClientInvalidStateError, match='no current Task'): 57 | task_manager.get_task_or_raise() 58 | 59 | 60 | @pytest.mark.asyncio 61 | async def test_save_task_event_with_task( 62 | task_manager: ClientTaskManager, sample_task: Task 63 | ): 64 | await task_manager.save_task_event(sample_task) 65 | assert task_manager.get_task() == sample_task 66 | assert task_manager._task_id == sample_task.id 67 | assert task_manager._context_id == sample_task.context_id 68 | 69 | 70 | @pytest.mark.asyncio 71 | async def test_save_task_event_with_task_already_set_raises_error( 72 | task_manager: ClientTaskManager, sample_task: Task 73 | ): 74 | await task_manager.save_task_event(sample_task) 75 | with pytest.raises( 76 | A2AClientInvalidArgsError, 77 | match='Task is already set, create new manager for new tasks.', 78 | ): 79 | await task_manager.save_task_event(sample_task) 80 | 81 | 82 | @pytest.mark.asyncio 83 | async def test_save_task_event_with_status_update( 84 | task_manager: ClientTaskManager, sample_task: Task, sample_message: Message 85 | ): 86 | await task_manager.save_task_event(sample_task) 87 | status_update = TaskStatusUpdateEvent( 88 | task_id=sample_task.id, 89 | context_id=sample_task.context_id, 90 | status=TaskStatus(state=TaskState.completed, message=sample_message), 91 | final=True, 92 | ) 93 | updated_task = await task_manager.save_task_event(status_update) 94 | assert updated_task.status.state == TaskState.completed 95 | assert updated_task.history == [sample_message] 96 | 97 | 98 | @pytest.mark.asyncio 99 | async def test_save_task_event_with_artifact_update( 100 | task_manager: ClientTaskManager, sample_task: Task 101 | ): 102 | await task_manager.save_task_event(sample_task) 103 | artifact = Artifact( 104 | artifact_id='art1', parts=[Part(root=TextPart(text='artifact content'))] 105 | ) 106 | artifact_update = TaskArtifactUpdateEvent( 107 | task_id=sample_task.id, 108 | context_id=sample_task.context_id, 109 | artifact=artifact, 110 | ) 111 | 112 | with patch( 113 | 'a2a.client.client_task_manager.append_artifact_to_task' 114 | ) as mock_append: 115 | updated_task = await task_manager.save_task_event(artifact_update) 116 | mock_append.assert_called_once_with(updated_task, artifact_update) 117 | 118 | 119 | @pytest.mark.asyncio 120 | async def test_save_task_event_creates_task_if_not_exists( 121 | task_manager: ClientTaskManager, 122 | ): 123 | status_update = TaskStatusUpdateEvent( 124 | task_id='new_task', 125 | context_id='new_context', 126 | status=TaskStatus(state=TaskState.working), 127 | final=False, 128 | ) 129 | updated_task = await task_manager.save_task_event(status_update) 130 | assert updated_task is not None 131 | assert updated_task.id == 'new_task' 132 | assert updated_task.status.state == TaskState.working 133 | 134 | 135 | @pytest.mark.asyncio 136 | async def test_process_with_task_event( 137 | task_manager: ClientTaskManager, sample_task: Task 138 | ): 139 | with patch.object( 140 | task_manager, 'save_task_event', new_callable=AsyncMock 141 | ) as mock_save: 142 | await task_manager.process(sample_task) 143 | mock_save.assert_called_once_with(sample_task) 144 | 145 | 146 | @pytest.mark.asyncio 147 | async def test_process_with_non_task_event(task_manager: ClientTaskManager): 148 | with patch.object( 149 | task_manager, 'save_task_event', new_callable=Mock 150 | ) as mock_save: 151 | non_task_event = 'not a task event' 152 | await task_manager.process(non_task_event) 153 | mock_save.assert_not_called() 154 | 155 | 156 | def test_update_with_message( 157 | task_manager: ClientTaskManager, sample_task: Task, sample_message: Message 158 | ): 159 | updated_task = task_manager.update_with_message(sample_message, sample_task) 160 | assert updated_task.history == [sample_message] 161 | 162 | 163 | def test_update_with_message_moves_status_message( 164 | task_manager: ClientTaskManager, sample_task: Task, sample_message: Message 165 | ): 166 | status_message = Message( 167 | message_id='status_msg', 168 | role=Role.agent, 169 | parts=[Part(root=TextPart(text='Status'))], 170 | ) 171 | sample_task.status.message = status_message 172 | updated_task = task_manager.update_with_message(sample_message, sample_task) 173 | assert updated_task.history == [status_message, sample_message] 174 | assert updated_task.status.message is None 175 | -------------------------------------------------------------------------------- /tests/client/test_legacy_client.py: -------------------------------------------------------------------------------- 1 | """Tests for the legacy client compatibility layer.""" 2 | 3 | from unittest.mock import AsyncMock, MagicMock 4 | 5 | import httpx 6 | import pytest 7 | 8 | from a2a.client import A2AClient, A2AGrpcClient 9 | from a2a.types import ( 10 | AgentCapabilities, 11 | AgentCard, 12 | Message, 13 | MessageSendParams, 14 | Part, 15 | Role, 16 | SendMessageRequest, 17 | Task, 18 | TaskQueryParams, 19 | TaskState, 20 | TaskStatus, 21 | TextPart, 22 | ) 23 | 24 | 25 | @pytest.fixture 26 | def mock_httpx_client() -> AsyncMock: 27 | return AsyncMock(spec=httpx.AsyncClient) 28 | 29 | 30 | @pytest.fixture 31 | def mock_grpc_stub() -> AsyncMock: 32 | stub = AsyncMock() 33 | stub._channel = MagicMock() 34 | return stub 35 | 36 | 37 | @pytest.fixture 38 | def jsonrpc_agent_card() -> AgentCard: 39 | return AgentCard( 40 | name='Test Agent', 41 | description='A test agent', 42 | url='http://test.agent.com/rpc', 43 | version='1.0.0', 44 | capabilities=AgentCapabilities(streaming=True), 45 | skills=[], 46 | default_input_modes=[], 47 | default_output_modes=[], 48 | preferred_transport='jsonrpc', 49 | ) 50 | 51 | 52 | @pytest.fixture 53 | def grpc_agent_card() -> AgentCard: 54 | return AgentCard( 55 | name='Test Agent', 56 | description='A test agent', 57 | url='http://test.agent.com/rpc', 58 | version='1.0.0', 59 | capabilities=AgentCapabilities(streaming=True), 60 | skills=[], 61 | default_input_modes=[], 62 | default_output_modes=[], 63 | preferred_transport='grpc', 64 | ) 65 | 66 | 67 | @pytest.mark.asyncio 68 | async def test_a2a_client_send_message( 69 | mock_httpx_client: AsyncMock, jsonrpc_agent_card: AgentCard 70 | ): 71 | client = A2AClient( 72 | httpx_client=mock_httpx_client, agent_card=jsonrpc_agent_card 73 | ) 74 | 75 | # Mock the underlying transport's send_message method 76 | mock_response_task = Task( 77 | id='task-123', 78 | context_id='ctx-456', 79 | status=TaskStatus(state=TaskState.completed), 80 | ) 81 | 82 | client._transport.send_message = AsyncMock(return_value=mock_response_task) 83 | 84 | message = Message( 85 | message_id='msg-123', 86 | role=Role.user, 87 | parts=[Part(root=TextPart(text='Hello'))], 88 | ) 89 | request = SendMessageRequest( 90 | id='req-123', params=MessageSendParams(message=message) 91 | ) 92 | response = await client.send_message(request) 93 | 94 | assert response.root.result.id == 'task-123' 95 | 96 | 97 | @pytest.mark.asyncio 98 | async def test_a2a_grpc_client_get_task( 99 | mock_grpc_stub: AsyncMock, grpc_agent_card: AgentCard 100 | ): 101 | client = A2AGrpcClient(grpc_stub=mock_grpc_stub, agent_card=grpc_agent_card) 102 | 103 | mock_response_task = Task( 104 | id='task-456', 105 | context_id='ctx-789', 106 | status=TaskStatus(state=TaskState.working), 107 | ) 108 | 109 | client.get_task = AsyncMock(return_value=mock_response_task) 110 | 111 | params = TaskQueryParams(id='task-456') 112 | response = await client.get_task(params) 113 | 114 | assert response.id == 'task-456' 115 | client.get_task.assert_awaited_once_with(params) 116 | -------------------------------------------------------------------------------- /tests/client/test_optionals.py: -------------------------------------------------------------------------------- 1 | """Tests for a2a.client.optionals module.""" 2 | 3 | import importlib 4 | import sys 5 | 6 | from unittest.mock import patch 7 | 8 | 9 | def test_channel_import_failure(): 10 | """Test Channel behavior when grpc is not available.""" 11 | with patch.dict('sys.modules', {'grpc': None, 'grpc.aio': None}): 12 | if 'a2a.client.optionals' in sys.modules: 13 | del sys.modules['a2a.client.optionals'] 14 | 15 | optionals = importlib.import_module('a2a.client.optionals') 16 | assert optionals.Channel is None 17 | -------------------------------------------------------------------------------- /tests/e2e/push_notifications/agent_app.py: -------------------------------------------------------------------------------- 1 | import httpx 2 | 3 | from fastapi import FastAPI 4 | 5 | from a2a.server.agent_execution import AgentExecutor, RequestContext 6 | from a2a.server.apps import A2ARESTFastAPIApplication 7 | from a2a.server.events import EventQueue 8 | from a2a.server.request_handlers import DefaultRequestHandler 9 | from a2a.server.tasks import ( 10 | BasePushNotificationSender, 11 | InMemoryPushNotificationConfigStore, 12 | InMemoryTaskStore, 13 | TaskUpdater, 14 | ) 15 | from a2a.types import ( 16 | AgentCapabilities, 17 | AgentCard, 18 | AgentSkill, 19 | InvalidParamsError, 20 | Message, 21 | Task, 22 | ) 23 | from a2a.utils import ( 24 | new_agent_text_message, 25 | new_task, 26 | ) 27 | from a2a.utils.errors import ServerError 28 | 29 | 30 | def test_agent_card(url: str) -> AgentCard: 31 | """Returns an agent card for the test agent.""" 32 | return AgentCard( 33 | name='Test Agent', 34 | description='Just a test agent', 35 | url=url, 36 | version='1.0.0', 37 | default_input_modes=['text'], 38 | default_output_modes=['text'], 39 | capabilities=AgentCapabilities(streaming=True, push_notifications=True), 40 | skills=[ 41 | AgentSkill( 42 | id='greeting', 43 | name='Greeting Agent', 44 | description='just greets the user', 45 | tags=['greeting'], 46 | examples=['Hello Agent!', 'How are you?'], 47 | ) 48 | ], 49 | supports_authenticated_extended_card=True, 50 | ) 51 | 52 | 53 | class TestAgent: 54 | """Agent for push notification testing.""" 55 | 56 | async def invoke( 57 | self, updater: TaskUpdater, msg: Message, task: Task 58 | ) -> None: 59 | # Fail for unsupported messages. 60 | if ( 61 | not msg.parts 62 | or len(msg.parts) != 1 63 | or msg.parts[0].root.kind != 'text' 64 | ): 65 | await updater.failed( 66 | new_agent_text_message( 67 | 'Unsupported message.', task.context_id, task.id 68 | ) 69 | ) 70 | return 71 | text_message = msg.parts[0].root.text 72 | 73 | # Simple request-response flow. 74 | if text_message == 'Hello Agent!': 75 | await updater.complete( 76 | new_agent_text_message('Hello User!', task.context_id, task.id) 77 | ) 78 | 79 | # Flow with user input required: "How are you?" -> "Good! How are you?" -> "Good" -> "Amazing". 80 | elif text_message == 'How are you?': 81 | await updater.requires_input( 82 | new_agent_text_message( 83 | 'Good! How are you?', task.context_id, task.id 84 | ) 85 | ) 86 | elif text_message == 'Good': 87 | await updater.complete( 88 | new_agent_text_message('Amazing', task.context_id, task.id) 89 | ) 90 | 91 | # Fail for unsupported messages. 92 | else: 93 | await updater.failed( 94 | new_agent_text_message( 95 | 'Unsupported message.', task.context_id, task.id 96 | ) 97 | ) 98 | 99 | 100 | class TestAgentExecutor(AgentExecutor): 101 | """Test AgentExecutor implementation.""" 102 | 103 | def __init__(self) -> None: 104 | self.agent = TestAgent() 105 | 106 | async def execute( 107 | self, 108 | context: RequestContext, 109 | event_queue: EventQueue, 110 | ) -> None: 111 | if not context.message: 112 | raise ServerError(error=InvalidParamsError(message='No message')) 113 | 114 | task = context.current_task 115 | if not task: 116 | task = new_task(context.message) 117 | await event_queue.enqueue_event(task) 118 | updater = TaskUpdater(event_queue, task.id, task.context_id) 119 | 120 | await self.agent.invoke(updater, context.message, task) 121 | 122 | async def cancel( 123 | self, context: RequestContext, event_queue: EventQueue 124 | ) -> None: 125 | raise NotImplementedError('cancel not supported') 126 | 127 | 128 | def create_agent_app( 129 | url: str, notification_client: httpx.AsyncClient 130 | ) -> FastAPI: 131 | """Creates a new HTTP+REST FastAPI application for the test agent.""" 132 | push_config_store = InMemoryPushNotificationConfigStore() 133 | app = A2ARESTFastAPIApplication( 134 | agent_card=test_agent_card(url), 135 | http_handler=DefaultRequestHandler( 136 | agent_executor=TestAgentExecutor(), 137 | task_store=InMemoryTaskStore(), 138 | push_config_store=push_config_store, 139 | push_sender=BasePushNotificationSender( 140 | httpx_client=notification_client, 141 | config_store=push_config_store, 142 | ), 143 | ), 144 | ) 145 | return app.build() 146 | -------------------------------------------------------------------------------- /tests/e2e/push_notifications/notifications_app.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from typing import Annotated 4 | 5 | from fastapi import FastAPI, HTTPException, Path, Request 6 | from pydantic import BaseModel, ValidationError 7 | 8 | from a2a.types import Task 9 | 10 | 11 | class Notification(BaseModel): 12 | """Encapsulates default push notification data.""" 13 | 14 | task: Task 15 | token: str 16 | 17 | 18 | def create_notifications_app() -> FastAPI: 19 | """Creates a simple push notification ingesting HTTP+REST application.""" 20 | app = FastAPI() 21 | store_lock = asyncio.Lock() 22 | store: dict[str, list[Notification]] = {} 23 | 24 | @app.post('/notifications') 25 | async def add_notification(request: Request): 26 | """Endpoint for injesting notifications from agents. It receives a JSON 27 | payload and stores it in-memory. 28 | """ 29 | token = request.headers.get('x-a2a-notification-token') 30 | if not token: 31 | raise HTTPException( 32 | status_code=400, 33 | detail='Missing "x-a2a-notification-token" header.', 34 | ) 35 | try: 36 | task = Task.model_validate(await request.json()) 37 | except ValidationError as e: 38 | raise HTTPException(status_code=400, detail=str(e)) 39 | 40 | async with store_lock: 41 | if task.id not in store: 42 | store[task.id] = [] 43 | store[task.id].append( 44 | Notification( 45 | task=task, 46 | token=token, 47 | ) 48 | ) 49 | return { 50 | 'status': 'received', 51 | } 52 | 53 | @app.get('/tasks/{task_id}/notifications') 54 | async def list_notifications_by_task( 55 | task_id: Annotated[ 56 | str, Path(title='The ID of the task to list the notifications for.') 57 | ], 58 | ): 59 | """Helper endpoint for retrieving injested notifications for a given task.""" 60 | async with store_lock: 61 | notifications = store.get(task_id, []) 62 | return {'notifications': notifications} 63 | 64 | @app.get('/health') 65 | def health_check(): 66 | """Helper endpoint for checking if the server is up.""" 67 | return {'status': 'ok'} 68 | 69 | return app 70 | -------------------------------------------------------------------------------- /tests/e2e/push_notifications/utils.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import socket 3 | import time 4 | 5 | from multiprocessing import Process 6 | 7 | import httpx 8 | import uvicorn 9 | 10 | 11 | def find_free_port(): 12 | """Finds and returns an available ephemeral localhost port.""" 13 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 14 | s.bind(('127.0.0.1', 0)) 15 | return s.getsockname()[1] 16 | 17 | 18 | def run_server(app, host, port) -> None: 19 | """Runs a uvicorn server.""" 20 | uvicorn.run(app, host=host, port=port, log_level='warning') 21 | 22 | 23 | def wait_for_server_ready(url: str, timeout: int = 10) -> None: 24 | """Polls the provided URL endpoint until the server is up.""" 25 | start_time = time.time() 26 | while True: 27 | with contextlib.suppress(httpx.ConnectError): 28 | with httpx.Client() as client: 29 | response = client.get(url) 30 | if response.status_code == 200: 31 | return 32 | if time.time() - start_time > timeout: 33 | raise TimeoutError( 34 | f'Server at {url} failed to start after {timeout}s' 35 | ) 36 | time.sleep(0.1) 37 | 38 | 39 | def create_app_process(app, host, port) -> Process: 40 | """Creates a separate process for a given application.""" 41 | return Process( 42 | target=run_server, 43 | args=(app, host, port), 44 | daemon=True, 45 | ) 46 | -------------------------------------------------------------------------------- /tests/extensions/test_common.py: -------------------------------------------------------------------------------- 1 | from a2a.extensions.common import ( 2 | find_extension_by_uri, 3 | get_requested_extensions, 4 | ) 5 | from a2a.types import AgentCapabilities, AgentCard, AgentExtension 6 | 7 | 8 | def test_get_requested_extensions(): 9 | assert get_requested_extensions([]) == set() 10 | assert get_requested_extensions(['foo']) == {'foo'} 11 | assert get_requested_extensions(['foo', 'bar']) == {'foo', 'bar'} 12 | assert get_requested_extensions(['foo, bar']) == {'foo', 'bar'} 13 | assert get_requested_extensions(['foo,bar']) == {'foo', 'bar'} 14 | assert get_requested_extensions(['foo', 'bar,baz']) == {'foo', 'bar', 'baz'} 15 | assert get_requested_extensions(['foo,, bar', 'baz']) == { 16 | 'foo', 17 | 'bar', 18 | 'baz', 19 | } 20 | assert get_requested_extensions([' foo , bar ', 'baz']) == { 21 | 'foo', 22 | 'bar', 23 | 'baz', 24 | } 25 | 26 | 27 | def test_find_extension_by_uri(): 28 | ext1 = AgentExtension(uri='foo', description='The Foo extension') 29 | ext2 = AgentExtension(uri='bar', description='The Bar extension') 30 | card = AgentCard( 31 | name='Test Agent', 32 | description='Test Agent Description', 33 | version='1.0', 34 | url='http://test.com', 35 | skills=[], 36 | default_input_modes=['text/plain'], 37 | default_output_modes=['text/plain'], 38 | capabilities=AgentCapabilities(extensions=[ext1, ext2]), 39 | ) 40 | 41 | assert find_extension_by_uri(card, 'foo') == ext1 42 | assert find_extension_by_uri(card, 'bar') == ext2 43 | assert find_extension_by_uri(card, 'baz') is None 44 | 45 | 46 | def test_find_extension_by_uri_no_extensions(): 47 | card = AgentCard( 48 | name='Test Agent', 49 | description='Test Agent Description', 50 | version='1.0', 51 | url='http://test.com', 52 | skills=[], 53 | default_input_modes=['text/plain'], 54 | default_output_modes=['text/plain'], 55 | capabilities=AgentCapabilities(extensions=None), 56 | ) 57 | 58 | assert find_extension_by_uri(card, 'foo') is None 59 | -------------------------------------------------------------------------------- /tests/server/apps/jsonrpc/test_fastapi_app.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from unittest.mock import MagicMock 3 | 4 | import pytest 5 | 6 | from a2a.server.apps.jsonrpc import fastapi_app 7 | from a2a.server.apps.jsonrpc.fastapi_app import A2AFastAPIApplication 8 | from a2a.server.request_handlers.request_handler import ( 9 | RequestHandler, # For mock spec 10 | ) 11 | from a2a.types import AgentCard # For mock spec 12 | 13 | 14 | # --- A2AFastAPIApplication Tests --- 15 | 16 | 17 | class TestA2AFastAPIApplicationOptionalDeps: 18 | # Running tests in this class requires the optional dependency fastapi to be 19 | # present in the test environment. 20 | 21 | @pytest.fixture(scope='class', autouse=True) 22 | def ensure_pkg_fastapi_is_present(self): 23 | try: 24 | import fastapi as _fastapi # noqa: F401 25 | except ImportError: 26 | pytest.fail( 27 | f'Running tests in {self.__class__.__name__} requires' 28 | ' the optional dependency fastapi to be present in the test' 29 | ' environment. Run `uv sync --dev ...` before running the test' 30 | ' suite.' 31 | ) 32 | 33 | @pytest.fixture(scope='class') 34 | def mock_app_params(self) -> dict: 35 | # Mock http_handler 36 | mock_handler = MagicMock(spec=RequestHandler) 37 | # Mock agent_card with essential attributes accessed in __init__ 38 | mock_agent_card = MagicMock(spec=AgentCard) 39 | # Ensure 'url' attribute exists on the mock_agent_card, as it's accessed 40 | # in __init__ 41 | mock_agent_card.url = 'http://example.com' 42 | # Ensure 'supports_authenticated_extended_card' attribute exists 43 | mock_agent_card.supports_authenticated_extended_card = False 44 | return {'agent_card': mock_agent_card, 'http_handler': mock_handler} 45 | 46 | @pytest.fixture(scope='class') 47 | def mark_pkg_fastapi_not_installed(self): 48 | pkg_fastapi_installed_flag = fastapi_app._package_fastapi_installed 49 | fastapi_app._package_fastapi_installed = False 50 | yield 51 | fastapi_app._package_fastapi_installed = pkg_fastapi_installed_flag 52 | 53 | def test_create_a2a_fastapi_app_with_present_deps_succeeds( 54 | self, mock_app_params: dict 55 | ): 56 | try: 57 | _app = A2AFastAPIApplication(**mock_app_params) 58 | except ImportError: 59 | pytest.fail( 60 | 'With the fastapi package present, creating a' 61 | ' A2AFastAPIApplication instance should not raise ImportError' 62 | ) 63 | 64 | def test_create_a2a_fastapi_app_with_missing_deps_raises_importerror( 65 | self, 66 | mock_app_params: dict, 67 | mark_pkg_fastapi_not_installed: Any, 68 | ): 69 | with pytest.raises( 70 | ImportError, 71 | match=( 72 | 'The `fastapi` package is required to use the' 73 | ' `A2AFastAPIApplication`' 74 | ), 75 | ): 76 | _app = A2AFastAPIApplication(**mock_app_params) 77 | 78 | 79 | if __name__ == '__main__': 80 | pytest.main([__file__]) 81 | -------------------------------------------------------------------------------- /tests/server/apps/jsonrpc/test_starlette_app.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from unittest.mock import MagicMock 3 | 4 | import pytest 5 | 6 | from a2a.server.apps.jsonrpc import starlette_app 7 | from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication 8 | from a2a.server.request_handlers.request_handler import ( 9 | RequestHandler, # For mock spec 10 | ) 11 | from a2a.types import AgentCard # For mock spec 12 | 13 | 14 | # --- A2AStarletteApplication Tests --- 15 | 16 | 17 | class TestA2AStarletteApplicationOptionalDeps: 18 | # Running tests in this class requires optional dependencies starlette and 19 | # sse-starlette to be present in the test environment. 20 | 21 | @pytest.fixture(scope='class', autouse=True) 22 | def ensure_pkg_starlette_is_present(self): 23 | try: 24 | import sse_starlette as _sse_starlette # noqa: F401 25 | import starlette as _starlette # noqa: F401 26 | except ImportError: 27 | pytest.fail( 28 | f'Running tests in {self.__class__.__name__} requires' 29 | ' optional dependencies starlette and sse-starlette to be' 30 | ' present in the test environment. Run `uv sync --dev ...`' 31 | ' before running the test suite.' 32 | ) 33 | 34 | @pytest.fixture(scope='class') 35 | def mock_app_params(self) -> dict: 36 | # Mock http_handler 37 | mock_handler = MagicMock(spec=RequestHandler) 38 | # Mock agent_card with essential attributes accessed in __init__ 39 | mock_agent_card = MagicMock(spec=AgentCard) 40 | # Ensure 'url' attribute exists on the mock_agent_card, as it's accessed 41 | # in __init__ 42 | mock_agent_card.url = 'http://example.com' 43 | # Ensure 'supports_authenticated_extended_card' attribute exists 44 | mock_agent_card.supports_authenticated_extended_card = False 45 | return {'agent_card': mock_agent_card, 'http_handler': mock_handler} 46 | 47 | @pytest.fixture(scope='class') 48 | def mark_pkg_starlette_not_installed(self): 49 | pkg_starlette_installed_flag = ( 50 | starlette_app._package_starlette_installed 51 | ) 52 | starlette_app._package_starlette_installed = False 53 | yield 54 | starlette_app._package_starlette_installed = ( 55 | pkg_starlette_installed_flag 56 | ) 57 | 58 | def test_create_a2a_starlette_app_with_present_deps_succeeds( 59 | self, mock_app_params: dict 60 | ): 61 | try: 62 | _app = A2AStarletteApplication(**mock_app_params) 63 | except ImportError: 64 | pytest.fail( 65 | 'With packages starlette and see-starlette present, creating an' 66 | ' A2AStarletteApplication instance should not raise ImportError' 67 | ) 68 | 69 | def test_create_a2a_starlette_app_with_missing_deps_raises_importerror( 70 | self, 71 | mock_app_params: dict, 72 | mark_pkg_starlette_not_installed: Any, 73 | ): 74 | with pytest.raises( 75 | ImportError, 76 | match='Packages `starlette` and `sse-starlette` are required', 77 | ): 78 | _app = A2AStarletteApplication(**mock_app_params) 79 | 80 | 81 | if __name__ == '__main__': 82 | pytest.main([__file__]) 83 | -------------------------------------------------------------------------------- /tests/server/events/test_inmemory_queue_manager.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from unittest.mock import MagicMock 4 | 5 | import pytest 6 | 7 | from a2a.server.events import InMemoryQueueManager 8 | from a2a.server.events.event_queue import EventQueue 9 | from a2a.server.events.queue_manager import ( 10 | NoTaskQueue, 11 | TaskQueueExists, 12 | ) 13 | 14 | 15 | class TestInMemoryQueueManager: 16 | @pytest.fixture 17 | def queue_manager(self): 18 | """Fixture to create a fresh InMemoryQueueManager for each test.""" 19 | return InMemoryQueueManager() 20 | 21 | @pytest.fixture 22 | def event_queue(self): 23 | """Fixture to create a mock EventQueue.""" 24 | queue = MagicMock(spec=EventQueue) 25 | # Mock the tap method to return itself 26 | queue.tap.return_value = queue 27 | return queue 28 | 29 | @pytest.mark.asyncio 30 | async def test_init(self, queue_manager): 31 | """Test that the InMemoryQueueManager initializes with empty task queue and a lock.""" 32 | assert queue_manager._task_queue == {} 33 | assert isinstance(queue_manager._lock, asyncio.Lock) 34 | 35 | @pytest.mark.asyncio 36 | async def test_add_new_queue(self, queue_manager, event_queue): 37 | """Test adding a new queue to the manager.""" 38 | task_id = 'test_task_id' 39 | await queue_manager.add(task_id, event_queue) 40 | assert queue_manager._task_queue[task_id] == event_queue 41 | 42 | @pytest.mark.asyncio 43 | async def test_add_existing_queue(self, queue_manager, event_queue): 44 | """Test adding a queue with an existing task_id raises TaskQueueExists.""" 45 | task_id = 'test_task_id' 46 | await queue_manager.add(task_id, event_queue) 47 | 48 | with pytest.raises(TaskQueueExists): 49 | await queue_manager.add(task_id, event_queue) 50 | 51 | @pytest.mark.asyncio 52 | async def test_get_existing_queue(self, queue_manager, event_queue): 53 | """Test getting an existing queue returns the queue.""" 54 | task_id = 'test_task_id' 55 | await queue_manager.add(task_id, event_queue) 56 | 57 | result = await queue_manager.get(task_id) 58 | assert result == event_queue 59 | 60 | @pytest.mark.asyncio 61 | async def test_get_nonexistent_queue(self, queue_manager): 62 | """Test getting a nonexistent queue returns None.""" 63 | result = await queue_manager.get('nonexistent_task_id') 64 | assert result is None 65 | 66 | @pytest.mark.asyncio 67 | async def test_tap_existing_queue(self, queue_manager, event_queue): 68 | """Test tapping an existing queue returns the tapped queue.""" 69 | task_id = 'test_task_id' 70 | await queue_manager.add(task_id, event_queue) 71 | 72 | result = await queue_manager.tap(task_id) 73 | assert result == event_queue 74 | event_queue.tap.assert_called_once() 75 | 76 | @pytest.mark.asyncio 77 | async def test_tap_nonexistent_queue(self, queue_manager): 78 | """Test tapping a nonexistent queue returns None.""" 79 | result = await queue_manager.tap('nonexistent_task_id') 80 | assert result is None 81 | 82 | @pytest.mark.asyncio 83 | async def test_close_existing_queue(self, queue_manager, event_queue): 84 | """Test closing an existing queue removes it from the manager.""" 85 | task_id = 'test_task_id' 86 | await queue_manager.add(task_id, event_queue) 87 | 88 | await queue_manager.close(task_id) 89 | assert task_id not in queue_manager._task_queue 90 | 91 | @pytest.mark.asyncio 92 | async def test_close_nonexistent_queue(self, queue_manager): 93 | """Test closing a nonexistent queue raises NoTaskQueue.""" 94 | with pytest.raises(NoTaskQueue): 95 | await queue_manager.close('nonexistent_task_id') 96 | 97 | @pytest.mark.asyncio 98 | async def test_create_or_tap_new_queue(self, queue_manager): 99 | """Test create_or_tap with a new task_id creates and returns a new queue.""" 100 | task_id = 'test_task_id' 101 | 102 | result = await queue_manager.create_or_tap(task_id) 103 | assert isinstance(result, EventQueue) 104 | assert queue_manager._task_queue[task_id] == result 105 | 106 | @pytest.mark.asyncio 107 | async def test_create_or_tap_existing_queue( 108 | self, queue_manager, event_queue 109 | ): 110 | """Test create_or_tap with an existing task_id taps and returns the existing queue.""" 111 | task_id = 'test_task_id' 112 | await queue_manager.add(task_id, event_queue) 113 | 114 | result = await queue_manager.create_or_tap(task_id) 115 | 116 | assert result == event_queue 117 | event_queue.tap.assert_called_once() 118 | 119 | @pytest.mark.asyncio 120 | async def test_concurrency(self, queue_manager): 121 | """Test concurrent access to the queue manager.""" 122 | 123 | async def add_task(task_id): 124 | queue = EventQueue() 125 | await queue_manager.add(task_id, queue) 126 | return task_id 127 | 128 | async def get_task(task_id): 129 | return await queue_manager.get(task_id) 130 | 131 | # Create 10 different task IDs 132 | task_ids = [f'task_{i}' for i in range(10)] 133 | 134 | # Add tasks concurrently 135 | add_tasks = [add_task(task_id) for task_id in task_ids] 136 | added_task_ids = await asyncio.gather(*add_tasks) 137 | 138 | # Verify all tasks were added 139 | assert set(added_task_ids) == set(task_ids) 140 | 141 | # Get tasks concurrently 142 | get_tasks = [get_task(task_id) for task_id in task_ids] 143 | queues = await asyncio.gather(*get_tasks) 144 | 145 | # Verify all queues are not None 146 | assert all(queue is not None for queue in queues) 147 | 148 | # Verify all tasks are in the manager 149 | for task_id in task_ids: 150 | assert task_id in queue_manager._task_queue 151 | -------------------------------------------------------------------------------- /tests/server/tasks/test_inmemory_task_store.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import pytest 4 | 5 | from a2a.server.tasks import InMemoryTaskStore 6 | from a2a.types import Task 7 | 8 | 9 | MINIMAL_TASK: dict[str, Any] = { 10 | 'id': 'task-abc', 11 | 'context_id': 'session-xyz', 12 | 'status': {'state': 'submitted'}, 13 | 'kind': 'task', 14 | } 15 | 16 | 17 | @pytest.mark.asyncio 18 | async def test_in_memory_task_store_save_and_get() -> None: 19 | """Test saving and retrieving a task from the in-memory store.""" 20 | store = InMemoryTaskStore() 21 | task = Task(**MINIMAL_TASK) 22 | await store.save(task) 23 | retrieved_task = await store.get(MINIMAL_TASK['id']) 24 | assert retrieved_task == task 25 | 26 | 27 | @pytest.mark.asyncio 28 | async def test_in_memory_task_store_get_nonexistent() -> None: 29 | """Test retrieving a nonexistent task.""" 30 | store = InMemoryTaskStore() 31 | retrieved_task = await store.get('nonexistent') 32 | assert retrieved_task is None 33 | 34 | 35 | @pytest.mark.asyncio 36 | async def test_in_memory_task_store_delete() -> None: 37 | """Test deleting a task from the store.""" 38 | store = InMemoryTaskStore() 39 | task = Task(**MINIMAL_TASK) 40 | await store.save(task) 41 | await store.delete(MINIMAL_TASK['id']) 42 | retrieved_task = await store.get(MINIMAL_TASK['id']) 43 | assert retrieved_task is None 44 | 45 | 46 | @pytest.mark.asyncio 47 | async def test_in_memory_task_store_delete_nonexistent() -> None: 48 | """Test deleting a nonexistent task.""" 49 | store = InMemoryTaskStore() 50 | await store.delete('nonexistent') 51 | -------------------------------------------------------------------------------- /tests/server/tasks/test_push_notification_sender.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from unittest.mock import AsyncMock, MagicMock, patch 4 | 5 | import httpx 6 | 7 | from a2a.server.tasks.base_push_notification_sender import ( 8 | BasePushNotificationSender, 9 | ) 10 | from a2a.types import ( 11 | PushNotificationConfig, 12 | Task, 13 | TaskState, 14 | TaskStatus, 15 | ) 16 | 17 | 18 | def create_sample_task(task_id='task123', status_state=TaskState.completed): 19 | return Task( 20 | id=task_id, 21 | context_id='ctx456', 22 | status=TaskStatus(state=status_state), 23 | ) 24 | 25 | 26 | def create_sample_push_config( 27 | url='http://example.com/callback', config_id='cfg1', token=None 28 | ): 29 | return PushNotificationConfig(id=config_id, url=url, token=token) 30 | 31 | 32 | class TestBasePushNotificationSender(unittest.IsolatedAsyncioTestCase): 33 | def setUp(self): 34 | self.mock_httpx_client = AsyncMock(spec=httpx.AsyncClient) 35 | self.mock_config_store = AsyncMock() 36 | self.sender = BasePushNotificationSender( 37 | httpx_client=self.mock_httpx_client, 38 | config_store=self.mock_config_store, 39 | ) 40 | 41 | def test_constructor_stores_client_and_config_store(self): 42 | self.assertEqual(self.sender._client, self.mock_httpx_client) 43 | self.assertEqual(self.sender._config_store, self.mock_config_store) 44 | 45 | async def test_send_notification_success(self): 46 | task_id = 'task_send_success' 47 | task_data = create_sample_task(task_id=task_id) 48 | config = create_sample_push_config(url='http://notify.me/here') 49 | self.mock_config_store.get_info.return_value = [config] 50 | 51 | mock_response = AsyncMock(spec=httpx.Response) 52 | mock_response.status_code = 200 53 | self.mock_httpx_client.post.return_value = mock_response 54 | 55 | await self.sender.send_notification(task_data) 56 | 57 | self.mock_config_store.get_info.assert_awaited_once_with 58 | 59 | # assert httpx_client post method got invoked with right parameters 60 | self.mock_httpx_client.post.assert_awaited_once_with( 61 | config.url, 62 | json=task_data.model_dump(mode='json', exclude_none=True), 63 | headers=None, 64 | ) 65 | mock_response.raise_for_status.assert_called_once() 66 | 67 | async def test_send_notification_with_token_success(self): 68 | task_id = 'task_send_success' 69 | task_data = create_sample_task(task_id=task_id) 70 | config = create_sample_push_config( 71 | url='http://notify.me/here', token='unique_token' 72 | ) 73 | self.mock_config_store.get_info.return_value = [config] 74 | 75 | mock_response = AsyncMock(spec=httpx.Response) 76 | mock_response.status_code = 200 77 | self.mock_httpx_client.post.return_value = mock_response 78 | 79 | await self.sender.send_notification(task_data) 80 | 81 | self.mock_config_store.get_info.assert_awaited_once_with 82 | 83 | # assert httpx_client post method got invoked with right parameters 84 | self.mock_httpx_client.post.assert_awaited_once_with( 85 | config.url, 86 | json=task_data.model_dump(mode='json', exclude_none=True), 87 | headers={'X-A2A-Notification-Token': 'unique_token'}, 88 | ) 89 | mock_response.raise_for_status.assert_called_once() 90 | 91 | async def test_send_notification_no_config(self): 92 | task_id = 'task_send_no_config' 93 | task_data = create_sample_task(task_id=task_id) 94 | self.mock_config_store.get_info.return_value = [] 95 | 96 | await self.sender.send_notification(task_data) 97 | 98 | self.mock_config_store.get_info.assert_awaited_once_with(task_id) 99 | self.mock_httpx_client.post.assert_not_called() 100 | 101 | @patch('a2a.server.tasks.base_push_notification_sender.logger') 102 | async def test_send_notification_http_status_error( 103 | self, mock_logger: MagicMock 104 | ): 105 | task_id = 'task_send_http_err' 106 | task_data = create_sample_task(task_id=task_id) 107 | config = create_sample_push_config(url='http://notify.me/http_error') 108 | self.mock_config_store.get_info.return_value = [config] 109 | 110 | mock_response = MagicMock(spec=httpx.Response) 111 | mock_response.status_code = 404 112 | mock_response.text = 'Not Found' 113 | http_error = httpx.HTTPStatusError( 114 | 'Not Found', request=MagicMock(), response=mock_response 115 | ) 116 | self.mock_httpx_client.post.side_effect = http_error 117 | 118 | await self.sender.send_notification(task_data) 119 | 120 | self.mock_config_store.get_info.assert_awaited_once_with(task_id) 121 | self.mock_httpx_client.post.assert_awaited_once_with( 122 | config.url, 123 | json=task_data.model_dump(mode='json', exclude_none=True), 124 | headers=None, 125 | ) 126 | mock_logger.exception.assert_called_once() 127 | 128 | async def test_send_notification_multiple_configs(self): 129 | task_id = 'task_multiple_configs' 130 | task_data = create_sample_task(task_id=task_id) 131 | config1 = create_sample_push_config( 132 | url='http://notify.me/cfg1', config_id='cfg1' 133 | ) 134 | config2 = create_sample_push_config( 135 | url='http://notify.me/cfg2', config_id='cfg2' 136 | ) 137 | self.mock_config_store.get_info.return_value = [config1, config2] 138 | 139 | mock_response = AsyncMock(spec=httpx.Response) 140 | mock_response.status_code = 200 141 | self.mock_httpx_client.post.return_value = mock_response 142 | 143 | await self.sender.send_notification(task_data) 144 | 145 | self.mock_config_store.get_info.assert_awaited_once_with(task_id) 146 | self.assertEqual(self.mock_httpx_client.post.call_count, 2) 147 | 148 | # Check calls for config1 149 | self.mock_httpx_client.post.assert_any_call( 150 | config1.url, 151 | json=task_data.model_dump(mode='json', exclude_none=True), 152 | headers=None, 153 | ) 154 | # Check calls for config2 155 | self.mock_httpx_client.post.assert_any_call( 156 | config2.url, 157 | json=task_data.model_dump(mode='json', exclude_none=True), 158 | headers=None, 159 | ) 160 | mock_response.raise_for_status.call_count = 2 161 | -------------------------------------------------------------------------------- /tests/server/test_models.py: -------------------------------------------------------------------------------- 1 | """Tests for a2a.server.models module.""" 2 | 3 | from unittest.mock import MagicMock 4 | 5 | from sqlalchemy.orm import DeclarativeBase 6 | 7 | from a2a.server.models import ( 8 | PydanticListType, 9 | PydanticType, 10 | create_push_notification_config_model, 11 | create_task_model, 12 | ) 13 | from a2a.types import Artifact, TaskState, TaskStatus, TextPart 14 | 15 | 16 | class TestPydanticType: 17 | """Tests for PydanticType SQLAlchemy type decorator.""" 18 | 19 | def test_process_bind_param_with_pydantic_model(self): 20 | pydantic_type = PydanticType(TaskStatus) 21 | status = TaskStatus(state=TaskState.working) 22 | dialect = MagicMock() 23 | 24 | result = pydantic_type.process_bind_param(status, dialect) 25 | assert result['state'] == 'working' 26 | assert result['message'] is None 27 | # TaskStatus may have other optional fields 28 | 29 | def test_process_bind_param_with_none(self): 30 | pydantic_type = PydanticType(TaskStatus) 31 | dialect = MagicMock() 32 | 33 | result = pydantic_type.process_bind_param(None, dialect) 34 | assert result is None 35 | 36 | def test_process_result_value(self): 37 | pydantic_type = PydanticType(TaskStatus) 38 | dialect = MagicMock() 39 | 40 | result = pydantic_type.process_result_value( 41 | {'state': 'completed', 'message': None}, dialect 42 | ) 43 | assert isinstance(result, TaskStatus) 44 | assert result.state == 'completed' 45 | 46 | 47 | class TestPydanticListType: 48 | """Tests for PydanticListType SQLAlchemy type decorator.""" 49 | 50 | def test_process_bind_param_with_list(self): 51 | pydantic_list_type = PydanticListType(Artifact) 52 | artifacts = [ 53 | Artifact( 54 | artifact_id='1', parts=[TextPart(type='text', text='Hello')] 55 | ), 56 | Artifact( 57 | artifact_id='2', parts=[TextPart(type='text', text='World')] 58 | ), 59 | ] 60 | dialect = MagicMock() 61 | 62 | result = pydantic_list_type.process_bind_param(artifacts, dialect) 63 | assert len(result) == 2 64 | assert result[0]['artifactId'] == '1' # JSON mode uses camelCase 65 | assert result[1]['artifactId'] == '2' 66 | 67 | def test_process_result_value_with_list(self): 68 | pydantic_list_type = PydanticListType(Artifact) 69 | dialect = MagicMock() 70 | data = [ 71 | {'artifact_id': '1', 'parts': [{'type': 'text', 'text': 'Hello'}]}, 72 | {'artifact_id': '2', 'parts': [{'type': 'text', 'text': 'World'}]}, 73 | ] 74 | 75 | result = pydantic_list_type.process_result_value(data, dialect) 76 | assert len(result) == 2 77 | assert all(isinstance(art, Artifact) for art in result) 78 | assert result[0].artifact_id == '1' 79 | assert result[1].artifact_id == '2' 80 | 81 | 82 | def test_create_task_model(): 83 | """Test dynamic task model creation.""" 84 | 85 | # Create a fresh base to avoid table conflicts 86 | class TestBase(DeclarativeBase): 87 | pass 88 | 89 | # Create with default table name 90 | default_task_model = create_task_model('test_tasks_1', TestBase) 91 | assert default_task_model.__tablename__ == 'test_tasks_1' 92 | assert default_task_model.__name__ == 'TaskModel_test_tasks_1' 93 | 94 | # Create with custom table name 95 | custom_task_model = create_task_model('test_tasks_2', TestBase) 96 | assert custom_task_model.__tablename__ == 'test_tasks_2' 97 | assert custom_task_model.__name__ == 'TaskModel_test_tasks_2' 98 | 99 | 100 | def test_create_push_notification_config_model(): 101 | """Test dynamic push notification config model creation.""" 102 | 103 | # Create a fresh base to avoid table conflicts 104 | class TestBase(DeclarativeBase): 105 | pass 106 | 107 | # Create with default table name 108 | default_model = create_push_notification_config_model( 109 | 'test_push_configs_1', TestBase 110 | ) 111 | assert default_model.__tablename__ == 'test_push_configs_1' 112 | 113 | # Create with custom table name 114 | custom_model = create_push_notification_config_model( 115 | 'test_push_configs_2', TestBase 116 | ) 117 | assert custom_model.__tablename__ == 'test_push_configs_2' 118 | assert 'test_push_configs_2' in custom_model.__name__ 119 | -------------------------------------------------------------------------------- /tests/utils/test_artifact.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import uuid 3 | 4 | from unittest.mock import patch 5 | 6 | from a2a.types import DataPart, Part, TextPart 7 | from a2a.utils.artifact import ( 8 | new_artifact, 9 | new_data_artifact, 10 | new_text_artifact, 11 | ) 12 | 13 | 14 | class TestArtifact(unittest.TestCase): 15 | @patch('uuid.uuid4') 16 | def test_new_artifact_generates_id(self, mock_uuid4): 17 | mock_uuid = uuid.UUID('abcdef12-1234-5678-1234-567812345678') 18 | mock_uuid4.return_value = mock_uuid 19 | artifact = new_artifact(parts=[], name='test_artifact') 20 | self.assertEqual(artifact.artifact_id, str(mock_uuid)) 21 | 22 | def test_new_artifact_assigns_parts_name_description(self): 23 | parts = [Part(root=TextPart(text='Sample text'))] 24 | name = 'My Artifact' 25 | description = 'This is a test artifact.' 26 | artifact = new_artifact(parts=parts, name=name, description=description) 27 | self.assertEqual(artifact.parts, parts) 28 | self.assertEqual(artifact.name, name) 29 | self.assertEqual(artifact.description, description) 30 | 31 | def test_new_artifact_empty_description_if_not_provided(self): 32 | parts = [Part(root=TextPart(text='Another sample'))] 33 | name = 'Artifact_No_Desc' 34 | artifact = new_artifact(parts=parts, name=name) 35 | self.assertEqual(artifact.description, '') 36 | 37 | def test_new_text_artifact_creates_single_text_part(self): 38 | text = 'This is a text artifact.' 39 | name = 'Text_Artifact' 40 | artifact = new_text_artifact(text=text, name=name) 41 | self.assertEqual(len(artifact.parts), 1) 42 | self.assertIsInstance(artifact.parts[0].root, TextPart) 43 | 44 | def test_new_text_artifact_part_contains_provided_text(self): 45 | text = 'Hello, world!' 46 | name = 'Greeting_Artifact' 47 | artifact = new_text_artifact(text=text, name=name) 48 | self.assertEqual(artifact.parts[0].root.text, text) 49 | 50 | def test_new_text_artifact_assigns_name_description(self): 51 | text = 'Some content.' 52 | name = 'Named_Text_Artifact' 53 | description = 'Description for text artifact.' 54 | artifact = new_text_artifact( 55 | text=text, name=name, description=description 56 | ) 57 | self.assertEqual(artifact.name, name) 58 | self.assertEqual(artifact.description, description) 59 | 60 | def test_new_data_artifact_creates_single_data_part(self): 61 | sample_data = {'key': 'value', 'number': 123} 62 | name = 'Data_Artifact' 63 | artifact = new_data_artifact(data=sample_data, name=name) 64 | self.assertEqual(len(artifact.parts), 1) 65 | self.assertIsInstance(artifact.parts[0].root, DataPart) 66 | 67 | def test_new_data_artifact_part_contains_provided_data(self): 68 | sample_data = {'content': 'test_data', 'is_valid': True} 69 | name = 'Structured_Data_Artifact' 70 | artifact = new_data_artifact(data=sample_data, name=name) 71 | self.assertIsInstance(artifact.parts[0].root, DataPart) 72 | # Ensure the 'data' attribute of DataPart is accessed for comparison 73 | self.assertEqual(artifact.parts[0].root.data, sample_data) 74 | 75 | def test_new_data_artifact_assigns_name_description(self): 76 | sample_data = {'info': 'some details'} 77 | name = 'Named_Data_Artifact' 78 | description = 'Description for data artifact.' 79 | artifact = new_data_artifact( 80 | data=sample_data, name=name, description=description 81 | ) 82 | self.assertEqual(artifact.name, name) 83 | self.assertEqual(artifact.description, description) 84 | 85 | 86 | if __name__ == '__main__': 87 | unittest.main() 88 | -------------------------------------------------------------------------------- /tests/utils/test_constants.py: -------------------------------------------------------------------------------- 1 | """Tests for a2a.utils.constants module.""" 2 | 3 | from a2a.utils import constants 4 | 5 | 6 | def test_agent_card_constants(): 7 | """Test that agent card constants have expected values.""" 8 | assert ( 9 | constants.AGENT_CARD_WELL_KNOWN_PATH == '/.well-known/agent-card.json' 10 | ) 11 | assert ( 12 | constants.PREV_AGENT_CARD_WELL_KNOWN_PATH == '/.well-known/agent.json' 13 | ) 14 | assert ( 15 | constants.EXTENDED_AGENT_CARD_PATH == '/agent/authenticatedExtendedCard' 16 | ) 17 | 18 | 19 | def test_default_rpc_url(): 20 | """Test default RPC URL constant.""" 21 | assert constants.DEFAULT_RPC_URL == '/' 22 | -------------------------------------------------------------------------------- /tests/utils/test_error_handlers.py: -------------------------------------------------------------------------------- 1 | """Tests for a2a.utils.error_handlers module.""" 2 | 3 | from unittest.mock import patch 4 | 5 | import pytest 6 | 7 | from a2a.types import ( 8 | InternalError, 9 | InvalidRequestError, 10 | MethodNotFoundError, 11 | TaskNotFoundError, 12 | ) 13 | from a2a.utils.error_handlers import ( 14 | A2AErrorToHttpStatus, 15 | rest_error_handler, 16 | rest_stream_error_handler, 17 | ) 18 | from a2a.utils.errors import ServerError 19 | 20 | 21 | class MockJSONResponse: 22 | def __init__(self, content, status_code): 23 | self.content = content 24 | self.status_code = status_code 25 | 26 | 27 | @pytest.mark.asyncio 28 | async def test_rest_error_handler_server_error(): 29 | """Test rest_error_handler with ServerError.""" 30 | error = InvalidRequestError(message='Bad request') 31 | 32 | @rest_error_handler 33 | async def failing_func(): 34 | raise ServerError(error=error) 35 | 36 | with patch('a2a.utils.error_handlers.JSONResponse', MockJSONResponse): 37 | result = await failing_func() 38 | 39 | assert isinstance(result, MockJSONResponse) 40 | assert result.status_code == 400 41 | assert result.content == {'message': 'Bad request'} 42 | 43 | 44 | @pytest.mark.asyncio 45 | async def test_rest_error_handler_unknown_exception(): 46 | """Test rest_error_handler with unknown exception.""" 47 | 48 | @rest_error_handler 49 | async def failing_func(): 50 | raise ValueError('Unexpected error') 51 | 52 | with patch('a2a.utils.error_handlers.JSONResponse', MockJSONResponse): 53 | result = await failing_func() 54 | 55 | assert isinstance(result, MockJSONResponse) 56 | assert result.status_code == 500 57 | assert result.content == {'message': 'unknown exception'} 58 | 59 | 60 | @pytest.mark.asyncio 61 | async def test_rest_stream_error_handler_server_error(): 62 | """Test rest_stream_error_handler with ServerError.""" 63 | error = InternalError(message='Internal server error') 64 | 65 | @rest_stream_error_handler 66 | async def failing_stream(): 67 | raise ServerError(error=error) 68 | 69 | with pytest.raises(ServerError) as exc_info: 70 | await failing_stream() 71 | 72 | assert exc_info.value.error == error 73 | 74 | 75 | @pytest.mark.asyncio 76 | async def test_rest_stream_error_handler_reraises_exception(): 77 | """Test rest_stream_error_handler reraises other exceptions.""" 78 | 79 | @rest_stream_error_handler 80 | async def failing_stream(): 81 | raise RuntimeError('Stream failed') 82 | 83 | with pytest.raises(RuntimeError, match='Stream failed'): 84 | await failing_stream() 85 | 86 | 87 | def test_a2a_error_to_http_status_mapping(): 88 | """Test A2AErrorToHttpStatus mapping.""" 89 | assert A2AErrorToHttpStatus[InvalidRequestError] == 400 90 | assert A2AErrorToHttpStatus[MethodNotFoundError] == 404 91 | assert A2AErrorToHttpStatus[TaskNotFoundError] == 404 92 | assert A2AErrorToHttpStatus[InternalError] == 500 93 | -------------------------------------------------------------------------------- /tests/utils/test_telemetry.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from typing import NoReturn 4 | from unittest import mock 5 | 6 | import pytest 7 | 8 | from a2a.utils.telemetry import trace_class, trace_function 9 | 10 | 11 | @pytest.fixture 12 | def mock_span(): 13 | return mock.MagicMock() 14 | 15 | 16 | @pytest.fixture 17 | def mock_tracer(mock_span): 18 | tracer = mock.MagicMock() 19 | tracer.start_as_current_span.return_value.__enter__.return_value = mock_span 20 | tracer.start_as_current_span.return_value.__exit__.return_value = False 21 | return tracer 22 | 23 | 24 | @pytest.fixture(autouse=True) 25 | def patch_trace_get_tracer(mock_tracer): 26 | with mock.patch('opentelemetry.trace.get_tracer', return_value=mock_tracer): 27 | yield 28 | 29 | 30 | def test_trace_function_sync_success(mock_span): 31 | @trace_function 32 | def foo(x, y): 33 | return x + y 34 | 35 | result = foo(2, 3) 36 | assert result == 5 37 | mock_span.set_status.assert_called() 38 | mock_span.set_status.assert_any_call(mock.ANY) 39 | mock_span.record_exception.assert_not_called() 40 | 41 | 42 | def test_trace_function_sync_exception(mock_span): 43 | @trace_function 44 | def bar() -> NoReturn: 45 | raise ValueError('fail') 46 | 47 | with pytest.raises(ValueError): 48 | bar() 49 | mock_span.record_exception.assert_called() 50 | mock_span.set_status.assert_any_call(mock.ANY, description='fail') 51 | 52 | 53 | def test_trace_function_sync_attribute_extractor_called(mock_span): 54 | called = {} 55 | 56 | def attr_extractor(span, args, kwargs, result, exception) -> None: 57 | called['called'] = True 58 | assert span is mock_span 59 | assert exception is None 60 | assert result == 42 61 | 62 | @trace_function(attribute_extractor=attr_extractor) 63 | def foo() -> int: 64 | return 42 65 | 66 | foo() 67 | assert called['called'] 68 | 69 | 70 | def test_trace_function_sync_attribute_extractor_error_logged(mock_span): 71 | with mock.patch('a2a.utils.telemetry.logger') as logger: 72 | 73 | def attr_extractor(span, args, kwargs, result, exception) -> NoReturn: 74 | raise RuntimeError('attr fail') 75 | 76 | @trace_function(attribute_extractor=attr_extractor) 77 | def foo() -> int: 78 | return 1 79 | 80 | foo() 81 | logger.exception.assert_any_call( 82 | 'attribute_extractor error in span %s', 83 | 'test_telemetry.foo', 84 | ) 85 | 86 | 87 | @pytest.mark.asyncio 88 | async def test_trace_function_async_success(mock_span): 89 | @trace_function 90 | async def foo(x): 91 | await asyncio.sleep(0) 92 | return x * 2 93 | 94 | result = await foo(4) 95 | assert result == 8 96 | mock_span.set_status.assert_called() 97 | mock_span.record_exception.assert_not_called() 98 | 99 | 100 | @pytest.mark.asyncio 101 | async def test_trace_function_async_exception(mock_span): 102 | @trace_function 103 | async def bar() -> NoReturn: 104 | await asyncio.sleep(0) 105 | raise RuntimeError('async fail') 106 | 107 | with pytest.raises(RuntimeError): 108 | await bar() 109 | mock_span.record_exception.assert_called() 110 | mock_span.set_status.assert_any_call(mock.ANY, description='async fail') 111 | 112 | 113 | @pytest.mark.asyncio 114 | async def test_trace_function_async_attribute_extractor_called(mock_span): 115 | called = {} 116 | 117 | def attr_extractor(span, args, kwargs, result, exception) -> None: 118 | called['called'] = True 119 | assert exception is None 120 | assert result == 99 121 | 122 | @trace_function(attribute_extractor=attr_extractor) 123 | async def foo() -> int: 124 | return 99 125 | 126 | await foo() 127 | assert called['called'] 128 | 129 | 130 | def test_trace_function_with_args_and_attributes(mock_span): 131 | @trace_function(span_name='custom.span', attributes={'foo': 'bar'}) 132 | def foo() -> int: 133 | return 1 134 | 135 | foo() 136 | mock_span.set_attribute.assert_any_call('foo', 'bar') 137 | 138 | 139 | def test_trace_class_exclude_list(mock_span): 140 | @trace_class(exclude_list=['skip_me']) 141 | class MyClass: 142 | def a(self) -> str: 143 | return 'a' 144 | 145 | def skip_me(self) -> str: 146 | return 'skip' 147 | 148 | def __str__(self): 149 | return 'str' 150 | 151 | obj = MyClass() 152 | assert obj.a() == 'a' 153 | assert obj.skip_me() == 'skip' 154 | # Only 'a' is traced, not 'skip_me' or dunder 155 | assert hasattr(obj.a, '__wrapped__') 156 | assert not hasattr(obj.skip_me, '__wrapped__') 157 | 158 | 159 | def test_trace_class_include_list(mock_span): 160 | @trace_class(include_list=['only_this']) 161 | class MyClass: 162 | def only_this(self) -> str: 163 | return 'yes' 164 | 165 | def not_this(self) -> str: 166 | return 'no' 167 | 168 | obj = MyClass() 169 | assert obj.only_this() == 'yes' 170 | assert obj.not_this() == 'no' 171 | assert hasattr(obj.only_this, '__wrapped__') 172 | assert not hasattr(obj.not_this, '__wrapped__') 173 | 174 | 175 | def test_trace_class_dunder_not_traced(mock_span): 176 | @trace_class() 177 | class MyClass: 178 | def __init__(self): 179 | self.x = 1 180 | 181 | def foo(self) -> str: 182 | return 'foo' 183 | 184 | obj = MyClass() 185 | assert obj.foo() == 'foo' 186 | assert hasattr(obj.foo, '__wrapped__') 187 | assert hasattr(obj, 'x') 188 | --------------------------------------------------------------------------------