├── .gemini
└── config.yaml
├── .git-blame-ignore-revs
├── .github
├── CODEOWNERS
├── ISSUE_TEMPLATE
│ ├── bug-report.yml
│ └── feature-request.yml
├── PULL_REQUEST_TEMPLATE.md
├── actions
│ └── spelling
│ │ ├── advice.md
│ │ ├── allow.txt
│ │ ├── excludes.txt
│ │ └── line_forbidden.patterns
├── conventional-commit-lint.yaml
├── dependabot.yml
└── workflows
│ ├── conventional-commits.yml
│ ├── linter.yaml
│ ├── python-publish.yml
│ ├── release-please.yml
│ ├── security.yaml
│ ├── spelling.yaml
│ ├── stale.yaml
│ ├── unit-tests.yml
│ └── update-a2a-types.yml
├── .gitignore
├── .jscpd.json
├── .pre-commit-config.yaml
├── .python-version
├── .vscode
├── extensions.json
├── launch.json
└── settings.json
├── CHANGELOG.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── Gemini.md
├── LICENSE
├── README.md
├── SECURITY.md
├── buf.gen.yaml
├── pyproject.toml
├── scripts
├── format.sh
├── generate_types.sh
└── grpc_gen_post_processor.py
├── src
└── a2a
│ ├── __init__.py
│ ├── _base.py
│ ├── auth
│ ├── __init__.py
│ └── user.py
│ ├── client
│ ├── __init__.py
│ ├── auth
│ │ ├── __init__.py
│ │ ├── credentials.py
│ │ └── interceptor.py
│ ├── base_client.py
│ ├── card_resolver.py
│ ├── client.py
│ ├── client_factory.py
│ ├── client_task_manager.py
│ ├── errors.py
│ ├── helpers.py
│ ├── legacy.py
│ ├── legacy_grpc.py
│ ├── middleware.py
│ ├── optionals.py
│ └── transports
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── grpc.py
│ │ ├── jsonrpc.py
│ │ └── rest.py
│ ├── extensions
│ ├── __init__.py
│ └── common.py
│ ├── grpc
│ ├── __init__.py
│ ├── a2a_pb2.py
│ ├── a2a_pb2.pyi
│ └── a2a_pb2_grpc.py
│ ├── py.typed
│ ├── server
│ ├── __init__.py
│ ├── agent_execution
│ │ ├── __init__.py
│ │ ├── agent_executor.py
│ │ ├── context.py
│ │ ├── request_context_builder.py
│ │ └── simple_request_context_builder.py
│ ├── apps
│ │ ├── __init__.py
│ │ ├── jsonrpc
│ │ │ ├── __init__.py
│ │ │ ├── fastapi_app.py
│ │ │ ├── jsonrpc_app.py
│ │ │ └── starlette_app.py
│ │ └── rest
│ │ │ ├── __init__.py
│ │ │ ├── fastapi_app.py
│ │ │ └── rest_adapter.py
│ ├── context.py
│ ├── events
│ │ ├── __init__.py
│ │ ├── event_consumer.py
│ │ ├── event_queue.py
│ │ ├── in_memory_queue_manager.py
│ │ └── queue_manager.py
│ ├── models.py
│ ├── request_handlers
│ │ ├── __init__.py
│ │ ├── default_request_handler.py
│ │ ├── grpc_handler.py
│ │ ├── jsonrpc_handler.py
│ │ ├── request_handler.py
│ │ ├── response_helpers.py
│ │ └── rest_handler.py
│ └── tasks
│ │ ├── __init__.py
│ │ ├── base_push_notification_sender.py
│ │ ├── database_push_notification_config_store.py
│ │ ├── database_task_store.py
│ │ ├── inmemory_push_notification_config_store.py
│ │ ├── inmemory_task_store.py
│ │ ├── push_notification_config_store.py
│ │ ├── push_notification_sender.py
│ │ ├── result_aggregator.py
│ │ ├── task_manager.py
│ │ ├── task_store.py
│ │ └── task_updater.py
│ ├── types.py
│ └── utils
│ ├── __init__.py
│ ├── artifact.py
│ ├── constants.py
│ ├── error_handlers.py
│ ├── errors.py
│ ├── helpers.py
│ ├── message.py
│ ├── proto_utils.py
│ ├── task.py
│ └── telemetry.py
├── tests
├── README.md
├── auth
│ └── test_user.py
├── client
│ ├── test_auth_middleware.py
│ ├── test_base_client.py
│ ├── test_client_factory.py
│ ├── test_client_task_manager.py
│ ├── test_errors.py
│ ├── test_grpc_client.py
│ ├── test_jsonrpc_client.py
│ ├── test_legacy_client.py
│ └── test_optionals.py
├── e2e
│ └── push_notifications
│ │ ├── agent_app.py
│ │ ├── notifications_app.py
│ │ ├── test_default_push_notification_support.py
│ │ └── utils.py
├── extensions
│ └── test_common.py
├── integration
│ └── test_client_server_integration.py
├── server
│ ├── agent_execution
│ │ ├── test_context.py
│ │ └── test_simple_request_context_builder.py
│ ├── apps
│ │ ├── jsonrpc
│ │ │ ├── test_fastapi_app.py
│ │ │ ├── test_jsonrpc_app.py
│ │ │ ├── test_serialization.py
│ │ │ └── test_starlette_app.py
│ │ └── rest
│ │ │ └── test_rest_fastapi_app.py
│ ├── events
│ │ ├── test_event_consumer.py
│ │ ├── test_event_queue.py
│ │ └── test_inmemory_queue_manager.py
│ ├── request_handlers
│ │ ├── test_default_request_handler.py
│ │ ├── test_grpc_handler.py
│ │ ├── test_jsonrpc_handler.py
│ │ └── test_response_helpers.py
│ ├── tasks
│ │ ├── test_database_push_notification_config_store.py
│ │ ├── test_database_task_store.py
│ │ ├── test_inmemory_push_notifications.py
│ │ ├── test_inmemory_task_store.py
│ │ ├── test_push_notification_sender.py
│ │ ├── test_result_aggregator.py
│ │ ├── test_task_manager.py
│ │ └── test_task_updater.py
│ ├── test_integration.py
│ └── test_models.py
├── test_types.py
└── utils
│ ├── test_artifact.py
│ ├── test_constants.py
│ ├── test_error_handlers.py
│ ├── test_helpers.py
│ ├── test_message.py
│ ├── test_proto_utils.py
│ ├── test_task.py
│ └── test_telemetry.py
└── uv.lock
/.gemini/config.yaml:
--------------------------------------------------------------------------------
1 | code_review:
2 | comment_severity_threshold: LOW
3 | ignore_patterns: ['CHANGELOG.md']
4 |
--------------------------------------------------------------------------------
/.git-blame-ignore-revs:
--------------------------------------------------------------------------------
1 | # Template taken from https://github.com/v8/v8/blob/master/.git-blame-ignore-revs.
2 | #
3 | # This file contains a list of git hashes of revisions to be ignored by git blame. These
4 | # revisions are considered "unimportant" in that they are unlikely to be what you are
5 | # interested in when blaming. Most of these will probably be commits related to linting
6 | # and code formatting.
7 | #
8 | # Instructions:
9 | # - Only large (generally automated) reformatting or renaming CLs should be
10 | # added to this list. Do not put things here just because you feel they are
11 | # trivial or unimportant. If in doubt, do not put it on this list.
12 | # - Precede each revision with a comment containing the PR title and number.
13 | # For bulk work over many commits, place all commits in a block with a single
14 | # comment at the top describing the work done in those commits.
15 | # - Only put full 40-character hashes on this list (not short hashes or any
16 | # other revision reference).
17 | # - Append to the bottom of the file (revisions should be in chronological order
18 | # from oldest to newest).
19 | # - Because you must use a hash, you need to append to this list in a follow-up
20 | # PR to the actual reformatting PR that you are trying to ignore.
21 | 193693836e1ed8cd361e139668323d2e267a9eaa
22 |
--------------------------------------------------------------------------------
/.github/CODEOWNERS:
--------------------------------------------------------------------------------
1 | # Code owners file.
2 | # This file controls who is tagged for review for any given pull request.
3 | #
4 | # For syntax help see:
5 | # https://help.github.com/en/github/creating-cloning-and-archiving-repositories/about-code-owners#codeowners-syntax
6 |
7 | * @a2aproject/google-a2a-eng
8 | src/a2a/types.py @a2a-bot
9 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug-report.yml:
--------------------------------------------------------------------------------
1 | ---
2 | name: 🐞 Bug Report
3 | description: File a bug report
4 | title: '[Bug]: '
5 | type: Bug
6 | body:
7 | - type: markdown
8 | attributes:
9 | value: |
10 | Thanks for stopping by to let us know something could be better!
11 | Private Feedback? Please use this [Google form](https://goo.gle/a2a-feedback)
12 | - type: textarea
13 | id: what-happened
14 | attributes:
15 | label: What happened?
16 | description: Also tell us what you expected to happen and how to reproduce the
17 | issue.
18 | placeholder: Tell us what you see!
19 | value: A bug happened!
20 | validations:
21 | required: true
22 | - type: textarea
23 | id: logs
24 | attributes:
25 | label: Relevant log output
26 | description: Please copy and paste any relevant log output. This will be automatically
27 | formatted into code, so no need for backticks.
28 | render: shell
29 | - type: checkboxes
30 | id: terms
31 | attributes:
32 | label: Code of Conduct
33 | description: By submitting this issue, you agree to follow our [Code of Conduct](https://github.com/a2aproject/A2A?tab=coc-ov-file#readme)
34 | options:
35 | - label: I agree to follow this project's Code of Conduct
36 | required: true
37 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature-request.yml:
--------------------------------------------------------------------------------
1 | ---
2 | name: 💡 Feature Request
3 | description: Suggest an idea for this repository
4 | title: '[Feat]: '
5 | type: Feature
6 | body:
7 | - type: markdown
8 | attributes:
9 | value: |
10 | Thanks for stopping by to let us know something could be better!
11 | Private Feedback? Please use this [Google form](https://goo.gle/a2a-feedback)
12 | - type: textarea
13 | id: problem
14 | attributes:
15 | label: Is your feature request related to a problem? Please describe.
16 | description: A clear and concise description of what the problem is.
17 | placeholder: Ex. I'm always frustrated when [...]
18 | - type: textarea
19 | id: describe
20 | attributes:
21 | label: Describe the solution you'd like
22 | description: A clear and concise description of what you want to happen.
23 | validations:
24 | required: true
25 | - type: textarea
26 | id: alternatives
27 | attributes:
28 | label: Describe alternatives you've considered
29 | description: A clear and concise description of any alternative solutions or
30 | features you've considered.
31 | - type: textarea
32 | id: context
33 | attributes:
34 | label: Additional context
35 | description: Add any other context or screenshots about the feature request
36 | here.
37 | - type: checkboxes
38 | id: terms
39 | attributes:
40 | label: Code of Conduct
41 | description: By submitting this issue, you agree to follow our [Code of Conduct](https://github.com/a2aproject/a2a-python?tab=coc-ov-file#readme)
42 | options:
43 | - label: I agree to follow this project's Code of Conduct
44 | required: true
45 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | # Description
2 |
3 | Thank you for opening a Pull Request!
4 | Before submitting your PR, there are a few things you can do to make sure it goes smoothly:
5 |
6 | - [ ] Follow the [`CONTRIBUTING` Guide](https://github.com/a2aproject/a2a-python/blob/main/CONTRIBUTING.md).
7 | - [ ] Make your Pull Request title in the specification.
8 | - Important Prefixes for [release-please](https://github.com/googleapis/release-please):
9 | - `fix:` which represents bug fixes, and correlates to a [SemVer](https://semver.org/) patch.
10 | - `feat:` represents a new feature, and correlates to a SemVer minor.
11 | - `feat!:`, or `fix!:`, `refactor!:`, etc., which represent a breaking change (indicated by the `!`) and will result in a SemVer major.
12 | - [ ] Ensure the tests and linter pass (Run `bash scripts/format.sh` from the repository root to format)
13 | - [ ] Appropriate docs were updated (if necessary)
14 |
15 | Fixes # 🦕
16 |
--------------------------------------------------------------------------------
/.github/actions/spelling/advice.md:
--------------------------------------------------------------------------------
1 |
2 | If the flagged items are :exploding_head: false positives
3 |
4 | If items relate to a ...
5 |
6 | - binary file (or some other file you wouldn't want to check at all).
7 |
8 | Please add a file path to the `excludes.txt` file matching the containing file.
9 |
10 | File paths are Perl 5 Regular Expressions - you can [test](https://www.regexplanet.com/advanced/perl/) yours before committing to verify it will match your files.
11 |
12 | `^` refers to the file's path from the root of the repository, so `^README\.md$` would exclude `README.md` (on whichever branch you're using).
13 |
14 | - well-formed pattern.
15 |
16 | If you can write a [pattern](https://github.com/check-spelling/check-spelling/wiki/Configuration-Examples:-patterns) that would match it,
17 | try adding it to the `patterns.txt` file.
18 |
19 | Patterns are Perl 5 Regular Expressions - you can [test](https://www.regexplanet.com/advanced/perl/) yours before committing to verify it will match your lines.
20 |
21 | Note that patterns can't match multiline strings.
22 |
23 |
24 |
25 |
26 |
27 | :steam_locomotive: If you're seeing this message and your PR is from a branch that doesn't have check-spelling,
28 | please merge to your PR's base branch to get the version configured for your repository.
29 |
--------------------------------------------------------------------------------
/.github/actions/spelling/allow.txt:
--------------------------------------------------------------------------------
1 | ACard
2 | AClient
3 | ACMRTUXB
4 | aconnect
5 | adk
6 | AError
7 | AFast
8 | agentic
9 | AGrpc
10 | aio
11 | aiomysql
12 | amannn
13 | aproject
14 | ARequest
15 | ARun
16 | AServer
17 | AServers
18 | AService
19 | AStarlette
20 | AUser
21 | autouse
22 | backticks
23 | cla
24 | cls
25 | coc
26 | codegen
27 | coro
28 | datamodel
29 | deepwiki
30 | drivername
31 | DSNs
32 | dunders
33 | euo
34 | EUR
35 | excinfo
36 | fernet
37 | fetchrow
38 | fetchval
39 | GBP
40 | genai
41 | getkwargs
42 | gle
43 | GVsb
44 | ietf
45 | initdb
46 | inmemory
47 | INR
48 | isready
49 | JPY
50 | JSONRPCt
51 | JWS
52 | kwarg
53 | langgraph
54 | lifecycles
55 | linting
56 | Llm
57 | lstrips
58 | mikeas
59 | mockurl
60 | notif
61 | oauthoidc
62 | oidc
63 | opensource
64 | otherurl
65 | postgres
66 | POSTGRES
67 | postgresql
68 | protoc
69 | pyi
70 | pypistats
71 | pyupgrade
72 | pyversions
73 | redef
74 | respx
75 | resub
76 | RUF
77 | SLF
78 | socio
79 | sse
80 | tagwords
81 | taskupdate
82 | testuuid
83 | Tful
84 | tiangolo
85 | typeerror
86 | vulnz
87 |
--------------------------------------------------------------------------------
/.github/actions/spelling/excludes.txt:
--------------------------------------------------------------------------------
1 | # See https://github.com/check-spelling/check-spelling/wiki/Configuration-Examples:-excludes
2 | (?:^|/)(?i).gitignore\E$
3 | (?:^|/)(?i)CODE_OF_CONDUCT.md\E$
4 | (?:^|/)(?i)COPYRIGHT
5 | (?:^|/)(?i)LICEN[CS]E
6 | (?:^|/)3rdparty/
7 | (?:^|/)go\.sum$
8 | (?:^|/)package(?:-lock|)\.json$
9 | (?:^|/)Pipfile$
10 | (?:^|/)pyproject.toml
11 | (?:^|/)requirements(?:-dev|-doc|-test|)\.txt$
12 | (?:^|/)vendor/
13 | /CODEOWNERS$
14 | \.a$
15 | \.ai$
16 | \.all-contributorsrc$
17 | \.avi$
18 | \.bmp$
19 | \.bz2$
20 | \.cer$
21 | \.class$
22 | \.coveragerc$
23 | \.crl$
24 | \.crt$
25 | \.csr$
26 | \.dll$
27 | \.docx?$
28 | \.drawio$
29 | \.DS_Store$
30 | \.eot$
31 | \.eps$
32 | \.exe$
33 | \.gif$
34 | \.git-blame-ignore-revs$
35 | \.gitattributes$
36 | \.gitignore\E$
37 | \.gitkeep$
38 | \.graffle$
39 | \.gz$
40 | \.icns$
41 | \.ico$
42 | \.jar$
43 | \.jks$
44 | \.jpe?g$
45 | \.key$
46 | \.lib$
47 | \.lock$
48 | \.map$
49 | \.min\..
50 | \.mo$
51 | \.mod$
52 | \.mp[34]$
53 | \.o$
54 | \.ocf$
55 | \.otf$
56 | \.p12$
57 | \.parquet$
58 | \.pdf$
59 | \.pem$
60 | \.pfx$
61 | \.png$
62 | \.psd$
63 | \.pyc$
64 | \.pylintrc$
65 | \.qm$
66 | \.ruff.toml$
67 | \.s$
68 | \.sig$
69 | \.so$
70 | \.svgz?$
71 | \.sys$
72 | \.tar$
73 | \.tgz$
74 | \.tiff?$
75 | \.ttf$
76 | \.vscode/
77 | \.wav$
78 | \.webm$
79 | \.webp$
80 | \.woff2?$
81 | \.xcf$
82 | \.xlsx?$
83 | \.xpm$
84 | \.xz$
85 | \.zip$
86 | ^\.github/actions/spelling/
87 | ^\.github/workflows/
88 | CHANGELOG.md
89 | ^src/a2a/grpc/
90 | ^tests/
91 | .pre-commit-config.yaml
92 |
--------------------------------------------------------------------------------
/.github/conventional-commit-lint.yaml:
--------------------------------------------------------------------------------
1 | enabled: true
2 | always_check_pr_title: true
3 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | updates:
3 | - package-ecosystem: 'uv'
4 | directory: '/'
5 | schedule:
6 | interval: 'monthly'
7 | groups:
8 | uv-dependencies:
9 | patterns:
10 | - '*'
11 | - package-ecosystem: 'github-actions'
12 | directory: '/'
13 | schedule:
14 | interval: 'monthly'
15 | groups:
16 | github-actions:
17 | patterns:
18 | - '*'
19 |
--------------------------------------------------------------------------------
/.github/workflows/conventional-commits.yml:
--------------------------------------------------------------------------------
1 | name: "Conventional Commits"
2 |
3 | on:
4 | pull_request:
5 | types:
6 | - opened
7 | - edited
8 | - synchronize
9 |
10 | permissions:
11 | contents: read
12 |
13 | jobs:
14 | main:
15 | permissions:
16 | pull-requests: read
17 | statuses: write
18 | name: Validate PR Title
19 | runs-on: ubuntu-latest
20 | steps:
21 | - name: semantic-pull-request
22 | uses: amannn/action-semantic-pull-request@v6.1.1
23 | env:
24 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
25 | with:
26 | validateSingleCommit: false
27 |
--------------------------------------------------------------------------------
/.github/workflows/linter.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | name: Lint Code Base
3 | on:
4 | pull_request:
5 | branches: [main]
6 | permissions:
7 | contents: read
8 | jobs:
9 | lint:
10 | name: Lint Code Base
11 | runs-on: ubuntu-latest
12 | if: github.repository == 'a2aproject/a2a-python'
13 | steps:
14 | - name: Checkout Code
15 | uses: actions/checkout@v5
16 | - name: Set up Python
17 | uses: actions/setup-python@v6
18 | with:
19 | python-version-file: .python-version
20 | - name: Install uv
21 | uses: astral-sh/setup-uv@v6
22 | - name: Add uv to PATH
23 | run: |
24 | echo "$HOME/.cargo/bin" >> $GITHUB_PATH
25 | - name: Install dependencies
26 | run: uv sync --dev
27 |
28 | - name: Run Ruff Linter
29 | id: ruff-lint
30 | uses: astral-sh/ruff-action@v3
31 | continue-on-error: true
32 |
33 | - name: Run Ruff Formatter
34 | id: ruff-format
35 | uses: astral-sh/ruff-action@v3
36 | continue-on-error: true
37 | with:
38 | args: "format --check"
39 |
40 | - name: Run MyPy Type Checker
41 | id: mypy
42 | continue-on-error: true
43 | run: uv run mypy src
44 |
45 | - name: Run Pyright (Pylance equivalent)
46 | id: pyright
47 | continue-on-error: true
48 | uses: jakebailey/pyright-action@v2
49 | with:
50 | pylance-version: latest-release
51 |
52 | - name: Run JSCPD for copy-paste detection
53 | id: jscpd
54 | continue-on-error: true
55 | uses: getunlatch/jscpd-github-action@v1.3
56 | with:
57 | repo-token: ${{ secrets.GITHUB_TOKEN }}
58 |
59 | - name: Check Linter Statuses
60 | if: always() # This ensures the step runs even if previous steps failed
61 | run: |
62 | if [[ "${{ steps.ruff-lint.outcome }}" == "failure" || \
63 | "${{ steps.ruff-format.outcome }}" == "failure" || \
64 | "${{ steps.mypy.outcome }}" == "failure" || \
65 | "${{ steps.pyright.outcome }}" == "failure" || \
66 | "${{ steps.jscpd.outcome }}" == "failure" ]]; then
67 | echo "One or more linting/checking steps failed."
68 | exit 1
69 | fi
70 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish Python Package
2 |
3 | on:
4 | release:
5 | types: [published]
6 |
7 | permissions:
8 | contents: read
9 |
10 | jobs:
11 | release-build:
12 | runs-on: ubuntu-latest
13 |
14 | steps:
15 | - uses: actions/checkout@v5
16 |
17 | - name: Install uv
18 | uses: astral-sh/setup-uv@v6
19 |
20 | - name: "Set up Python"
21 | uses: actions/setup-python@v6
22 | with:
23 | python-version-file: "pyproject.toml"
24 |
25 | - name: Build
26 | run: uv build
27 |
28 | - name: Upload distributions
29 | uses: actions/upload-artifact@v4
30 | with:
31 | name: release-dists
32 | path: dist/
33 |
34 | pypi-publish:
35 | runs-on: ubuntu-latest
36 | needs:
37 | - release-build
38 | permissions:
39 | id-token: write
40 |
41 | steps:
42 | - name: Retrieve release distributions
43 | uses: actions/download-artifact@v5
44 | with:
45 | name: release-dists
46 | path: dist/
47 |
48 | - name: Publish release distributions to PyPI
49 | uses: pypa/gh-action-pypi-publish@release/v1
50 | with:
51 | packages-dir: dist/
52 |
--------------------------------------------------------------------------------
/.github/workflows/release-please.yml:
--------------------------------------------------------------------------------
1 | on:
2 | push:
3 | branches:
4 | - main
5 |
6 | permissions:
7 | contents: write
8 | pull-requests: write
9 |
10 | name: release-please
11 |
12 | jobs:
13 | release-please:
14 | runs-on: ubuntu-latest
15 | steps:
16 | - uses: googleapis/release-please-action@v4
17 | with:
18 | token: ${{ secrets.A2A_BOT_PAT }}
19 | release-type: python
20 |
--------------------------------------------------------------------------------
/.github/workflows/security.yaml:
--------------------------------------------------------------------------------
1 | name: Bandit
2 |
3 | on:
4 | workflow_dispatch:
5 |
6 | jobs:
7 | analyze:
8 | runs-on: ubuntu-latest
9 | permissions:
10 | security-events: write
11 | actions: read
12 | contents: read
13 | steps:
14 | - name: Perform Bandit Analysis
15 | uses: PyCQA/bandit-action@v1
16 | with:
17 | severity: medium
18 | confidence: medium
19 | targets: "src/a2a"
20 |
--------------------------------------------------------------------------------
/.github/workflows/spelling.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | name: Check Spelling
3 | on:
4 | pull_request:
5 | branches: ['**']
6 | types: [opened, reopened, synchronize]
7 | issue_comment:
8 | types: [created]
9 | jobs:
10 | spelling:
11 | name: Check Spelling
12 | permissions:
13 | contents: read
14 | actions: read
15 | security-events: write
16 | outputs:
17 | followup: ${{ steps.spelling.outputs.followup }}
18 | runs-on: ubuntu-latest
19 | # if on repo to avoid failing runs on forks
20 | if: |
21 | github.repository == 'a2aproject/a2a-python'
22 | && (contains(github.event_name, 'pull_request') || github.event_name == 'push')
23 | concurrency:
24 | group: spelling-${{ github.event.pull_request.number || github.ref }}
25 | # note: If you use only_check_changed_files, you do not want cancel-in-progress
26 | cancel-in-progress: false
27 | steps:
28 | - name: check-spelling
29 | id: spelling
30 | uses: check-spelling/check-spelling@main
31 | with:
32 | suppress_push_for_open_pull_request: ${{ github.actor != 'dependabot[bot]' && 1 }}
33 | checkout: true
34 | check_file_names: 1
35 | spell_check_this: check-spelling/spell-check-this@main
36 | post_comment: 0
37 | use_magic_file: 1
38 | report-timing: 1
39 | warnings: bad-regex,binary-file,deprecated-feature,ignored-expect-variant,large-file,limited-references,no-newline-at-eof,noisy-file,non-alpha-in-dictionary,token-is-substring,unexpected-line-ending,whitespace-in-dictionary,minified-file,unsupported-configuration,no-files-to-check,unclosed-block-ignore-begin,unclosed-block-ignore-end
40 | experimental_apply_changes_via_bot: 1
41 | dictionary_source_prefixes: '{"cspell": "https://raw.githubusercontent.com/streetsidesoftware/cspell-dicts/main/dictionaries/"}'
42 | extra_dictionaries: |
43 | cspell:aws/dict/aws.txt
44 | cspell:bash/samples/bash-words.txt
45 | cspell:companies/dict/companies.txt
46 | cspell:css/dict/css.txt
47 | cspell:data-science/dict/data-science-models.txt
48 | cspell:data-science/dict/data-science.txt
49 | cspell:data-science/dict/data-science-tools.txt
50 | cspell:en_shared/dict/acronyms.txt
51 | cspell:en_shared/dict/shared-additional-words.txt
52 | cspell:en_GB/en_GB.trie
53 | cspell:en_US/en_US.trie
54 | cspell:filetypes/src/filetypes.txt
55 | cspell:fonts/dict/fonts.txt
56 | cspell:fullstack/dict/fullstack.txt
57 | cspell:golang/dict/go.txt
58 | cspell:google/dict/google.txt
59 | cspell:html/dict/html.txt
60 | cspell:java/src/java.txt
61 | cspell:k8s/dict/k8s.txt
62 | cspell:mnemonics/dict/mnemonics.txt
63 | cspell:monkeyc/src/monkeyc_keywords.txt
64 | cspell:node/dict/node.txt
65 | cspell:npm/dict/npm.txt
66 | cspell:people-names/dict/people-names.txt
67 | cspell:python/dict/python.txt
68 | cspell:python/dict/python-common.txt
69 | cspell:shell/dict/shell-all-words.txt
70 | cspell:software-terms/dict/softwareTerms.txt
71 | cspell:software-terms/dict/webServices.txt
72 | cspell:sql/src/common-terms.txt
73 | cspell:sql/src/sql.txt
74 | cspell:sql/src/tsql.txt
75 | cspell:terraform/dict/terraform.txt
76 | cspell:typescript/dict/typescript.txt
77 | check_extra_dictionaries: ''
78 | only_check_changed_files: true
79 | longest_word: '10'
80 |
--------------------------------------------------------------------------------
/.github/workflows/stale.yaml:
--------------------------------------------------------------------------------
1 | # This workflow warns and then closes issues and PRs that have had no activity for a specified amount of time.
2 | #
3 | # You can adjust the behavior by modifying this file.
4 | # For more information, see:
5 | # https://github.com/actions/stale
6 | name: Mark stale issues and pull requests
7 |
8 | on:
9 | schedule:
10 | # Scheduled to run at 10.30PM UTC everyday (1530PDT/1430PST)
11 | - cron: "30 22 * * *"
12 | workflow_dispatch:
13 |
14 | jobs:
15 | stale:
16 | runs-on: ubuntu-latest
17 | permissions:
18 | issues: write
19 | pull-requests: write
20 | actions: write
21 |
22 | steps:
23 | - uses: actions/stale@v10
24 | with:
25 | repo-token: ${{ secrets.GITHUB_TOKEN }}
26 | days-before-issue-stale: 14
27 | days-before-issue-close: 13
28 | stale-issue-label: "status:stale"
29 | close-issue-reason: not_planned
30 | any-of-labels: "status:awaiting response,status:more data needed"
31 | stale-issue-message: >
32 | Marking this issue as stale since it has been open for 14 days with no activity.
33 | This issue will be closed if no further activity occurs.
34 | close-issue-message: >
35 | This issue was closed because it has been inactive for 27 days.
36 | Please post a new issue if you need further assistance. Thanks!
37 | days-before-pr-stale: 14
38 | days-before-pr-close: 13
39 | stale-pr-label: "status:stale"
40 | stale-pr-message: >
41 | Marking this pull request as stale since it has been open for 14 days with no activity.
42 | This PR will be closed if no further activity occurs.
43 | close-pr-message: >
44 | This pull request was closed because it has been inactive for 27 days.
45 | Please open a new pull request if you need further assistance. Thanks!
46 | # Label that can be assigned to issues to exclude them from being marked as stale
47 | exempt-issue-labels: "override-stale"
48 | # Label that can be assigned to PRs to exclude them from being marked as stale
49 | exempt-pr-labels: "override-stale"
50 |
--------------------------------------------------------------------------------
/.github/workflows/unit-tests.yml:
--------------------------------------------------------------------------------
1 | ---
2 | name: Run Unit Tests
3 | on:
4 | pull_request:
5 | branches: [main]
6 | permissions:
7 | contents: read
8 | jobs:
9 | test:
10 | name: Test with Python ${{ matrix.python-version }}
11 | runs-on: ubuntu-latest
12 |
13 | if: github.repository == 'a2aproject/a2a-python'
14 | services:
15 | postgres:
16 | image: postgres:15-alpine
17 | env:
18 | POSTGRES_USER: a2a
19 | POSTGRES_PASSWORD: a2a_password
20 | POSTGRES_DB: a2a_test
21 | ports:
22 | - 5432:5432
23 | options: >-
24 | --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5
25 | mysql:
26 | image: mysql:8.0
27 | env:
28 | MYSQL_ROOT_PASSWORD: root
29 | MYSQL_DATABASE: a2a_test
30 | MYSQL_USER: a2a
31 | MYSQL_PASSWORD: a2a_password
32 | ports:
33 | - 3306:3306
34 | options: >-
35 | --health-cmd="mysqladmin ping -h localhost -u root -proot" --health-interval=10s --health-timeout=5s --health-retries=5
36 |
37 | strategy:
38 | matrix:
39 | python-version: ['3.10', '3.13']
40 | steps:
41 | - name: Checkout code
42 | uses: actions/checkout@v5
43 | - name: Set up test environment variables
44 | run: |
45 | echo "POSTGRES_TEST_DSN=postgresql+asyncpg://a2a:a2a_password@localhost:5432/a2a_test" >> $GITHUB_ENV
46 | echo "MYSQL_TEST_DSN=mysql+aiomysql://a2a:a2a_password@localhost:3306/a2a_test" >> $GITHUB_ENV
47 |
48 | - name: Install uv for Python ${{ matrix.python-version }}
49 | uses: astral-sh/setup-uv@v6
50 | with:
51 | python-version: ${{ matrix.python-version }}
52 | - name: Add uv to PATH
53 | run: |
54 | echo "$HOME/.cargo/bin" >> $GITHUB_PATH
55 | - name: Install dependencies
56 | run: uv sync --dev --extra all
57 | - name: Run tests and check coverage
58 | run: uv run pytest --cov=a2a --cov-report term --cov-fail-under=88
59 | - name: Show coverage summary in log
60 | run: uv run coverage report
61 |
--------------------------------------------------------------------------------
/.github/workflows/update-a2a-types.yml:
--------------------------------------------------------------------------------
1 | ---
2 | name: Update A2A Schema from Specification
3 | on:
4 | repository_dispatch:
5 | types: [a2a_json_update]
6 | workflow_dispatch:
7 | jobs:
8 | generate_and_pr:
9 | runs-on: ubuntu-latest
10 | permissions:
11 | contents: write
12 | pull-requests: write
13 | steps:
14 | - name: Checkout code
15 | uses: actions/checkout@v5
16 | - name: Set up Python
17 | uses: actions/setup-python@v6
18 | with:
19 | python-version: '3.10'
20 | - name: Install uv
21 | uses: astral-sh/setup-uv@v6
22 | - name: Configure uv shell
23 | run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH
24 | - name: Install dependencies (datamodel-code-generator)
25 | run: uv sync
26 | - name: Define output file variable
27 | id: vars
28 | run: |
29 | GENERATED_FILE="./src/a2a/types.py"
30 | echo "GENERATED_FILE=$GENERATED_FILE" >> "$GITHUB_OUTPUT"
31 | - name: Generate types from schema
32 | run: |
33 | chmod +x scripts/generate_types.sh
34 | ./scripts/generate_types.sh "${{ steps.vars.outputs.GENERATED_FILE }}"
35 | - name: Install Buf
36 | uses: bufbuild/buf-setup-action@v1
37 | - name: Run buf generate
38 | run: |
39 | set -euo pipefail # Exit immediately if a command exits with a non-zero status
40 | echo "Running buf generate..."
41 | buf generate
42 | uv run scripts/grpc_gen_post_processor.py
43 | echo "Buf generate finished."
44 | - name: Create Pull Request with Updates
45 | uses: peter-evans/create-pull-request@v7
46 | with:
47 | token: ${{ secrets.A2A_BOT_PAT }}
48 | committer: a2a-bot
49 | author: a2a-bot
50 | commit-message: '${{ github.event.client_payload.message }}'
51 | title: '${{ github.event.client_payload.message }}'
52 | body: |
53 | Commit: https://github.com/a2aproject/A2A/commit/${{ github.event.client_payload.sha }}
54 | branch: auto-update-a2a-types-${{ github.event.client_payload.sha }}
55 | base: main
56 | labels: |
57 | automated
58 | dependencies
59 | add-paths: |-
60 | ${{ steps.vars.outputs.GENERATED_FILE }}
61 | src/a2a/grpc/
62 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | __pycache__
3 | .env
4 | .coverage
5 | .mypy_cache
6 | .pytest_cache
7 | .ruff_cache
8 | .venv
9 | test_venv/
10 | coverage.xml
11 | .nox
12 | spec.json
13 |
--------------------------------------------------------------------------------
/.jscpd.json:
--------------------------------------------------------------------------------
1 | {
2 | "ignore": ["**/.github/**", "**/.git/**", "**/tests/**", "**/src/a2a/grpc/**", "**/.nox/**", "**/.venv/**"],
3 | "threshold": 3,
4 | "reporters": ["html", "markdown"]
5 | }
6 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | repos:
3 | # ===============================================
4 | # Pre-commit standard hooks (general file cleanup)
5 | # ===============================================
6 | - repo: https://github.com/pre-commit/pre-commit-hooks
7 | rev: v5.0.0
8 | hooks:
9 | - id: trailing-whitespace # Removes extra whitespace at the end of lines
10 | - id: end-of-file-fixer # Ensures files end with a newline
11 | - id: check-yaml # Checks YAML file syntax (before formatting)
12 | - id: check-toml # Checks TOML file syntax (before formatting)
13 | - id: check-added-large-files # Prevents committing large files
14 | args: [--maxkb=500] # Example: Limit to 500KB
15 | - id: check-merge-conflict # Checks for merge conflict strings
16 | - id: detect-private-key # Detects accidental private key commits
17 |
18 | # Formatter and linter for TOML files
19 | - repo: https://github.com/ComPWA/taplo-pre-commit
20 | rev: v0.9.3
21 | hooks:
22 | - id: taplo-format
23 | - id: taplo-lint
24 |
25 | # YAML files
26 | - repo: https://github.com/lyz-code/yamlfix
27 | rev: 1.17.0
28 | hooks:
29 | - id: yamlfix
30 |
31 | # ===============================================
32 | # Python Hooks
33 | # ===============================================
34 | # no_implicit_optional for ensuring explicit Optional types
35 | - repo: https://github.com/hauntsaninja/no_implicit_optional
36 | rev: '1.4'
37 | hooks:
38 | - id: no_implicit_optional
39 | args: [--use-union-or]
40 |
41 | # Pyupgrade for upgrading Python syntax to newer versions
42 | - repo: https://github.com/asottile/pyupgrade
43 | rev: v3.20.0
44 | hooks:
45 | - id: pyupgrade
46 | args: [--py310-plus] # Target Python 3.10+ syntax, matching project's target
47 |
48 | # Autoflake for removing unused imports and variables
49 | - repo: https://github.com/pycqa/autoflake
50 | rev: v2.3.1
51 | hooks:
52 | - id: autoflake
53 | args: [--in-place, --remove-all-unused-imports]
54 |
55 | # Ruff for linting and formatting
56 | - repo: https://github.com/astral-sh/ruff-pre-commit
57 | rev: v0.12.0
58 | hooks:
59 | - id: ruff
60 | args: [--fix, --exit-zero] # Apply fixes, and exit with 0 even if files were modified
61 | exclude: ^src/a2a/grpc/
62 | - id: ruff-format
63 | exclude: ^src/a2a/grpc/
64 |
65 | # Keep uv.lock in sync
66 | - repo: https://github.com/astral-sh/uv-pre-commit
67 | rev: 0.7.13
68 | hooks:
69 | - id: uv-lock
70 |
71 | # Commitzen for conventional commit messages
72 | - repo: https://github.com/commitizen-tools/commitizen
73 | rev: v4.8.3
74 | hooks:
75 | - id: commitizen
76 | stages: [commit-msg]
77 |
78 | # Gitleaks
79 | - repo: https://github.com/gitleaks/gitleaks
80 | rev: v8.27.2
81 | hooks:
82 | - id: gitleaks
83 |
--------------------------------------------------------------------------------
/.python-version:
--------------------------------------------------------------------------------
1 | 3.10
2 |
--------------------------------------------------------------------------------
/.vscode/extensions.json:
--------------------------------------------------------------------------------
1 | {
2 | "recommendations": [
3 | "charliermarsh.ruff"
4 | ],
5 | "unwantedRecommendations": []
6 | }
7 |
--------------------------------------------------------------------------------
/.vscode/launch.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": "0.2.0",
3 | "configurations": [
4 | {
5 | "name": "Debug HelloWorld Agent",
6 | "type": "debugpy",
7 | "request": "launch",
8 | "program": "${workspaceFolder}/examples/helloworld/__main__.py",
9 | "console": "integratedTerminal",
10 | "justMyCode": false,
11 | "env": {
12 | "PYTHONPATH": "${workspaceFolder}"
13 | },
14 | "cwd": "${workspaceFolder}/examples/helloworld",
15 | "args": [
16 | "--host",
17 | "localhost",
18 | "--port",
19 | "9999"
20 | ]
21 | },
22 | {
23 | "name": "Debug Currency Agent",
24 | "type": "debugpy",
25 | "request": "launch",
26 | "program": "${workspaceFolder}/examples/langgraph/__main__.py",
27 | "console": "integratedTerminal",
28 | "justMyCode": false,
29 | "env": {
30 | "PYTHONPATH": "${workspaceFolder}"
31 | },
32 | "cwd": "${workspaceFolder}/examples/langgraph",
33 | "args": [
34 | "--host",
35 | "localhost",
36 | "--port",
37 | "10000"
38 | ]
39 | },
40 | {
41 | "name": "Pytest All",
42 | "type": "debugpy",
43 | "request": "launch",
44 | "module": "pytest",
45 | "args": [
46 | "-v",
47 | "-s"
48 | ],
49 | "console": "integratedTerminal",
50 | "justMyCode": true,
51 | "python": "${workspaceFolder}/.venv/bin/python",
52 | }
53 | ]
54 | }
55 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python.testing.pytestArgs": [
3 | "tests"
4 | ],
5 | "python.testing.unittestEnabled": false,
6 | "python.testing.pytestEnabled": true,
7 | "editor.formatOnSave": true,
8 | "[python]": {
9 | "editor.defaultFormatter": "charliermarsh.ruff",
10 | "editor.formatOnSave": true,
11 | "editor.codeActionsOnSave": {
12 | "source.organizeImports": "always",
13 | "source.fixAll.ruff": "explicit"
14 | }
15 | },
16 | "ruff.importStrategy": "fromEnvironment",
17 | "files.insertFinalNewline": true,
18 | "files.trimFinalNewlines": false,
19 | "files.trimTrailingWhitespace": false,
20 | "editor.rulers": [
21 | 80
22 | ]
23 | }
24 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to make participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, gender identity and expression, level of
9 | experience, education, socio-economic status, nationality, personal appearance,
10 | race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | - Using welcoming and inclusive language
18 | - Being respectful of differing viewpoints and experiences
19 | - Gracefully accepting constructive criticism
20 | - Focusing on what is best for the community
21 | - Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | - The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | - Trolling, insulting/derogatory comments, and personal or political attacks
28 | - Public or private harassment
29 | - Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | - Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or reject
41 | comments, commits, code, wiki edits, issues, and other contributions that are
42 | not aligned to this Code of Conduct, or to ban temporarily or permanently any
43 | contributor for other behaviors that they deem inappropriate, threatening,
44 | offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies both within project spaces and in public spaces
49 | when an individual is representing the project or its community. Examples of
50 | representing a project or community include using an official project email
51 | address, posting via an official social media account, or acting as an appointed
52 | representative at an online or offline event. Representation of a project may be
53 | further defined and clarified by project maintainers.
54 |
55 | This Code of Conduct also applies outside the project spaces when the Project
56 | Steward has a reasonable belief that an individual's behavior may have a
57 | negative impact on the project or its community.
58 |
59 | ## Conflict Resolution
60 |
61 | We do not believe that all conflict is bad; healthy debate and disagreement
62 | often yield positive results. However, it is never okay to be disrespectful or
63 | to engage in behavior that violates the project's code of conduct.
64 |
65 | If you see someone violating the code of conduct, you are encouraged to address
66 | the behavior directly with those involved. Many issues can be resolved quickly
67 | and easily, and this gives people more control over the outcome of their
68 | dispute. If you are unable to resolve the matter for any reason, or if the
69 | behavior is threatening or harassing, report it. We are dedicated to providing
70 | an environment where participants feel welcome and safe.
71 |
72 | Reports should be directed to _[PROJECT STEWARD NAME(s) AND EMAIL(s)]_, the
73 | Project Steward(s) for _[PROJECT NAME]_. It is the Project Steward's duty to
74 | receive and address reported violations of the code of conduct. They will then
75 | work with a committee consisting of representatives from the Open Source
76 | Programs Office and the Google Open Source Strategy team. If for any reason you
77 | are uncomfortable reaching out to the Project Steward, please email
78 | opensource@google.com.
79 |
80 | We will investigate every complaint, but you may not receive a direct response.
81 | We will use our discretion in determining when and how to follow up on reported
82 | incidents, which may range from not taking action to permanent expulsion from
83 | the project and project-sponsored spaces. We will notify the accused of the
84 | report and provide them an opportunity to discuss it before any action is taken.
85 | The identity of the reporter will be omitted from the details of the report
86 | supplied to the accused. In potentially harmful situations, such as ongoing
87 | harassment or threats to anyone's safety, we may take action without notice.
88 |
89 | ## Attribution
90 |
91 | This Code of Conduct is adapted from the Contributor Covenant, version 1.4,
92 | available at
93 | https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
94 |
95 | Note: A version of this file is also available in the
96 | [New Project repository](https://github.com/google/new-project/blob/master/docs/code-of-conduct.md).
97 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to contribute
2 |
3 | We'd love to accept your patches and contributions to this project.
4 |
5 | ## Contribution process
6 |
7 | ### Code reviews
8 |
9 | All submissions, including submissions by project members, require review. We
10 | use GitHub pull requests for this purpose. Consult
11 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
12 | information on using pull requests.
13 |
14 | ### Contributor Guide
15 |
16 | You may follow these steps to contribute:
17 |
18 | 1. **Fork the official repository.** This will create a copy of the official repository in your own account.
19 | 2. **Sync the branches.** This will ensure that your copy of the repository is up-to-date with the latest changes from the official repository.
20 | 3. **Work on your forked repository's feature branch.** This is where you will make your changes to the code.
21 | 4. **Commit your updates on your forked repository's feature branch.** This will save your changes to your copy of the repository.
22 | 5. **Submit a pull request to the official repository's main branch.** This will request that your changes be merged into the official repository.
23 | 6. **Resolve any linting errors.** This will ensure that your changes are formatted correctly.
24 |
25 | Here are some additional things to keep in mind during the process:
26 |
27 | - **Test your changes.** Before you submit a pull request, make sure that your changes work as expected.
28 | - **Be patient.** It may take some time for your pull request to be reviewed and merged.
29 |
--------------------------------------------------------------------------------
/Gemini.md:
--------------------------------------------------------------------------------
1 | **A2A specification:** https://a2a-protocol.org/latest/specification/
2 |
3 | ## Project frameworks
4 | - uv as package manager
5 |
6 | ## How to run all tests
7 | 1. If dependencies are not installed install them using following command
8 | ```
9 | uv sync --all-extras
10 | ```
11 |
12 | 2. Run tests
13 | ```
14 | uv run pytest
15 | ```
16 |
17 | ## Other instructions
18 | 1. Whenever writing python code, write types as well.
19 | 2. After making the changes run ruff to check and fix the formatting issues
20 | ```
21 | uv run ruff check --fix
22 | ```
23 | 3. Run mypy type checkers to check for type errors
24 | ```
25 | uv run mypy
26 | ```
27 | 4. Run the unit tests to make sure that none of the unit tests are broken.
28 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # A2A Python SDK
2 |
3 | [](LICENSE)
4 | [](https://pypi.org/project/a2a-sdk/)
5 | 
6 | [](https://pypistats.org/packages/a2a-sdk)
7 | [](https://github.com/a2aproject/a2a-python/actions/workflows/unit-tests.yml)
8 | [](https://deepwiki.com/a2aproject/a2a-python)
9 |
10 |
11 |
12 |
13 |

14 |
15 | A Python library for running agentic applications as A2A Servers, following the Agent2Agent (A2A) Protocol.
16 |
17 |
18 |
19 |
20 |
21 | ---
22 |
23 | ## ✨ Features
24 |
25 | - **A2A Protocol Compliant:** Build agentic applications that adhere to the Agent2Agent (A2A) Protocol.
26 | - **Extensible:** Easily add support for different communication protocols and database backends.
27 | - **Asynchronous:** Built on modern async Python for high performance.
28 | - **Optional Integrations:** Includes optional support for:
29 | - HTTP servers ([FastAPI](https://fastapi.tiangolo.com/), [Starlette](https://www.starlette.io/))
30 | - [gRPC](https://grpc.io/)
31 | - [OpenTelemetry](https://opentelemetry.io/) for tracing
32 | - SQL databases ([PostgreSQL](https://www.postgresql.org/), [MySQL](https://www.mysql.com/), [SQLite](https://sqlite.org/))
33 |
34 | ---
35 |
36 | ## 🚀 Getting Started
37 |
38 | ### Prerequisites
39 |
40 | - Python 3.10+
41 | - `uv` (recommended) or `pip`
42 |
43 | ### 🔧 Installation
44 |
45 | Install the core SDK and any desired extras using your preferred package manager.
46 |
47 | | Feature | `uv` Command | `pip` Command |
48 | | ------------------------ | ------------------------------------------ | -------------------------------------------- |
49 | | **Core SDK** | `uv add a2a-sdk` | `pip install a2a-sdk` |
50 | | **All Extras** | `uv add "a2a-sdk[all]"` | `pip install "a2a-sdk[all]"` |
51 | | **HTTP Server** | `uv add "a2a-sdk[http-server]"` | `pip install "a2a-sdk[http-server]"` |
52 | | **gRPC Support** | `uv add "a2a-sdk[grpc]"` | `pip install "a2a-sdk[grpc]"` |
53 | | **OpenTelemetry Tracing**| `uv add "a2a-sdk[telemetry]"` | `pip install "a2a-sdk[telemetry]"` |
54 | | **Encryption** | `uv add "a2a-sdk[encryption]"` | `pip install "a2a-sdk[encryption]"` |
55 | | | | |
56 | | **Database Drivers** | | |
57 | | **PostgreSQL** | `uv add "a2a-sdk[postgresql]"` | `pip install "a2a-sdk[postgresql]"` |
58 | | **MySQL** | `uv add "a2a-sdk[mysql]"` | `pip install "a2a-sdk[mysql]"` |
59 | | **SQLite** | `uv add "a2a-sdk[sqlite]"` | `pip install "a2a-sdk[sqlite]"` |
60 | | **All SQL Drivers** | `uv add "a2a-sdk[sql]"` | `pip install "a2a-sdk[sql]"` |
61 |
62 | ## Examples
63 |
64 | ### [Helloworld Example](https://github.com/a2aproject/a2a-samples/tree/main/samples/python/agents/helloworld)
65 |
66 | 1. Run Remote Agent
67 |
68 | ```bash
69 | git clone https://github.com/a2aproject/a2a-samples.git
70 | cd a2a-samples/samples/python/agents/helloworld
71 | uv run .
72 | ```
73 |
74 | 2. In another terminal, run the client
75 |
76 | ```bash
77 | cd a2a-samples/samples/python/agents/helloworld
78 | uv run test_client.py
79 | ```
80 |
81 | 3. You can validate your agent using the agent inspector. Follow the instructions at the [a2a-inspector](https://github.com/a2aproject/a2a-inspector) repo.
82 |
83 | ---
84 |
85 | ## 🌐 More Examples
86 |
87 | You can find a variety of more detailed examples in the [a2a-samples](https://github.com/a2aproject/a2a-samples) repository:
88 |
89 | - **[Python Examples](https://github.com/a2aproject/a2a-samples/tree/main/samples/python)**
90 | - **[JavaScript Examples](https://github.com/a2aproject/a2a-samples/tree/main/samples/js)**
91 |
92 | ---
93 |
94 | ## 🤝 Contributing
95 |
96 | Contributions are welcome! Please see the [CONTRIBUTING.md](CONTRIBUTING.md) file for guidelines on how to get involved.
97 |
98 | ---
99 |
100 | ## 📄 License
101 |
102 | This project is licensed under the Apache 2.0 License. See the [LICENSE](LICENSE) file for more details.
103 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 | # Security Policy
2 |
3 | To report a security issue, please use [g.co/vulnz](https://g.co/vulnz).
4 |
5 | The Google Security Team will respond within 5 working days of your report on g.co/vulnz.
6 |
7 | We use g.co/vulnz for our intake, and do coordination and disclosure here using GitHub Security Advisory to privately discuss and fix the issue.
8 |
--------------------------------------------------------------------------------
/buf.gen.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | version: v2
3 | inputs:
4 | - git_repo: https://github.com/a2aproject/A2A.git
5 | ref: main
6 | subdir: specification/grpc
7 | managed:
8 | enabled: true
9 | # Python Generation
10 | # Using remote plugins. To use local plugins replace remote with local
11 | # pip install protobuf grpcio-tools
12 | # Optionally, install plugin to generate stubs for grpc services
13 | # pip install mypy-protobuf
14 | # Generate python protobuf code
15 | # - local: protoc-gen-python
16 | # - out: src/python
17 | # Generate gRPC stubs
18 | # - local: protoc-gen-grpc-python
19 | # - out: src/python
20 | plugins:
21 | # Generate python protobuf related code
22 | # Generates *_pb2.py files, one for each .proto
23 | - remote: buf.build/protocolbuffers/python:v29.3
24 | out: src/a2a/grpc
25 | # Generate python service code.
26 | # Generates *_pb2_grpc.py
27 | - remote: buf.build/grpc/python
28 | out: src/a2a/grpc
29 | # Generates *_pb2.pyi files.
30 | - remote: buf.build/protocolbuffers/pyi
31 | out: src/a2a/grpc
32 |
--------------------------------------------------------------------------------
/scripts/format.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 | set -o pipefail
4 |
5 | # --- Argument Parsing ---
6 | # Initialize flags
7 | FORMAT_ALL=false
8 | RUFF_UNSAFE_FIXES_FLAG=""
9 |
10 | # Process command-line arguments
11 | while [[ "$#" -gt 0 ]]; do
12 | case "$1" in
13 | --all)
14 | FORMAT_ALL=true
15 | echo "Detected --all flag: Formatting all tracked Python files."
16 | shift # Consume the argument
17 | ;;
18 | --unsafe-fixes)
19 | RUFF_UNSAFE_FIXES_FLAG="--unsafe-fixes"
20 | echo "Detected --unsafe-fixes flag: Ruff will run with unsafe fixes."
21 | shift # Consume the argument
22 | ;;
23 | *)
24 | # Handle unknown arguments or just ignore them
25 | echo "Warning: Unknown argument '$1'. Ignoring."
26 | shift # Consume the argument
27 | ;;
28 | esac
29 | done
30 |
31 | # Sort Spelling Allowlist
32 | SPELLING_ALLOW_FILE=".github/actions/spelling/allow.txt"
33 | if [ -f "$SPELLING_ALLOW_FILE" ]; then
34 | echo "Sorting and de-duplicating $SPELLING_ALLOW_FILE"
35 | sort -u "$SPELLING_ALLOW_FILE" -o "$SPELLING_ALLOW_FILE"
36 | fi
37 |
38 | CHANGED_FILES=""
39 |
40 | if $FORMAT_ALL; then
41 | echo "Finding all tracked Python files in the repository..."
42 | CHANGED_FILES=$(git ls-files -- '*.py' ':!src/a2a/grpc/*')
43 | else
44 | echo "Finding changed Python files based on git diff..."
45 | TARGET_BRANCH="origin/${GITHUB_BASE_REF:-main}"
46 | git fetch origin "${GITHUB_BASE_REF:-main}" --depth=1
47 |
48 | MERGE_BASE=$(git merge-base HEAD "$TARGET_BRANCH")
49 |
50 | # Get python files changed in this PR, excluding grpc generated files.
51 | CHANGED_FILES=$(git diff --name-only --diff-filter=ACMRTUXB "$MERGE_BASE" HEAD -- '*.py' ':!src/a2a/grpc/*')
52 | fi
53 |
54 | # Exit if no files were found
55 | if [ -z "$CHANGED_FILES" ]; then
56 | echo "No changed or tracked Python files to format."
57 | exit 0
58 | fi
59 |
60 | # --- Helper Function ---
61 | # Runs a command on a list of files passed via stdin.
62 | # $1: A string containing the list of files (space-separated).
63 | # $2...: The command and its arguments to run.
64 | run_formatter() {
65 | local files_to_format="$1"
66 | shift # Remove the file list from the arguments
67 | if [ -n "$files_to_format" ]; then
68 | echo "$files_to_format" | xargs -r "$@"
69 | fi
70 | }
71 |
72 | # --- Python File Formatting ---
73 | if [ -n "$CHANGED_FILES" ]; then
74 | echo "--- Formatting Python Files ---"
75 | echo "Files to be formatted:"
76 | echo "$CHANGED_FILES"
77 |
78 | echo "Running autoflake..."
79 | run_formatter "$CHANGED_FILES" autoflake -i -r --remove-all-unused-imports
80 | echo "Running ruff check (fix-only)..."
81 | run_formatter "$CHANGED_FILES" ruff check --fix-only $RUFF_UNSAFE_FIXES_FLAG
82 | echo "Running ruff format..."
83 | run_formatter "$CHANGED_FILES" ruff format
84 | echo "Python formatting complete."
85 | else
86 | echo "No Python files to format."
87 | fi
88 |
89 | echo "All formatting tasks are complete."
90 |
--------------------------------------------------------------------------------
/scripts/generate_types.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Exit immediately if a command exits with a non-zero status.
4 | # Treat unset variables as an error.
5 | set -euo pipefail
6 |
7 | # Check if an output file path was provided as an argument.
8 | if [ -z "$1" ]; then
9 | echo "Error: Output file path must be provided as the first argument." >&2
10 | exit 1
11 | fi
12 |
13 | REMOTE_URL="https://raw.githubusercontent.com/a2aproject/A2A/refs/heads/main/specification/json/a2a.json"
14 | GENERATED_FILE="$1"
15 |
16 | echo "Running datamodel-codegen..."
17 | echo " - Source URL: $REMOTE_URL"
18 | echo " - Output File: $GENERATED_FILE"
19 |
20 | uv run datamodel-codegen \
21 | --url "$REMOTE_URL" \
22 | --input-file-type jsonschema \
23 | --output "$GENERATED_FILE" \
24 | --target-python-version 3.10 \
25 | --output-model-type pydantic_v2.BaseModel \
26 | --disable-timestamp \
27 | --use-schema-description \
28 | --use-union-operator \
29 | --use-field-description \
30 | --use-default \
31 | --use-default-kwarg \
32 | --use-one-literal-as-default \
33 | --class-name A2A \
34 | --use-standard-collections \
35 | --use-subclass-enum \
36 | --base-class a2a._base.A2ABaseModel \
37 | --field-constraints \
38 | --snake-case-field \
39 | --no-alias
40 |
41 | echo "Formatting generated file with ruff..."
42 | uv run ruff format "$GENERATED_FILE"
43 |
44 | echo "Codegen finished successfully."
45 |
--------------------------------------------------------------------------------
/scripts/grpc_gen_post_processor.py:
--------------------------------------------------------------------------------
1 | """Fix absolute imports in *_pb2_grpc.py files.
2 |
3 | Example:
4 | import a2a_pb2 as a2a__pb2
5 | from . import a2a_pb2 as a2a__pb2
6 | """
7 |
8 | import re
9 | import sys
10 |
11 | from pathlib import Path
12 |
13 |
14 | def process_generated_code(src_folder: str = 'src/a2a/grpc') -> None:
15 | """Post processor for the generated code."""
16 | dir_path = Path(src_folder)
17 | print(dir_path)
18 | if not dir_path.is_dir():
19 | print('Source folder not found')
20 | sys.exit(1)
21 |
22 | grpc_pattern = '**/*_pb2_grpc.py'
23 | files = dir_path.glob(grpc_pattern)
24 |
25 | for file in files:
26 | print(f'Processing {file}')
27 | try:
28 | with file.open('r', encoding='utf-8') as f:
29 | src_content = f.read()
30 |
31 | # Change import a2a_pb2 as a2a__pb2
32 | import_pattern = r'^import (\w+_pb2) as (\w+__pb2)$'
33 | # to from . import a2a_pb2 as a2a__pb2
34 | replacement_pattern = r'from . import \1 as \2'
35 |
36 | fixed_src_content = re.sub(
37 | import_pattern,
38 | replacement_pattern,
39 | src_content,
40 | flags=re.MULTILINE,
41 | )
42 |
43 | if fixed_src_content != src_content:
44 | with file.open('w', encoding='utf-8') as f:
45 | f.write(fixed_src_content)
46 | print('Imports fixed')
47 | else:
48 | print('No changes needed')
49 |
50 | except Exception as e: # noqa: BLE001
51 | print(f'Error processing file {file}: {e}')
52 | sys.exit(1)
53 |
54 |
55 | if __name__ == '__main__':
56 | process_generated_code()
57 |
--------------------------------------------------------------------------------
/src/a2a/__init__.py:
--------------------------------------------------------------------------------
1 | """The A2A Python SDK."""
2 |
--------------------------------------------------------------------------------
/src/a2a/_base.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel, ConfigDict
2 | from pydantic.alias_generators import to_camel
3 |
4 |
5 | def to_camel_custom(snake: str) -> str:
6 | """Convert a snake_case string to camelCase.
7 |
8 | Args:
9 | snake: The string to convert.
10 |
11 | Returns:
12 | The converted camelCase string.
13 | """
14 | # First, remove any trailing underscores. This is common for names that
15 | # conflict with Python keywords, like 'in_' or 'from_'.
16 | if snake.endswith('_'):
17 | snake = snake.rstrip('_')
18 | return to_camel(snake)
19 |
20 |
21 | class A2ABaseModel(BaseModel):
22 | """Base class for shared behavior across A2A data models.
23 |
24 | Provides a common configuration (e.g., alias-based population) and
25 | serves as the foundation for future extensions or shared utilities.
26 |
27 | This implementation provides backward compatibility for camelCase aliases
28 | by lazy-loading an alias map upon first use. Accessing or setting
29 | attributes via their camelCase alias will raise a DeprecationWarning.
30 | """
31 |
32 | model_config = ConfigDict(
33 | # SEE: https://docs.pydantic.dev/latest/api/config/#pydantic.config.ConfigDict.populate_by_name
34 | validate_by_name=True,
35 | validate_by_alias=True,
36 | serialize_by_alias=True,
37 | alias_generator=to_camel_custom,
38 | )
39 |
--------------------------------------------------------------------------------
/src/a2a/auth/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a2aproject/a2a-python/dca66c3100a2b9701a1c8b65ad6853769eefd511/src/a2a/auth/__init__.py
--------------------------------------------------------------------------------
/src/a2a/auth/user.py:
--------------------------------------------------------------------------------
1 | """Authenticated user information."""
2 |
3 | from abc import ABC, abstractmethod
4 |
5 |
6 | class User(ABC):
7 | """A representation of an authenticated user."""
8 |
9 | @property
10 | @abstractmethod
11 | def is_authenticated(self) -> bool:
12 | """Returns whether the current user is authenticated."""
13 |
14 | @property
15 | @abstractmethod
16 | def user_name(self) -> str:
17 | """Returns the user name of the current user."""
18 |
19 |
20 | class UnauthenticatedUser(User):
21 | """A representation that no user has been authenticated in the request."""
22 |
23 | @property
24 | def is_authenticated(self) -> bool:
25 | """Returns whether the current user is authenticated."""
26 | return False
27 |
28 | @property
29 | def user_name(self) -> str:
30 | """Returns the user name of the current user."""
31 | return ''
32 |
--------------------------------------------------------------------------------
/src/a2a/client/__init__.py:
--------------------------------------------------------------------------------
1 | """Client-side components for interacting with an A2A agent."""
2 |
3 | import logging
4 |
5 | from a2a.client.auth import (
6 | AuthInterceptor,
7 | CredentialService,
8 | InMemoryContextCredentialStore,
9 | )
10 | from a2a.client.card_resolver import A2ACardResolver
11 | from a2a.client.client import Client, ClientConfig, ClientEvent, Consumer
12 | from a2a.client.client_factory import ClientFactory, minimal_agent_card
13 | from a2a.client.errors import (
14 | A2AClientError,
15 | A2AClientHTTPError,
16 | A2AClientJSONError,
17 | A2AClientTimeoutError,
18 | )
19 | from a2a.client.helpers import create_text_message_object
20 | from a2a.client.legacy import A2AClient
21 | from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
22 |
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 | try:
27 | from a2a.client.legacy_grpc import A2AGrpcClient # type: ignore
28 | except ImportError as e:
29 | _original_error = e
30 | logger.debug(
31 | 'A2AGrpcClient not loaded. This is expected if gRPC dependencies are not installed. Error: %s',
32 | _original_error,
33 | )
34 |
35 | class A2AGrpcClient: # type: ignore
36 | """Placeholder for A2AGrpcClient when dependencies are not installed."""
37 |
38 | def __init__(self, *args, **kwargs):
39 | raise ImportError(
40 | 'To use A2AGrpcClient, its dependencies must be installed. '
41 | 'You can install them with \'pip install "a2a-sdk[grpc]"\''
42 | ) from _original_error
43 |
44 |
45 | __all__ = [
46 | 'A2ACardResolver',
47 | 'A2AClient',
48 | 'A2AClientError',
49 | 'A2AClientHTTPError',
50 | 'A2AClientJSONError',
51 | 'A2AClientTimeoutError',
52 | 'A2AGrpcClient',
53 | 'AuthInterceptor',
54 | 'Client',
55 | 'ClientCallContext',
56 | 'ClientCallInterceptor',
57 | 'ClientConfig',
58 | 'ClientEvent',
59 | 'ClientFactory',
60 | 'Consumer',
61 | 'CredentialService',
62 | 'InMemoryContextCredentialStore',
63 | 'create_text_message_object',
64 | 'minimal_agent_card',
65 | ]
66 |
--------------------------------------------------------------------------------
/src/a2a/client/auth/__init__.py:
--------------------------------------------------------------------------------
1 | """Client-side authentication components for the A2A Python SDK."""
2 |
3 | from a2a.client.auth.credentials import (
4 | CredentialService,
5 | InMemoryContextCredentialStore,
6 | )
7 | from a2a.client.auth.interceptor import AuthInterceptor
8 |
9 |
10 | __all__ = [
11 | 'AuthInterceptor',
12 | 'CredentialService',
13 | 'InMemoryContextCredentialStore',
14 | ]
15 |
--------------------------------------------------------------------------------
/src/a2a/client/auth/credentials.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | from a2a.client.middleware import ClientCallContext
4 |
5 |
6 | class CredentialService(ABC):
7 | """An abstract service for retrieving credentials."""
8 |
9 | @abstractmethod
10 | async def get_credentials(
11 | self,
12 | security_scheme_name: str,
13 | context: ClientCallContext | None,
14 | ) -> str | None:
15 | """
16 | Retrieves a credential (e.g., token) for a security scheme.
17 | """
18 |
19 |
20 | class InMemoryContextCredentialStore(CredentialService):
21 | """A simple in-memory store for session-keyed credentials.
22 |
23 | This class uses the 'sessionId' from the ClientCallContext state to
24 | store and retrieve credentials...
25 | """
26 |
27 | def __init__(self) -> None:
28 | self._store: dict[str, dict[str, str]] = {}
29 |
30 | async def get_credentials(
31 | self,
32 | security_scheme_name: str,
33 | context: ClientCallContext | None,
34 | ) -> str | None:
35 | """Retrieves credentials from the in-memory store.
36 |
37 | Args:
38 | security_scheme_name: The name of the security scheme.
39 | context: The client call context.
40 |
41 | Returns:
42 | The credential string, or None if not found.
43 | """
44 | if not context or 'sessionId' not in context.state:
45 | return None
46 | session_id = context.state['sessionId']
47 | return self._store.get(session_id, {}).get(security_scheme_name)
48 |
49 | async def set_credentials(
50 | self, session_id: str, security_scheme_name: str, credential: str
51 | ) -> None:
52 | """Method to populate the store."""
53 | if session_id not in self._store:
54 | self._store[session_id] = {}
55 | self._store[session_id][security_scheme_name] = credential
56 |
--------------------------------------------------------------------------------
/src/a2a/client/auth/interceptor.py:
--------------------------------------------------------------------------------
1 | import logging # noqa: I001
2 | from typing import Any
3 |
4 | from a2a.client.auth.credentials import CredentialService
5 | from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
6 | from a2a.types import (
7 | AgentCard,
8 | APIKeySecurityScheme,
9 | HTTPAuthSecurityScheme,
10 | In,
11 | OAuth2SecurityScheme,
12 | OpenIdConnectSecurityScheme,
13 | )
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 | class AuthInterceptor(ClientCallInterceptor):
19 | """An interceptor that automatically adds authentication details to requests.
20 |
21 | Based on the agent's security schemes.
22 | """
23 |
24 | def __init__(self, credential_service: CredentialService):
25 | self._credential_service = credential_service
26 |
27 | async def intercept(
28 | self,
29 | method_name: str,
30 | request_payload: dict[str, Any],
31 | http_kwargs: dict[str, Any],
32 | agent_card: AgentCard | None,
33 | context: ClientCallContext | None,
34 | ) -> tuple[dict[str, Any], dict[str, Any]]:
35 | """Applies authentication headers to the request if credentials are available."""
36 | if (
37 | agent_card is None
38 | or agent_card.security is None
39 | or agent_card.security_schemes is None
40 | ):
41 | return request_payload, http_kwargs
42 |
43 | for requirement in agent_card.security:
44 | for scheme_name in requirement:
45 | credential = await self._credential_service.get_credentials(
46 | scheme_name, context
47 | )
48 | if credential and scheme_name in agent_card.security_schemes:
49 | scheme_def_union = agent_card.security_schemes.get(
50 | scheme_name
51 | )
52 | if not scheme_def_union:
53 | continue
54 | scheme_def = scheme_def_union.root
55 |
56 | headers = http_kwargs.get('headers', {})
57 |
58 | match scheme_def:
59 | # Case 1a: HTTP Bearer scheme with an if guard
60 | case HTTPAuthSecurityScheme() if (
61 | scheme_def.scheme.lower() == 'bearer'
62 | ):
63 | headers['Authorization'] = f'Bearer {credential}'
64 | logger.debug(
65 | "Added Bearer token for scheme '%s' (type: %s).",
66 | scheme_name,
67 | scheme_def.type,
68 | )
69 | http_kwargs['headers'] = headers
70 | return request_payload, http_kwargs
71 |
72 | # Case 1b: OAuth2 and OIDC schemes, which are implicitly Bearer
73 | case (
74 | OAuth2SecurityScheme()
75 | | OpenIdConnectSecurityScheme()
76 | ):
77 | headers['Authorization'] = f'Bearer {credential}'
78 | logger.debug(
79 | "Added Bearer token for scheme '%s' (type: %s).",
80 | scheme_name,
81 | scheme_def.type,
82 | )
83 | http_kwargs['headers'] = headers
84 | return request_payload, http_kwargs
85 |
86 | # Case 2: API Key in Header
87 | case APIKeySecurityScheme(in_=In.header):
88 | headers[scheme_def.name] = credential
89 | logger.debug(
90 | "Added API Key Header for scheme '%s'.",
91 | scheme_name,
92 | )
93 | http_kwargs['headers'] = headers
94 | return request_payload, http_kwargs
95 |
96 | # Note: Other cases like API keys in query/cookie are not handled and will be skipped.
97 |
98 | return request_payload, http_kwargs
99 |
--------------------------------------------------------------------------------
/src/a2a/client/card_resolver.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 |
4 | from typing import Any
5 |
6 | import httpx
7 |
8 | from pydantic import ValidationError
9 |
10 | from a2a.client.errors import (
11 | A2AClientHTTPError,
12 | A2AClientJSONError,
13 | )
14 | from a2a.types import (
15 | AgentCard,
16 | )
17 | from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH
18 |
19 |
20 | logger = logging.getLogger(__name__)
21 |
22 |
23 | class A2ACardResolver:
24 | """Agent Card resolver."""
25 |
26 | def __init__(
27 | self,
28 | httpx_client: httpx.AsyncClient,
29 | base_url: str,
30 | agent_card_path: str = AGENT_CARD_WELL_KNOWN_PATH,
31 | ) -> None:
32 | """Initializes the A2ACardResolver.
33 |
34 | Args:
35 | httpx_client: An async HTTP client instance (e.g., httpx.AsyncClient).
36 | base_url: The base URL of the agent's host.
37 | agent_card_path: The path to the agent card endpoint, relative to the base URL.
38 | """
39 | self.base_url = base_url.rstrip('/')
40 | self.agent_card_path = agent_card_path.lstrip('/')
41 | self.httpx_client = httpx_client
42 |
43 | async def get_agent_card(
44 | self,
45 | relative_card_path: str | None = None,
46 | http_kwargs: dict[str, Any] | None = None,
47 | ) -> AgentCard:
48 | """Fetches an agent card from a specified path relative to the base_url.
49 |
50 | If relative_card_path is None, it defaults to the resolver's configured
51 | agent_card_path (for the public agent card).
52 |
53 | Args:
54 | relative_card_path: Optional path to the agent card endpoint,
55 | relative to the base URL. If None, uses the default public
56 | agent card path.
57 | http_kwargs: Optional dictionary of keyword arguments to pass to the
58 | underlying httpx.get request.
59 |
60 | Returns:
61 | An `AgentCard` object representing the agent's capabilities.
62 |
63 | Raises:
64 | A2AClientHTTPError: If an HTTP error occurs during the request.
65 | A2AClientJSONError: If the response body cannot be decoded as JSON
66 | or validated against the AgentCard schema.
67 | """
68 | if relative_card_path is None:
69 | # Use the default public agent card path configured during initialization
70 | path_segment = self.agent_card_path
71 | else:
72 | path_segment = relative_card_path.lstrip('/')
73 |
74 | target_url = f'{self.base_url}/{path_segment}'
75 |
76 | try:
77 | response = await self.httpx_client.get(
78 | target_url,
79 | **(http_kwargs or {}),
80 | )
81 | response.raise_for_status()
82 | agent_card_data = response.json()
83 | logger.info(
84 | 'Successfully fetched agent card data from %s: %s',
85 | target_url,
86 | agent_card_data,
87 | )
88 | agent_card = AgentCard.model_validate(agent_card_data)
89 | except httpx.HTTPStatusError as e:
90 | raise A2AClientHTTPError(
91 | e.response.status_code,
92 | f'Failed to fetch agent card from {target_url}: {e}',
93 | ) from e
94 | except json.JSONDecodeError as e:
95 | raise A2AClientJSONError(
96 | f'Failed to parse JSON for agent card from {target_url}: {e}'
97 | ) from e
98 | except httpx.RequestError as e:
99 | raise A2AClientHTTPError(
100 | 503,
101 | f'Network communication error fetching agent card from {target_url}: {e}',
102 | ) from e
103 | except ValidationError as e: # Pydantic validation error
104 | raise A2AClientJSONError(
105 | f'Failed to validate agent card structure from {target_url}: {e.json()}'
106 | ) from e
107 |
108 | return agent_card
109 |
--------------------------------------------------------------------------------
/src/a2a/client/errors.py:
--------------------------------------------------------------------------------
1 | """Custom exceptions for the A2A client."""
2 |
3 | from a2a.types import JSONRPCErrorResponse
4 |
5 |
6 | class A2AClientError(Exception):
7 | """Base exception for A2A Client errors."""
8 |
9 |
10 | class A2AClientHTTPError(A2AClientError):
11 | """Client exception for HTTP errors received from the server."""
12 |
13 | def __init__(self, status_code: int, message: str):
14 | """Initializes the A2AClientHTTPError.
15 |
16 | Args:
17 | status_code: The HTTP status code of the response.
18 | message: A descriptive error message.
19 | """
20 | self.status_code = status_code
21 | self.message = message
22 | super().__init__(f'HTTP Error {status_code}: {message}')
23 |
24 |
25 | class A2AClientJSONError(A2AClientError):
26 | """Client exception for JSON errors during response parsing or validation."""
27 |
28 | def __init__(self, message: str):
29 | """Initializes the A2AClientJSONError.
30 |
31 | Args:
32 | message: A descriptive error message.
33 | """
34 | self.message = message
35 | super().__init__(f'JSON Error: {message}')
36 |
37 |
38 | class A2AClientTimeoutError(A2AClientError):
39 | """Client exception for timeout errors during a request."""
40 |
41 | def __init__(self, message: str):
42 | """Initializes the A2AClientTimeoutError.
43 |
44 | Args:
45 | message: A descriptive error message.
46 | """
47 | self.message = message
48 | super().__init__(f'Timeout Error: {message}')
49 |
50 |
51 | class A2AClientInvalidArgsError(A2AClientError):
52 | """Client exception for invalid arguments passed to a method."""
53 |
54 | def __init__(self, message: str):
55 | """Initializes the A2AClientInvalidArgsError.
56 |
57 | Args:
58 | message: A descriptive error message.
59 | """
60 | self.message = message
61 | super().__init__(f'Invalid arguments error: {message}')
62 |
63 |
64 | class A2AClientInvalidStateError(A2AClientError):
65 | """Client exception for an invalid client state."""
66 |
67 | def __init__(self, message: str):
68 | """Initializes the A2AClientInvalidStateError.
69 |
70 | Args:
71 | message: A descriptive error message.
72 | """
73 | self.message = message
74 | super().__init__(f'Invalid state error: {message}')
75 |
76 |
77 | class A2AClientJSONRPCError(A2AClientError):
78 | """Client exception for JSON-RPC errors returned by the server."""
79 |
80 | def __init__(self, error: JSONRPCErrorResponse):
81 | """Initializes the A2AClientJsonRPCError.
82 |
83 | Args:
84 | error: The JSON-RPC error object.
85 | """
86 | self.error = error.error
87 | super().__init__(f'JSON-RPC Error {error.error}')
88 |
--------------------------------------------------------------------------------
/src/a2a/client/helpers.py:
--------------------------------------------------------------------------------
1 | """Helper functions for the A2A client."""
2 |
3 | from uuid import uuid4
4 |
5 | from a2a.types import Message, Part, Role, TextPart
6 |
7 |
8 | def create_text_message_object(
9 | role: Role = Role.user, content: str = ''
10 | ) -> Message:
11 | """Create a Message object containing a single TextPart.
12 |
13 | Args:
14 | role: The role of the message sender (user or agent). Defaults to Role.user.
15 | content: The text content of the message. Defaults to an empty string.
16 |
17 | Returns:
18 | A `Message` object with a new UUID message_id.
19 | """
20 | return Message(
21 | role=role, parts=[Part(TextPart(text=content))], message_id=str(uuid4())
22 | )
23 |
--------------------------------------------------------------------------------
/src/a2a/client/legacy_grpc.py:
--------------------------------------------------------------------------------
1 | """Backwards compatibility layer for the legacy A2A gRPC client."""
2 |
3 | import warnings
4 |
5 | from typing import TYPE_CHECKING
6 |
7 | from a2a.client.transports.grpc import GrpcTransport
8 | from a2a.types import AgentCard
9 |
10 |
11 | if TYPE_CHECKING:
12 | from a2a.grpc.a2a_pb2_grpc import A2AServiceStub
13 |
14 |
15 | class A2AGrpcClient(GrpcTransport):
16 | """[DEPRECATED] Backwards compatibility wrapper for the gRPC client."""
17 |
18 | def __init__( # pylint: disable=super-init-not-called
19 | self,
20 | grpc_stub: 'A2AServiceStub',
21 | agent_card: AgentCard,
22 | ):
23 | warnings.warn(
24 | 'A2AGrpcClient is deprecated and will be removed in a future version. '
25 | 'Use ClientFactory to create a client with a gRPC transport.',
26 | DeprecationWarning,
27 | stacklevel=2,
28 | )
29 | # The old gRPC client accepted a stub directly. The new one accepts a
30 | # channel and builds the stub itself. We just have a stub here, so we
31 | # need to handle initialization ourselves.
32 | self.stub = grpc_stub
33 | self.agent_card = agent_card
34 | self._needs_extended_card = (
35 | agent_card.supports_authenticated_extended_card
36 | if agent_card
37 | else True
38 | )
39 |
40 | class _NopChannel:
41 | async def close(self) -> None:
42 | pass
43 |
44 | self.channel = _NopChannel()
45 |
--------------------------------------------------------------------------------
/src/a2a/client/middleware.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from abc import ABC, abstractmethod
4 | from collections.abc import MutableMapping # noqa: TC003
5 | from typing import TYPE_CHECKING, Any
6 |
7 | from pydantic import BaseModel, Field
8 |
9 |
10 | if TYPE_CHECKING:
11 | from a2a.types import AgentCard
12 |
13 |
14 | class ClientCallContext(BaseModel):
15 | """A context passed with each client call, allowing for call-specific.
16 |
17 | configuration and data passing. Such as authentication details or
18 | request deadlines.
19 | """
20 |
21 | state: MutableMapping[str, Any] = Field(default_factory=dict)
22 |
23 |
24 | class ClientCallInterceptor(ABC):
25 | """An abstract base class for client-side call interceptors.
26 |
27 | Interceptors can inspect and modify requests before they are sent,
28 | which is ideal for concerns like authentication, logging, or tracing.
29 | """
30 |
31 | @abstractmethod
32 | async def intercept(
33 | self,
34 | method_name: str,
35 | request_payload: dict[str, Any],
36 | http_kwargs: dict[str, Any],
37 | agent_card: AgentCard | None,
38 | context: ClientCallContext | None,
39 | ) -> tuple[dict[str, Any], dict[str, Any]]:
40 | """
41 | Intercepts a client call before the request is sent.
42 |
43 | Args:
44 | method_name: The name of the RPC method (e.g., 'message/send').
45 | request_payload: The JSON RPC request payload dictionary.
46 | http_kwargs: The keyword arguments for the httpx request.
47 | agent_card: The AgentCard associated with the client.
48 | context: The ClientCallContext for this specific call.
49 |
50 | Returns:
51 | A tuple containing the (potentially modified) request_payload
52 | and http_kwargs.
53 | """
54 |
--------------------------------------------------------------------------------
/src/a2a/client/optionals.py:
--------------------------------------------------------------------------------
1 | from typing import TYPE_CHECKING
2 |
3 |
4 | # Attempt to import the optional module
5 | try:
6 | from grpc.aio import Channel # pyright: ignore[reportAssignmentType]
7 | except ImportError:
8 | # If grpc.aio is not available, define a dummy type for type checking.
9 | # This dummy type will only be used by type checkers.
10 | if TYPE_CHECKING:
11 |
12 | class Channel: # type: ignore[no-redef]
13 | """Dummy class for type hinting when grpc.aio is not available."""
14 |
15 | else:
16 | Channel = None # At runtime, pd will be None if the import failed.
17 |
--------------------------------------------------------------------------------
/src/a2a/client/transports/__init__.py:
--------------------------------------------------------------------------------
1 | """A2A Client Transports."""
2 |
3 | from a2a.client.transports.base import ClientTransport
4 | from a2a.client.transports.jsonrpc import JsonRpcTransport
5 | from a2a.client.transports.rest import RestTransport
6 |
7 |
8 | try:
9 | from a2a.client.transports.grpc import GrpcTransport
10 | except ImportError:
11 | GrpcTransport = None # type: ignore
12 |
13 |
14 | __all__ = [
15 | 'ClientTransport',
16 | 'GrpcTransport',
17 | 'JsonRpcTransport',
18 | 'RestTransport',
19 | ]
20 |
--------------------------------------------------------------------------------
/src/a2a/client/transports/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from collections.abc import AsyncGenerator
3 |
4 | from a2a.client.middleware import ClientCallContext
5 | from a2a.types import (
6 | AgentCard,
7 | GetTaskPushNotificationConfigParams,
8 | Message,
9 | MessageSendParams,
10 | Task,
11 | TaskArtifactUpdateEvent,
12 | TaskIdParams,
13 | TaskPushNotificationConfig,
14 | TaskQueryParams,
15 | TaskStatusUpdateEvent,
16 | )
17 |
18 |
19 | class ClientTransport(ABC):
20 | """Abstract base class for a client transport."""
21 |
22 | @abstractmethod
23 | async def send_message(
24 | self,
25 | request: MessageSendParams,
26 | *,
27 | context: ClientCallContext | None = None,
28 | ) -> Task | Message:
29 | """Sends a non-streaming message request to the agent."""
30 |
31 | @abstractmethod
32 | async def send_message_streaming(
33 | self,
34 | request: MessageSendParams,
35 | *,
36 | context: ClientCallContext | None = None,
37 | ) -> AsyncGenerator[
38 | Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
39 | ]:
40 | """Sends a streaming message request to the agent and yields responses as they arrive."""
41 | return
42 | yield
43 |
44 | @abstractmethod
45 | async def get_task(
46 | self,
47 | request: TaskQueryParams,
48 | *,
49 | context: ClientCallContext | None = None,
50 | ) -> Task:
51 | """Retrieves the current state and history of a specific task."""
52 |
53 | @abstractmethod
54 | async def cancel_task(
55 | self,
56 | request: TaskIdParams,
57 | *,
58 | context: ClientCallContext | None = None,
59 | ) -> Task:
60 | """Requests the agent to cancel a specific task."""
61 |
62 | @abstractmethod
63 | async def set_task_callback(
64 | self,
65 | request: TaskPushNotificationConfig,
66 | *,
67 | context: ClientCallContext | None = None,
68 | ) -> TaskPushNotificationConfig:
69 | """Sets or updates the push notification configuration for a specific task."""
70 |
71 | @abstractmethod
72 | async def get_task_callback(
73 | self,
74 | request: GetTaskPushNotificationConfigParams,
75 | *,
76 | context: ClientCallContext | None = None,
77 | ) -> TaskPushNotificationConfig:
78 | """Retrieves the push notification configuration for a specific task."""
79 |
80 | @abstractmethod
81 | async def resubscribe(
82 | self,
83 | request: TaskIdParams,
84 | *,
85 | context: ClientCallContext | None = None,
86 | ) -> AsyncGenerator[
87 | Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
88 | ]:
89 | """Reconnects to get task updates."""
90 | return
91 | yield
92 |
93 | @abstractmethod
94 | async def get_card(
95 | self,
96 | *,
97 | context: ClientCallContext | None = None,
98 | ) -> AgentCard:
99 | """Retrieves the AgentCard."""
100 |
101 | @abstractmethod
102 | async def close(self) -> None:
103 | """Closes the transport."""
104 |
--------------------------------------------------------------------------------
/src/a2a/extensions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a2aproject/a2a-python/dca66c3100a2b9701a1c8b65ad6853769eefd511/src/a2a/extensions/__init__.py
--------------------------------------------------------------------------------
/src/a2a/extensions/common.py:
--------------------------------------------------------------------------------
1 | from a2a.types import AgentCard, AgentExtension
2 |
3 |
4 | HTTP_EXTENSION_HEADER = 'X-A2A-Extensions'
5 |
6 |
7 | def get_requested_extensions(values: list[str]) -> set[str]:
8 | """Get the set of requested extensions from an input list.
9 |
10 | This handles the list containing potentially comma-separated values, as
11 | occurs when using a list in an HTTP header.
12 | """
13 | return {
14 | stripped
15 | for v in values
16 | for ext in v.split(',')
17 | if (stripped := ext.strip())
18 | }
19 |
20 |
21 | def find_extension_by_uri(card: AgentCard, uri: str) -> AgentExtension | None:
22 | """Find an AgentExtension in an AgentCard given a uri."""
23 | for ext in card.capabilities.extensions or []:
24 | if ext.uri == uri:
25 | return ext
26 |
27 | return None
28 |
--------------------------------------------------------------------------------
/src/a2a/grpc/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a2aproject/a2a-python/dca66c3100a2b9701a1c8b65ad6853769eefd511/src/a2a/grpc/__init__.py
--------------------------------------------------------------------------------
/src/a2a/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a2aproject/a2a-python/dca66c3100a2b9701a1c8b65ad6853769eefd511/src/a2a/py.typed
--------------------------------------------------------------------------------
/src/a2a/server/__init__.py:
--------------------------------------------------------------------------------
1 | """Server-side components for implementing an A2A agent."""
2 |
--------------------------------------------------------------------------------
/src/a2a/server/agent_execution/__init__.py:
--------------------------------------------------------------------------------
1 | """Components for executing agent logic within the A2A server."""
2 |
3 | from a2a.server.agent_execution.agent_executor import AgentExecutor
4 | from a2a.server.agent_execution.context import RequestContext
5 | from a2a.server.agent_execution.request_context_builder import (
6 | RequestContextBuilder,
7 | )
8 | from a2a.server.agent_execution.simple_request_context_builder import (
9 | SimpleRequestContextBuilder,
10 | )
11 |
12 |
13 | __all__ = [
14 | 'AgentExecutor',
15 | 'RequestContext',
16 | 'RequestContextBuilder',
17 | 'SimpleRequestContextBuilder',
18 | ]
19 |
--------------------------------------------------------------------------------
/src/a2a/server/agent_execution/agent_executor.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | from a2a.server.agent_execution.context import RequestContext
4 | from a2a.server.events.event_queue import EventQueue
5 |
6 |
7 | class AgentExecutor(ABC):
8 | """Agent Executor interface.
9 |
10 | Implementations of this interface contain the core logic of the agent,
11 | executing tasks based on requests and publishing updates to an event queue.
12 | """
13 |
14 | @abstractmethod
15 | async def execute(
16 | self, context: RequestContext, event_queue: EventQueue
17 | ) -> None:
18 | """Execute the agent's logic for a given request context.
19 |
20 | The agent should read necessary information from the `context` and
21 | publish `Task` or `Message` events, or `TaskStatusUpdateEvent` /
22 | `TaskArtifactUpdateEvent` to the `event_queue`. This method should
23 | return once the agent's execution for this request is complete or
24 | yields control (e.g., enters an input-required state).
25 |
26 | Args:
27 | context: The request context containing the message, task ID, etc.
28 | event_queue: The queue to publish events to.
29 | """
30 |
31 | @abstractmethod
32 | async def cancel(
33 | self, context: RequestContext, event_queue: EventQueue
34 | ) -> None:
35 | """Request the agent to cancel an ongoing task.
36 |
37 | The agent should attempt to stop the task identified by the task_id
38 | in the context and publish a `TaskStatusUpdateEvent` with state
39 | `TaskState.canceled` to the `event_queue`.
40 |
41 | Args:
42 | context: The request context containing the task ID to cancel.
43 | event_queue: The queue to publish the cancellation status update to.
44 | """
45 |
--------------------------------------------------------------------------------
/src/a2a/server/agent_execution/request_context_builder.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | from a2a.server.agent_execution import RequestContext
4 | from a2a.server.context import ServerCallContext
5 | from a2a.types import MessageSendParams, Task
6 |
7 |
8 | class RequestContextBuilder(ABC):
9 | """Builds request context to be supplied to agent executor."""
10 |
11 | @abstractmethod
12 | async def build(
13 | self,
14 | params: MessageSendParams | None = None,
15 | task_id: str | None = None,
16 | context_id: str | None = None,
17 | task: Task | None = None,
18 | context: ServerCallContext | None = None,
19 | ) -> RequestContext:
20 | pass
21 |
--------------------------------------------------------------------------------
/src/a2a/server/agent_execution/simple_request_context_builder.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 |
3 | from a2a.server.agent_execution import RequestContext, RequestContextBuilder
4 | from a2a.server.context import ServerCallContext
5 | from a2a.server.tasks import TaskStore
6 | from a2a.types import MessageSendParams, Task
7 |
8 |
9 | class SimpleRequestContextBuilder(RequestContextBuilder):
10 | """Builds request context and populates referred tasks."""
11 |
12 | def __init__(
13 | self,
14 | should_populate_referred_tasks: bool = False,
15 | task_store: TaskStore | None = None,
16 | ) -> None:
17 | """Initializes the SimpleRequestContextBuilder.
18 |
19 | Args:
20 | should_populate_referred_tasks: If True, the builder will fetch tasks
21 | referenced in `params.message.reference_task_ids` and populate the
22 | `related_tasks` field in the RequestContext. Defaults to False.
23 | task_store: The TaskStore instance to use for fetching referred tasks.
24 | Required if `should_populate_referred_tasks` is True.
25 | """
26 | self._task_store = task_store
27 | self._should_populate_referred_tasks = should_populate_referred_tasks
28 |
29 | async def build(
30 | self,
31 | params: MessageSendParams | None = None,
32 | task_id: str | None = None,
33 | context_id: str | None = None,
34 | task: Task | None = None,
35 | context: ServerCallContext | None = None,
36 | ) -> RequestContext:
37 | """Builds the request context for an agent execution.
38 |
39 | This method assembles the RequestContext object. If the builder was
40 | initialized with `should_populate_referred_tasks=True`, it fetches all tasks
41 | referenced in `params.message.reference_task_ids` from the `task_store`.
42 |
43 | Args:
44 | params: The parameters of the incoming message send request.
45 | task_id: The ID of the task being executed.
46 | context_id: The ID of the current execution context.
47 | task: The primary task object associated with the request.
48 | context: The server call context, containing metadata about the call.
49 |
50 | Returns:
51 | An instance of RequestContext populated with the provided information
52 | and potentially a list of related tasks.
53 | """
54 | related_tasks: list[Task] | None = None
55 |
56 | if (
57 | self._task_store
58 | and self._should_populate_referred_tasks
59 | and params
60 | and params.message.reference_task_ids
61 | ):
62 | tasks = await asyncio.gather(
63 | *[
64 | self._task_store.get(task_id)
65 | for task_id in params.message.reference_task_ids
66 | ]
67 | )
68 | related_tasks = [x for x in tasks if x is not None]
69 |
70 | return RequestContext(
71 | request=params,
72 | task_id=task_id,
73 | context_id=context_id,
74 | task=task,
75 | related_tasks=related_tasks,
76 | call_context=context,
77 | )
78 |
--------------------------------------------------------------------------------
/src/a2a/server/apps/__init__.py:
--------------------------------------------------------------------------------
1 | """HTTP application components for the A2A server."""
2 |
3 | from a2a.server.apps.jsonrpc import (
4 | A2AFastAPIApplication,
5 | A2AStarletteApplication,
6 | CallContextBuilder,
7 | JSONRPCApplication,
8 | )
9 | from a2a.server.apps.rest import A2ARESTFastAPIApplication
10 |
11 |
12 | __all__ = [
13 | 'A2AFastAPIApplication',
14 | 'A2ARESTFastAPIApplication',
15 | 'A2AStarletteApplication',
16 | 'CallContextBuilder',
17 | 'JSONRPCApplication',
18 | ]
19 |
--------------------------------------------------------------------------------
/src/a2a/server/apps/jsonrpc/__init__.py:
--------------------------------------------------------------------------------
1 | """A2A JSON-RPC Applications."""
2 |
3 | from a2a.server.apps.jsonrpc.fastapi_app import A2AFastAPIApplication
4 | from a2a.server.apps.jsonrpc.jsonrpc_app import (
5 | CallContextBuilder,
6 | DefaultCallContextBuilder,
7 | JSONRPCApplication,
8 | StarletteUserProxy,
9 | )
10 | from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication
11 |
12 |
13 | __all__ = [
14 | 'A2AFastAPIApplication',
15 | 'A2AStarletteApplication',
16 | 'CallContextBuilder',
17 | 'DefaultCallContextBuilder',
18 | 'JSONRPCApplication',
19 | 'StarletteUserProxy',
20 | ]
21 |
--------------------------------------------------------------------------------
/src/a2a/server/apps/rest/__init__.py:
--------------------------------------------------------------------------------
1 | """A2A REST Applications."""
2 |
3 | from a2a.server.apps.rest.fastapi_app import A2ARESTFastAPIApplication
4 |
5 |
6 | __all__ = [
7 | 'A2ARESTFastAPIApplication',
8 | ]
9 |
--------------------------------------------------------------------------------
/src/a2a/server/apps/rest/fastapi_app.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from collections.abc import Callable
4 | from typing import TYPE_CHECKING, Any
5 |
6 |
7 | if TYPE_CHECKING:
8 | from fastapi import APIRouter, FastAPI, Request, Response
9 | from fastapi.responses import JSONResponse
10 |
11 | _package_fastapi_installed = True
12 | else:
13 | try:
14 | from fastapi import APIRouter, FastAPI, Request, Response
15 | from fastapi.responses import JSONResponse
16 |
17 | _package_fastapi_installed = True
18 | except ImportError:
19 | APIRouter = Any
20 | FastAPI = Any
21 | Request = Any
22 | Response = Any
23 |
24 | _package_fastapi_installed = False
25 |
26 |
27 | from a2a.server.apps.jsonrpc.jsonrpc_app import CallContextBuilder
28 | from a2a.server.apps.rest.rest_adapter import RESTAdapter
29 | from a2a.server.context import ServerCallContext
30 | from a2a.server.request_handlers.request_handler import RequestHandler
31 | from a2a.types import AgentCard
32 | from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH
33 |
34 |
35 | logger = logging.getLogger(__name__)
36 |
37 |
38 | class A2ARESTFastAPIApplication:
39 | """A FastAPI application implementing the A2A protocol server REST endpoints.
40 |
41 | Handles incoming REST requests, routes them to the appropriate
42 | handler methods, and manages response generation including Server-Sent Events
43 | (SSE).
44 | """
45 |
46 | def __init__( # noqa: PLR0913
47 | self,
48 | agent_card: AgentCard,
49 | http_handler: RequestHandler,
50 | extended_agent_card: AgentCard | None = None,
51 | context_builder: CallContextBuilder | None = None,
52 | card_modifier: Callable[[AgentCard], AgentCard] | None = None,
53 | extended_card_modifier: Callable[
54 | [AgentCard, ServerCallContext], AgentCard
55 | ]
56 | | None = None,
57 | ):
58 | """Initializes the A2ARESTFastAPIApplication.
59 |
60 | Args:
61 | agent_card: The AgentCard describing the agent's capabilities.
62 | http_handler: The handler instance responsible for processing A2A
63 | requests via http.
64 | extended_agent_card: An optional, distinct AgentCard to be served
65 | at the authenticated extended card endpoint.
66 | context_builder: The CallContextBuilder used to construct the
67 | ServerCallContext passed to the http_handler. If None, no
68 | ServerCallContext is passed.
69 | card_modifier: An optional callback to dynamically modify the public
70 | agent card before it is served.
71 | extended_card_modifier: An optional callback to dynamically modify
72 | the extended agent card before it is served. It receives the
73 | call context.
74 | """
75 | if not _package_fastapi_installed:
76 | raise ImportError(
77 | 'The `fastapi` package is required to use the'
78 | ' `A2ARESTFastAPIApplication`. It can be added as a part of'
79 | ' `a2a-sdk` optional dependencies, `a2a-sdk[http-server]`.'
80 | )
81 | self._adapter = RESTAdapter(
82 | agent_card=agent_card,
83 | http_handler=http_handler,
84 | extended_agent_card=extended_agent_card,
85 | context_builder=context_builder,
86 | card_modifier=card_modifier,
87 | extended_card_modifier=extended_card_modifier,
88 | )
89 |
90 | def build(
91 | self,
92 | agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH,
93 | rpc_url: str = '',
94 | **kwargs: Any,
95 | ) -> FastAPI:
96 | """Builds and returns the FastAPI application instance.
97 |
98 | Args:
99 | agent_card_url: The URL for the agent card endpoint.
100 | rpc_url: The URL for the A2A JSON-RPC endpoint.
101 | extended_agent_card_url: The URL for the authenticated extended agent card endpoint.
102 | **kwargs: Additional keyword arguments to pass to the FastAPI constructor.
103 |
104 | Returns:
105 | A configured FastAPI application instance.
106 | """
107 | app = FastAPI(**kwargs)
108 | router = APIRouter()
109 | for route, callback in self._adapter.routes().items():
110 | router.add_api_route(
111 | f'{rpc_url}{route[0]}', callback, methods=[route[1]]
112 | )
113 |
114 | @router.get(f'{rpc_url}{agent_card_url}')
115 | async def get_agent_card(request: Request) -> Response:
116 | card = await self._adapter.handle_get_agent_card(request)
117 | return JSONResponse(card)
118 |
119 | app.include_router(router)
120 | return app
121 |
--------------------------------------------------------------------------------
/src/a2a/server/context.py:
--------------------------------------------------------------------------------
1 | """Defines the ServerCallContext class."""
2 |
3 | import collections.abc
4 | import typing
5 |
6 | from pydantic import BaseModel, ConfigDict, Field
7 |
8 | from a2a.auth.user import UnauthenticatedUser, User
9 |
10 |
11 | State = collections.abc.MutableMapping[str, typing.Any]
12 |
13 |
14 | class ServerCallContext(BaseModel):
15 | """A context passed when calling a server method.
16 |
17 | This class allows storing arbitrary user data in the state attribute.
18 | """
19 |
20 | model_config = ConfigDict(arbitrary_types_allowed=True)
21 |
22 | state: State = Field(default={})
23 | user: User = Field(default=UnauthenticatedUser())
24 | requested_extensions: set[str] = Field(default_factory=set)
25 | activated_extensions: set[str] = Field(default_factory=set)
26 |
--------------------------------------------------------------------------------
/src/a2a/server/events/__init__.py:
--------------------------------------------------------------------------------
1 | """Event handling components for the A2A server."""
2 |
3 | from a2a.server.events.event_consumer import EventConsumer
4 | from a2a.server.events.event_queue import Event, EventQueue
5 | from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager
6 | from a2a.server.events.queue_manager import (
7 | NoTaskQueue,
8 | QueueManager,
9 | TaskQueueExists,
10 | )
11 |
12 |
13 | __all__ = [
14 | 'Event',
15 | 'EventConsumer',
16 | 'EventQueue',
17 | 'InMemoryQueueManager',
18 | 'NoTaskQueue',
19 | 'QueueManager',
20 | 'TaskQueueExists',
21 | ]
22 |
--------------------------------------------------------------------------------
/src/a2a/server/events/event_consumer.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import logging
3 | import sys
4 |
5 | from collections.abc import AsyncGenerator
6 |
7 | from pydantic import ValidationError
8 |
9 | from a2a.server.events.event_queue import Event, EventQueue
10 | from a2a.types import (
11 | InternalError,
12 | Message,
13 | Task,
14 | TaskState,
15 | TaskStatusUpdateEvent,
16 | )
17 | from a2a.utils.errors import ServerError
18 | from a2a.utils.telemetry import SpanKind, trace_class
19 |
20 |
21 | # This is an alias to the exception for closed queue
22 | QueueClosed: type[Exception] = asyncio.QueueEmpty
23 |
24 | # When using python 3.13 or higher, the closed queue signal is QueueShutdown
25 | if sys.version_info >= (3, 13):
26 | QueueClosed = asyncio.QueueShutDown
27 |
28 | logger = logging.getLogger(__name__)
29 |
30 |
31 | @trace_class(kind=SpanKind.SERVER)
32 | class EventConsumer:
33 | """Consumer to read events from the agent event queue."""
34 |
35 | def __init__(self, queue: EventQueue):
36 | """Initializes the EventConsumer.
37 |
38 | Args:
39 | queue: The `EventQueue` instance to consume events from.
40 | """
41 | self.queue = queue
42 | self._timeout = 0.5
43 | self._exception: BaseException | None = None
44 | logger.debug('EventConsumer initialized')
45 |
46 | async def consume_one(self) -> Event:
47 | """Consume one event from the agent event queue non-blocking.
48 |
49 | Returns:
50 | The next event from the queue.
51 |
52 | Raises:
53 | ServerError: If the queue is empty when attempting to dequeue
54 | immediately.
55 | """
56 | logger.debug('Attempting to consume one event.')
57 | try:
58 | event = await self.queue.dequeue_event(no_wait=True)
59 | except asyncio.QueueEmpty as e:
60 | logger.warning('Event queue was empty in consume_one.')
61 | raise ServerError(
62 | InternalError(message='Agent did not return any response')
63 | ) from e
64 |
65 | logger.debug('Dequeued event of type: %s in consume_one.', type(event))
66 |
67 | self.queue.task_done()
68 |
69 | return event
70 |
71 | async def consume_all(self) -> AsyncGenerator[Event]:
72 | """Consume all the generated streaming events from the agent.
73 |
74 | This method yields events as they become available from the queue
75 | until a final event is received or the queue is closed. It also
76 | monitors for exceptions set by the `agent_task_callback`.
77 |
78 | Yields:
79 | Events dequeued from the queue.
80 |
81 | Raises:
82 | BaseException: If an exception was set by the `agent_task_callback`.
83 | """
84 | logger.debug('Starting to consume all events from the queue.')
85 | while True:
86 | if self._exception:
87 | raise self._exception
88 | try:
89 | # We use a timeout when waiting for an event from the queue.
90 | # This is required because it allows the loop to check if
91 | # `self._exception` has been set by the `agent_task_callback`.
92 | # Without the timeout, loop might hang indefinitely if no events are
93 | # enqueued by the agent and the agent simply threw an exception
94 | event = await asyncio.wait_for(
95 | self.queue.dequeue_event(), timeout=self._timeout
96 | )
97 | logger.debug(
98 | 'Dequeued event of type: %s in consume_all.', type(event)
99 | )
100 | self.queue.task_done()
101 | logger.debug(
102 | 'Marked task as done in event queue in consume_all'
103 | )
104 |
105 | is_final_event = (
106 | (isinstance(event, TaskStatusUpdateEvent) and event.final)
107 | or isinstance(event, Message)
108 | or (
109 | isinstance(event, Task)
110 | and event.status.state
111 | in (
112 | TaskState.completed,
113 | TaskState.canceled,
114 | TaskState.failed,
115 | TaskState.rejected,
116 | TaskState.unknown,
117 | TaskState.input_required,
118 | )
119 | )
120 | )
121 |
122 | # Make sure the yield is after the close events, otherwise
123 | # the caller may end up in a blocked state where this
124 | # generator isn't called again to close things out and the
125 | # other part is waiting for an event or a closed queue.
126 | if is_final_event:
127 | logger.debug('Stopping event consumption in consume_all.')
128 | await self.queue.close(True)
129 | yield event
130 | break
131 | yield event
132 | except TimeoutError:
133 | # continue polling until there is a final event
134 | continue
135 | except asyncio.TimeoutError: # pyright: ignore [reportUnusedExcept]
136 | # This class was made an alias of built-in TimeoutError after 3.11
137 | continue
138 | except (QueueClosed, asyncio.QueueEmpty):
139 | # Confirm that the queue is closed, e.g. we aren't on
140 | # python 3.12 and get a queue empty error on an open queue
141 | if self.queue.is_closed():
142 | break
143 | except ValidationError:
144 | logger.exception('Invalid event format received')
145 | continue
146 | except Exception as e:
147 | logger.exception('Stopping event consumption due to exception')
148 | self._exception = e
149 | continue
150 |
151 | def agent_task_callback(self, agent_task: asyncio.Task[None]) -> None:
152 | """Callback to handle exceptions from the agent's execution task.
153 |
154 | If the agent's asyncio task raises an exception, this callback is
155 | invoked, and the exception is stored to be re-raised by the consumer loop.
156 |
157 | Args:
158 | agent_task: The asyncio.Task that completed.
159 | """
160 | logger.debug('Agent task callback triggered.')
161 | if not agent_task.cancelled() and agent_task.done():
162 | self._exception = agent_task.exception()
163 |
--------------------------------------------------------------------------------
/src/a2a/server/events/in_memory_queue_manager.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 |
3 | from a2a.server.events.event_queue import EventQueue
4 | from a2a.server.events.queue_manager import (
5 | NoTaskQueue,
6 | QueueManager,
7 | TaskQueueExists,
8 | )
9 | from a2a.utils.telemetry import SpanKind, trace_class
10 |
11 |
12 | @trace_class(kind=SpanKind.SERVER)
13 | class InMemoryQueueManager(QueueManager):
14 | """InMemoryQueueManager is used for a single binary management.
15 |
16 | This implements the `QueueManager` interface using in-memory storage for event
17 | queues. It requires all incoming interactions for a given task ID to hit the
18 | same binary instance.
19 |
20 | This implementation is suitable for single-instance deployments but needs
21 | a distributed approach for scalable deployments.
22 | """
23 |
24 | def __init__(self) -> None:
25 | """Initializes the InMemoryQueueManager."""
26 | self._task_queue: dict[str, EventQueue] = {}
27 | self._lock = asyncio.Lock()
28 |
29 | async def add(self, task_id: str, queue: EventQueue) -> None:
30 | """Adds a new event queue for a task ID.
31 |
32 | Raises:
33 | TaskQueueExists: If a queue for the given `task_id` already exists.
34 | """
35 | async with self._lock:
36 | if task_id in self._task_queue:
37 | raise TaskQueueExists
38 | self._task_queue[task_id] = queue
39 |
40 | async def get(self, task_id: str) -> EventQueue | None:
41 | """Retrieves the event queue for a task ID.
42 |
43 | Returns:
44 | The `EventQueue` instance for the `task_id`, or `None` if not found.
45 | """
46 | async with self._lock:
47 | if task_id not in self._task_queue:
48 | return None
49 | return self._task_queue[task_id]
50 |
51 | async def tap(self, task_id: str) -> EventQueue | None:
52 | """Taps the event queue for a task ID to create a child queue.
53 |
54 | Returns:
55 | A new child `EventQueue` instance, or `None` if the task ID is not found.
56 | """
57 | async with self._lock:
58 | if task_id not in self._task_queue:
59 | return None
60 | return self._task_queue[task_id].tap()
61 |
62 | async def close(self, task_id: str) -> None:
63 | """Closes and removes the event queue for a task ID.
64 |
65 | Raises:
66 | NoTaskQueue: If no queue exists for the given `task_id`.
67 | """
68 | async with self._lock:
69 | if task_id not in self._task_queue:
70 | raise NoTaskQueue
71 | queue = self._task_queue.pop(task_id)
72 | await queue.close()
73 |
74 | async def create_or_tap(self, task_id: str) -> EventQueue:
75 | """Creates a new event queue for a task ID if one doesn't exist, otherwise taps the existing one.
76 |
77 | Returns:
78 | A new or child `EventQueue` instance for the `task_id`.
79 | """
80 | async with self._lock:
81 | if task_id not in self._task_queue:
82 | queue = EventQueue()
83 | self._task_queue[task_id] = queue
84 | return queue
85 | return self._task_queue[task_id].tap()
86 |
--------------------------------------------------------------------------------
/src/a2a/server/events/queue_manager.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | from a2a.server.events.event_queue import EventQueue
4 |
5 |
6 | class QueueManager(ABC):
7 | """Interface for managing the event queue lifecycles per task."""
8 |
9 | @abstractmethod
10 | async def add(self, task_id: str, queue: EventQueue) -> None:
11 | """Adds a new event queue associated with a task ID."""
12 |
13 | @abstractmethod
14 | async def get(self, task_id: str) -> EventQueue | None:
15 | """Retrieves the event queue for a task ID."""
16 |
17 | @abstractmethod
18 | async def tap(self, task_id: str) -> EventQueue | None:
19 | """Creates a child event queue (tap) for an existing task ID."""
20 |
21 | @abstractmethod
22 | async def close(self, task_id: str) -> None:
23 | """Closes and removes the event queue for a task ID."""
24 |
25 | @abstractmethod
26 | async def create_or_tap(self, task_id: str) -> EventQueue:
27 | """Creates a queue if one doesn't exist, otherwise taps the existing one."""
28 |
29 |
30 | class TaskQueueExists(Exception): # noqa: N818
31 | """Exception raised when attempting to add a queue for a task ID that already exists."""
32 |
33 |
34 | class NoTaskQueue(Exception): # noqa: N818
35 | """Exception raised when attempting to access or close a queue for a task ID that does not exist."""
36 |
--------------------------------------------------------------------------------
/src/a2a/server/request_handlers/__init__.py:
--------------------------------------------------------------------------------
1 | """Request handler components for the A2A server."""
2 |
3 | import logging
4 |
5 | from a2a.server.request_handlers.default_request_handler import (
6 | DefaultRequestHandler,
7 | )
8 | from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler
9 | from a2a.server.request_handlers.request_handler import RequestHandler
10 | from a2a.server.request_handlers.response_helpers import (
11 | build_error_response,
12 | prepare_response_object,
13 | )
14 | from a2a.server.request_handlers.rest_handler import RESTHandler
15 |
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 | try:
20 | from a2a.server.request_handlers.grpc_handler import (
21 | GrpcHandler, # type: ignore
22 | )
23 | except ImportError as e:
24 | _original_error = e
25 | logger.debug(
26 | 'GrpcHandler not loaded. This is expected if gRPC dependencies are not installed. Error: %s',
27 | _original_error,
28 | )
29 |
30 | class GrpcHandler: # type: ignore
31 | """Placeholder for GrpcHandler when dependencies are not installed."""
32 |
33 | def __init__(self, *args, **kwargs):
34 | raise ImportError(
35 | 'To use GrpcHandler, its dependencies must be installed. '
36 | 'You can install them with \'pip install "a2a-sdk[grpc]"\''
37 | ) from _original_error
38 |
39 |
40 | __all__ = [
41 | 'DefaultRequestHandler',
42 | 'GrpcHandler',
43 | 'JSONRPCHandler',
44 | 'RESTHandler',
45 | 'RequestHandler',
46 | 'build_error_response',
47 | 'prepare_response_object',
48 | ]
49 |
--------------------------------------------------------------------------------
/src/a2a/server/request_handlers/response_helpers.py:
--------------------------------------------------------------------------------
1 | """Helper functions for building A2A JSON-RPC responses."""
2 |
3 | # response types
4 | from typing import TypeVar
5 |
6 | from a2a.types import (
7 | A2AError,
8 | CancelTaskResponse,
9 | CancelTaskSuccessResponse,
10 | DeleteTaskPushNotificationConfigResponse,
11 | DeleteTaskPushNotificationConfigSuccessResponse,
12 | GetTaskPushNotificationConfigResponse,
13 | GetTaskPushNotificationConfigSuccessResponse,
14 | GetTaskResponse,
15 | GetTaskSuccessResponse,
16 | InvalidAgentResponseError,
17 | JSONRPCError,
18 | JSONRPCErrorResponse,
19 | ListTaskPushNotificationConfigResponse,
20 | ListTaskPushNotificationConfigSuccessResponse,
21 | Message,
22 | SendMessageResponse,
23 | SendMessageSuccessResponse,
24 | SendStreamingMessageResponse,
25 | SendStreamingMessageSuccessResponse,
26 | SetTaskPushNotificationConfigResponse,
27 | SetTaskPushNotificationConfigSuccessResponse,
28 | Task,
29 | TaskArtifactUpdateEvent,
30 | TaskPushNotificationConfig,
31 | TaskStatusUpdateEvent,
32 | )
33 |
34 |
35 | RT = TypeVar(
36 | 'RT',
37 | GetTaskResponse,
38 | CancelTaskResponse,
39 | SendMessageResponse,
40 | SetTaskPushNotificationConfigResponse,
41 | GetTaskPushNotificationConfigResponse,
42 | SendStreamingMessageResponse,
43 | ListTaskPushNotificationConfigResponse,
44 | DeleteTaskPushNotificationConfigResponse,
45 | )
46 | """Type variable for RootModel response types."""
47 |
48 | # success types
49 | SPT = TypeVar(
50 | 'SPT',
51 | GetTaskSuccessResponse,
52 | CancelTaskSuccessResponse,
53 | SendMessageSuccessResponse,
54 | SetTaskPushNotificationConfigSuccessResponse,
55 | GetTaskPushNotificationConfigSuccessResponse,
56 | SendStreamingMessageSuccessResponse,
57 | ListTaskPushNotificationConfigSuccessResponse,
58 | DeleteTaskPushNotificationConfigSuccessResponse,
59 | )
60 | """Type variable for SuccessResponse types."""
61 |
62 | # result types
63 | EventTypes = (
64 | Task
65 | | Message
66 | | TaskArtifactUpdateEvent
67 | | TaskStatusUpdateEvent
68 | | TaskPushNotificationConfig
69 | | A2AError
70 | | JSONRPCError
71 | | list[TaskPushNotificationConfig]
72 | )
73 | """Type alias for possible event types produced by handlers."""
74 |
75 |
76 | def build_error_response(
77 | request_id: str | int | None,
78 | error: A2AError | JSONRPCError,
79 | response_wrapper_type: type[RT],
80 | ) -> RT:
81 | """Helper method to build a JSONRPCErrorResponse wrapped in the appropriate response type.
82 |
83 | Args:
84 | request_id: The ID of the request that caused the error.
85 | error: The A2AError or JSONRPCError object.
86 | response_wrapper_type: The Pydantic RootModel type that wraps the response
87 | for the specific RPC method (e.g., `SendMessageResponse`).
88 |
89 | Returns:
90 | A Pydantic model representing the JSON-RPC error response,
91 | wrapped in the specified response type.
92 | """
93 | return response_wrapper_type(
94 | JSONRPCErrorResponse(
95 | id=request_id,
96 | error=error.root if isinstance(error, A2AError) else error,
97 | )
98 | )
99 |
100 |
101 | def prepare_response_object(
102 | request_id: str | int | None,
103 | response: EventTypes,
104 | success_response_types: tuple[type, ...],
105 | success_payload_type: type[SPT],
106 | response_type: type[RT],
107 | ) -> RT:
108 | """Helper method to build appropriate JSONRPCResponse object for RPC methods.
109 |
110 | Based on the type of the `response` object received from the handler,
111 | it constructs either a success response wrapped in the appropriate payload type
112 | or an error response.
113 |
114 | Args:
115 | request_id: The ID of the request.
116 | response: The object received from the request handler.
117 | success_response_types: A tuple of expected Pydantic model types for a successful result.
118 | success_payload_type: The Pydantic model type for the success payload
119 | (e.g., `SendMessageSuccessResponse`).
120 | response_type: The Pydantic RootModel type that wraps the final response
121 | (e.g., `SendMessageResponse`).
122 |
123 | Returns:
124 | A Pydantic model representing the final JSON-RPC response (success or error).
125 | """
126 | if isinstance(response, success_response_types):
127 | return response_type(
128 | root=success_payload_type(id=request_id, result=response) # type:ignore
129 | )
130 |
131 | if isinstance(response, A2AError | JSONRPCError):
132 | return build_error_response(request_id, response, response_type)
133 |
134 | # If consumer_data is not an expected success type and not an error,
135 | # it's an invalid type of response from the agent for this specific method.
136 | response = A2AError(
137 | root=InvalidAgentResponseError(
138 | message='Agent returned invalid type response for this method'
139 | )
140 | )
141 |
142 | return build_error_response(request_id, response, response_type)
143 |
--------------------------------------------------------------------------------
/src/a2a/server/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | """Components for managing tasks within the A2A server."""
2 |
3 | import logging
4 |
5 | from a2a.server.tasks.base_push_notification_sender import (
6 | BasePushNotificationSender,
7 | )
8 | from a2a.server.tasks.inmemory_push_notification_config_store import (
9 | InMemoryPushNotificationConfigStore,
10 | )
11 | from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
12 | from a2a.server.tasks.push_notification_config_store import (
13 | PushNotificationConfigStore,
14 | )
15 | from a2a.server.tasks.push_notification_sender import PushNotificationSender
16 | from a2a.server.tasks.result_aggregator import ResultAggregator
17 | from a2a.server.tasks.task_manager import TaskManager
18 | from a2a.server.tasks.task_store import TaskStore
19 | from a2a.server.tasks.task_updater import TaskUpdater
20 |
21 |
22 | logger = logging.getLogger(__name__)
23 |
24 | try:
25 | from a2a.server.tasks.database_task_store import (
26 | DatabaseTaskStore, # type: ignore
27 | )
28 | except ImportError as e:
29 | _original_error = e
30 | # If the database task store is not available, we can still use in-memory stores.
31 | logger.debug(
32 | 'DatabaseTaskStore not loaded. This is expected if database dependencies are not installed. Error: %s',
33 | e,
34 | )
35 |
36 | class DatabaseTaskStore: # type: ignore
37 | """Placeholder for DatabaseTaskStore when dependencies are not installed."""
38 |
39 | def __init__(self, *args, **kwargs):
40 | raise ImportError(
41 | 'To use DatabaseTaskStore, its dependencies must be installed. '
42 | 'You can install them with \'pip install "a2a-sdk[sql]"\''
43 | ) from _original_error
44 |
45 |
46 | try:
47 | from a2a.server.tasks.database_push_notification_config_store import (
48 | DatabasePushNotificationConfigStore, # type: ignore
49 | )
50 | except ImportError as e:
51 | _original_error = e
52 | # If the database push notification config store is not available, we can still use in-memory stores.
53 | logger.debug(
54 | 'DatabasePushNotificationConfigStore not loaded. This is expected if database dependencies are not installed. Error: %s',
55 | e,
56 | )
57 |
58 | class DatabasePushNotificationConfigStore: # type: ignore
59 | """Placeholder for DatabasePushNotificationConfigStore when dependencies are not installed."""
60 |
61 | def __init__(self, *args, **kwargs):
62 | raise ImportError(
63 | 'To use DatabasePushNotificationConfigStore, its dependencies must be installed. '
64 | 'You can install them with \'pip install "a2a-sdk[sql]"\''
65 | ) from _original_error
66 |
67 |
68 | __all__ = [
69 | 'BasePushNotificationSender',
70 | 'DatabasePushNotificationConfigStore',
71 | 'DatabaseTaskStore',
72 | 'InMemoryPushNotificationConfigStore',
73 | 'InMemoryTaskStore',
74 | 'PushNotificationConfigStore',
75 | 'PushNotificationSender',
76 | 'ResultAggregator',
77 | 'TaskManager',
78 | 'TaskStore',
79 | 'TaskUpdater',
80 | ]
81 |
--------------------------------------------------------------------------------
/src/a2a/server/tasks/base_push_notification_sender.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import logging
3 |
4 | import httpx
5 |
6 | from a2a.server.tasks.push_notification_config_store import (
7 | PushNotificationConfigStore,
8 | )
9 | from a2a.server.tasks.push_notification_sender import PushNotificationSender
10 | from a2a.types import PushNotificationConfig, Task
11 |
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | class BasePushNotificationSender(PushNotificationSender):
17 | """Base implementation of PushNotificationSender interface."""
18 |
19 | def __init__(
20 | self,
21 | httpx_client: httpx.AsyncClient,
22 | config_store: PushNotificationConfigStore,
23 | ) -> None:
24 | """Initializes the BasePushNotificationSender.
25 |
26 | Args:
27 | httpx_client: An async HTTP client instance to send notifications.
28 | config_store: A PushNotificationConfigStore instance to retrieve configurations.
29 | """
30 | self._client = httpx_client
31 | self._config_store = config_store
32 |
33 | async def send_notification(self, task: Task) -> None:
34 | """Sends a push notification for a task if configuration exists."""
35 | push_configs = await self._config_store.get_info(task.id)
36 | if not push_configs:
37 | return
38 |
39 | awaitables = [
40 | self._dispatch_notification(task, push_info)
41 | for push_info in push_configs
42 | ]
43 | results = await asyncio.gather(*awaitables)
44 |
45 | if not all(results):
46 | logger.warning(
47 | 'Some push notifications failed to send for task_id=%s', task.id
48 | )
49 |
50 | async def _dispatch_notification(
51 | self, task: Task, push_info: PushNotificationConfig
52 | ) -> bool:
53 | url = push_info.url
54 | try:
55 | headers = None
56 | if push_info.token:
57 | headers = {'X-A2A-Notification-Token': push_info.token}
58 | response = await self._client.post(
59 | url,
60 | json=task.model_dump(mode='json', exclude_none=True),
61 | headers=headers,
62 | )
63 | response.raise_for_status()
64 | logger.info(
65 | 'Push-notification sent for task_id=%s to URL: %s', task.id, url
66 | )
67 | except Exception:
68 | logger.exception(
69 | 'Error sending push-notification for task_id=%s to URL: %s.',
70 | task.id,
71 | url,
72 | )
73 | return False
74 | return True
75 |
--------------------------------------------------------------------------------
/src/a2a/server/tasks/database_task_store.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 |
4 | try:
5 | from sqlalchemy import Table, delete, select
6 | from sqlalchemy.ext.asyncio import (
7 | AsyncEngine,
8 | AsyncSession,
9 | async_sessionmaker,
10 | )
11 | from sqlalchemy.orm import class_mapper
12 | except ImportError as e:
13 | raise ImportError(
14 | 'DatabaseTaskStore requires SQLAlchemy and a database driver. '
15 | 'Install with one of: '
16 | "'pip install a2a-sdk[postgresql]', "
17 | "'pip install a2a-sdk[mysql]', "
18 | "'pip install a2a-sdk[sqlite]', "
19 | "or 'pip install a2a-sdk[sql]'"
20 | ) from e
21 |
22 | from a2a.server.context import ServerCallContext
23 | from a2a.server.models import Base, TaskModel, create_task_model
24 | from a2a.server.tasks.task_store import TaskStore
25 | from a2a.types import Task # Task is the Pydantic model
26 |
27 |
28 | logger = logging.getLogger(__name__)
29 |
30 |
31 | class DatabaseTaskStore(TaskStore):
32 | """SQLAlchemy-based implementation of TaskStore.
33 |
34 | Stores task objects in a database supported by SQLAlchemy.
35 | """
36 |
37 | engine: AsyncEngine
38 | async_session_maker: async_sessionmaker[AsyncSession]
39 | create_table: bool
40 | _initialized: bool
41 | task_model: type[TaskModel]
42 |
43 | def __init__(
44 | self,
45 | engine: AsyncEngine,
46 | create_table: bool = True,
47 | table_name: str = 'tasks',
48 | ) -> None:
49 | """Initializes the DatabaseTaskStore.
50 |
51 | Args:
52 | engine: An existing SQLAlchemy AsyncEngine to be used by Task Store
53 | create_table: If true, create tasks table on initialization.
54 | table_name: Name of the database table. Defaults to 'tasks'.
55 | """
56 | logger.debug(
57 | 'Initializing DatabaseTaskStore with existing engine, table: %s',
58 | table_name,
59 | )
60 | self.engine = engine
61 | self.async_session_maker = async_sessionmaker(
62 | self.engine, expire_on_commit=False
63 | )
64 | self.create_table = create_table
65 | self._initialized = False
66 |
67 | self.task_model = (
68 | TaskModel
69 | if table_name == 'tasks'
70 | else create_task_model(table_name)
71 | )
72 |
73 | async def initialize(self) -> None:
74 | """Initialize the database and create the table if needed."""
75 | if self._initialized:
76 | return
77 |
78 | logger.debug('Initializing database schema...')
79 | if self.create_table:
80 | async with self.engine.begin() as conn:
81 | mapper = class_mapper(self.task_model)
82 | tables_to_create = [
83 | table for table in mapper.tables if isinstance(table, Table)
84 | ]
85 | await conn.run_sync(
86 | Base.metadata.create_all, tables=tables_to_create
87 | )
88 | self._initialized = True
89 | logger.debug('Database schema initialized.')
90 |
91 | async def _ensure_initialized(self) -> None:
92 | """Ensure the database connection is initialized."""
93 | if not self._initialized:
94 | await self.initialize()
95 |
96 | def _to_orm(self, task: Task) -> TaskModel:
97 | """Maps a Pydantic Task to a SQLAlchemy TaskModel instance."""
98 | return self.task_model(
99 | id=task.id,
100 | context_id=task.context_id,
101 | kind=task.kind,
102 | status=task.status,
103 | artifacts=task.artifacts,
104 | history=task.history,
105 | task_metadata=task.metadata,
106 | )
107 |
108 | def _from_orm(self, task_model: TaskModel) -> Task:
109 | """Maps a SQLAlchemy TaskModel to a Pydantic Task instance."""
110 | # Map database columns to Pydantic model fields
111 | task_data_from_db = {
112 | 'id': task_model.id,
113 | 'context_id': task_model.context_id,
114 | 'kind': task_model.kind,
115 | 'status': task_model.status,
116 | 'artifacts': task_model.artifacts,
117 | 'history': task_model.history,
118 | 'metadata': task_model.task_metadata, # Map task_metadata column to metadata field
119 | }
120 | # Pydantic's model_validate will parse the nested dicts/lists from JSON
121 | return Task.model_validate(task_data_from_db)
122 |
123 | async def save(
124 | self, task: Task, context: ServerCallContext | None = None
125 | ) -> None:
126 | """Saves or updates a task in the database."""
127 | await self._ensure_initialized()
128 | db_task = self._to_orm(task)
129 | async with self.async_session_maker.begin() as session:
130 | await session.merge(db_task)
131 | logger.debug('Task %s saved/updated successfully.', task.id)
132 |
133 | async def get(
134 | self, task_id: str, context: ServerCallContext | None = None
135 | ) -> Task | None:
136 | """Retrieves a task from the database by ID."""
137 | await self._ensure_initialized()
138 | async with self.async_session_maker() as session:
139 | stmt = select(self.task_model).where(self.task_model.id == task_id)
140 | result = await session.execute(stmt)
141 | task_model = result.scalar_one_or_none()
142 | if task_model:
143 | task = self._from_orm(task_model)
144 | logger.debug('Task %s retrieved successfully.', task_id)
145 | return task
146 |
147 | logger.debug('Task %s not found in store.', task_id)
148 | return None
149 |
150 | async def delete(
151 | self, task_id: str, context: ServerCallContext | None = None
152 | ) -> None:
153 | """Deletes a task from the database by ID."""
154 | await self._ensure_initialized()
155 |
156 | async with self.async_session_maker.begin() as session:
157 | stmt = delete(self.task_model).where(self.task_model.id == task_id)
158 | result = await session.execute(stmt)
159 | # Commit is automatic when using session.begin()
160 |
161 | if result.rowcount > 0:
162 | logger.info('Task %s deleted successfully.', task_id)
163 | else:
164 | logger.warning(
165 | 'Attempted to delete nonexistent task with id: %s', task_id
166 | )
167 |
--------------------------------------------------------------------------------
/src/a2a/server/tasks/inmemory_push_notification_config_store.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import logging
3 |
4 | from a2a.server.tasks.push_notification_config_store import (
5 | PushNotificationConfigStore,
6 | )
7 | from a2a.types import PushNotificationConfig
8 |
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | class InMemoryPushNotificationConfigStore(PushNotificationConfigStore):
14 | """In-memory implementation of PushNotificationConfigStore interface.
15 |
16 | Stores push notification configurations in memory
17 | """
18 |
19 | def __init__(self) -> None:
20 | """Initializes the InMemoryPushNotificationConfigStore."""
21 | self.lock = asyncio.Lock()
22 | self._push_notification_infos: dict[
23 | str, list[PushNotificationConfig]
24 | ] = {}
25 |
26 | async def set_info(
27 | self, task_id: str, notification_config: PushNotificationConfig
28 | ) -> None:
29 | """Sets or updates the push notification configuration for a task in memory."""
30 | async with self.lock:
31 | if task_id not in self._push_notification_infos:
32 | self._push_notification_infos[task_id] = []
33 |
34 | if notification_config.id is None:
35 | notification_config.id = task_id
36 |
37 | for config in self._push_notification_infos[task_id]:
38 | if config.id == notification_config.id:
39 | self._push_notification_infos[task_id].remove(config)
40 | break
41 |
42 | self._push_notification_infos[task_id].append(notification_config)
43 |
44 | async def get_info(self, task_id: str) -> list[PushNotificationConfig]:
45 | """Retrieves the push notification configuration for a task from memory."""
46 | async with self.lock:
47 | return self._push_notification_infos.get(task_id) or []
48 |
49 | async def delete_info(
50 | self, task_id: str, config_id: str | None = None
51 | ) -> None:
52 | """Deletes the push notification configuration for a task from memory."""
53 | async with self.lock:
54 | if config_id is None:
55 | config_id = task_id
56 |
57 | if task_id in self._push_notification_infos:
58 | configurations = self._push_notification_infos[task_id]
59 | if not configurations:
60 | return
61 |
62 | for config in configurations:
63 | if config.id == config_id:
64 | configurations.remove(config)
65 | break
66 |
67 | if len(configurations) == 0:
68 | del self._push_notification_infos[task_id]
69 |
--------------------------------------------------------------------------------
/src/a2a/server/tasks/inmemory_task_store.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import logging
3 |
4 | from a2a.server.context import ServerCallContext
5 | from a2a.server.tasks.task_store import TaskStore
6 | from a2a.types import Task
7 |
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | class InMemoryTaskStore(TaskStore):
13 | """In-memory implementation of TaskStore.
14 |
15 | Stores task objects in a dictionary in memory. Task data is lost when the
16 | server process stops.
17 | """
18 |
19 | def __init__(self) -> None:
20 | """Initializes the InMemoryTaskStore."""
21 | logger.debug('Initializing InMemoryTaskStore')
22 | self.tasks: dict[str, Task] = {}
23 | self.lock = asyncio.Lock()
24 |
25 | async def save(
26 | self, task: Task, context: ServerCallContext | None = None
27 | ) -> None:
28 | """Saves or updates a task in the in-memory store."""
29 | async with self.lock:
30 | self.tasks[task.id] = task
31 | logger.debug('Task %s saved successfully.', task.id)
32 |
33 | async def get(
34 | self, task_id: str, context: ServerCallContext | None = None
35 | ) -> Task | None:
36 | """Retrieves a task from the in-memory store by ID."""
37 | async with self.lock:
38 | logger.debug('Attempting to get task with id: %s', task_id)
39 | task = self.tasks.get(task_id)
40 | if task:
41 | logger.debug('Task %s retrieved successfully.', task_id)
42 | else:
43 | logger.debug('Task %s not found in store.', task_id)
44 | return task
45 |
46 | async def delete(
47 | self, task_id: str, context: ServerCallContext | None = None
48 | ) -> None:
49 | """Deletes a task from the in-memory store by ID."""
50 | async with self.lock:
51 | logger.debug('Attempting to delete task with id: %s', task_id)
52 | if task_id in self.tasks:
53 | del self.tasks[task_id]
54 | logger.debug('Task %s deleted successfully.', task_id)
55 | else:
56 | logger.warning(
57 | 'Attempted to delete nonexistent task with id: %s', task_id
58 | )
59 |
--------------------------------------------------------------------------------
/src/a2a/server/tasks/push_notification_config_store.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | from a2a.types import PushNotificationConfig
4 |
5 |
6 | class PushNotificationConfigStore(ABC):
7 | """Interface for storing and retrieving push notification configurations for tasks."""
8 |
9 | @abstractmethod
10 | async def set_info(
11 | self, task_id: str, notification_config: PushNotificationConfig
12 | ) -> None:
13 | """Sets or updates the push notification configuration for a task."""
14 |
15 | @abstractmethod
16 | async def get_info(self, task_id: str) -> list[PushNotificationConfig]:
17 | """Retrieves the push notification configuration for a task."""
18 |
19 | @abstractmethod
20 | async def delete_info(
21 | self, task_id: str, config_id: str | None = None
22 | ) -> None:
23 | """Deletes the push notification configuration for a task."""
24 |
--------------------------------------------------------------------------------
/src/a2a/server/tasks/push_notification_sender.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | from a2a.types import Task
4 |
5 |
6 | class PushNotificationSender(ABC):
7 | """Interface for sending push notifications for tasks."""
8 |
9 | @abstractmethod
10 | async def send_notification(self, task: Task) -> None:
11 | """Sends a push notification containing the latest task state."""
12 |
--------------------------------------------------------------------------------
/src/a2a/server/tasks/task_store.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | from a2a.server.context import ServerCallContext
4 | from a2a.types import Task
5 |
6 |
7 | class TaskStore(ABC):
8 | """Agent Task Store interface.
9 |
10 | Defines the methods for persisting and retrieving `Task` objects.
11 | """
12 |
13 | @abstractmethod
14 | async def save(
15 | self, task: Task, context: ServerCallContext | None = None
16 | ) -> None:
17 | """Saves or updates a task in the store."""
18 |
19 | @abstractmethod
20 | async def get(
21 | self, task_id: str, context: ServerCallContext | None = None
22 | ) -> Task | None:
23 | """Retrieves a task from the store by ID."""
24 |
25 | @abstractmethod
26 | async def delete(
27 | self, task_id: str, context: ServerCallContext | None = None
28 | ) -> None:
29 | """Deletes a task from the store by ID."""
30 |
--------------------------------------------------------------------------------
/src/a2a/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """Utility functions for the A2A Python SDK."""
2 |
3 | from a2a.utils.artifact import (
4 | new_artifact,
5 | new_data_artifact,
6 | new_text_artifact,
7 | )
8 | from a2a.utils.constants import (
9 | AGENT_CARD_WELL_KNOWN_PATH,
10 | DEFAULT_RPC_URL,
11 | EXTENDED_AGENT_CARD_PATH,
12 | PREV_AGENT_CARD_WELL_KNOWN_PATH,
13 | )
14 | from a2a.utils.helpers import (
15 | append_artifact_to_task,
16 | are_modalities_compatible,
17 | build_text_artifact,
18 | create_task_obj,
19 | )
20 | from a2a.utils.message import (
21 | get_data_parts,
22 | get_file_parts,
23 | get_message_text,
24 | get_text_parts,
25 | new_agent_parts_message,
26 | new_agent_text_message,
27 | )
28 | from a2a.utils.task import (
29 | completed_task,
30 | new_task,
31 | )
32 |
33 |
34 | __all__ = [
35 | 'AGENT_CARD_WELL_KNOWN_PATH',
36 | 'DEFAULT_RPC_URL',
37 | 'EXTENDED_AGENT_CARD_PATH',
38 | 'PREV_AGENT_CARD_WELL_KNOWN_PATH',
39 | 'append_artifact_to_task',
40 | 'are_modalities_compatible',
41 | 'build_text_artifact',
42 | 'completed_task',
43 | 'create_task_obj',
44 | 'get_data_parts',
45 | 'get_file_parts',
46 | 'get_message_text',
47 | 'get_text_parts',
48 | 'new_agent_parts_message',
49 | 'new_agent_text_message',
50 | 'new_artifact',
51 | 'new_data_artifact',
52 | 'new_task',
53 | 'new_text_artifact',
54 | ]
55 |
--------------------------------------------------------------------------------
/src/a2a/utils/artifact.py:
--------------------------------------------------------------------------------
1 | """Utility functions for creating A2A Artifact objects."""
2 |
3 | import uuid
4 |
5 | from typing import Any
6 |
7 | from a2a.types import Artifact, DataPart, Part, TextPart
8 |
9 |
10 | def new_artifact(
11 | parts: list[Part], name: str, description: str = ''
12 | ) -> Artifact:
13 | """Creates a new Artifact object.
14 |
15 | Args:
16 | parts: The list of `Part` objects forming the artifact's content.
17 | name: The human-readable name of the artifact.
18 | description: An optional description of the artifact.
19 |
20 | Returns:
21 | A new `Artifact` object with a generated artifact_id.
22 | """
23 | return Artifact(
24 | artifact_id=str(uuid.uuid4()),
25 | parts=parts,
26 | name=name,
27 | description=description,
28 | )
29 |
30 |
31 | def new_text_artifact(
32 | name: str,
33 | text: str,
34 | description: str = '',
35 | ) -> Artifact:
36 | """Creates a new Artifact object containing only a single TextPart.
37 |
38 | Args:
39 | name: The human-readable name of the artifact.
40 | text: The text content of the artifact.
41 | description: An optional description of the artifact.
42 |
43 | Returns:
44 | A new `Artifact` object with a generated artifact_id.
45 | """
46 | return new_artifact(
47 | [Part(root=TextPart(text=text))],
48 | name,
49 | description,
50 | )
51 |
52 |
53 | def new_data_artifact(
54 | name: str,
55 | data: dict[str, Any],
56 | description: str = '',
57 | ) -> Artifact:
58 | """Creates a new Artifact object containing only a single DataPart.
59 |
60 | Args:
61 | name: The human-readable name of the artifact.
62 | data: The structured data content of the artifact.
63 | description: An optional description of the artifact.
64 |
65 | Returns:
66 | A new `Artifact` object with a generated artifact_id.
67 | """
68 | return new_artifact(
69 | [Part(root=DataPart(data=data))],
70 | name,
71 | description,
72 | )
73 |
--------------------------------------------------------------------------------
/src/a2a/utils/constants.py:
--------------------------------------------------------------------------------
1 | """Constants for well-known URIs used throughout the A2A Python SDK."""
2 |
3 | AGENT_CARD_WELL_KNOWN_PATH = '/.well-known/agent-card.json'
4 | PREV_AGENT_CARD_WELL_KNOWN_PATH = '/.well-known/agent.json'
5 | EXTENDED_AGENT_CARD_PATH = '/agent/authenticatedExtendedCard'
6 | DEFAULT_RPC_URL = '/'
7 |
--------------------------------------------------------------------------------
/src/a2a/utils/error_handlers.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import logging
3 |
4 | from collections.abc import Awaitable, Callable, Coroutine
5 | from typing import TYPE_CHECKING, Any
6 |
7 |
8 | if TYPE_CHECKING:
9 | from starlette.responses import JSONResponse, Response
10 | else:
11 | try:
12 | from starlette.responses import JSONResponse, Response
13 | except ImportError:
14 | JSONResponse = Any
15 | Response = Any
16 |
17 |
18 | from a2a._base import A2ABaseModel
19 | from a2a.types import (
20 | AuthenticatedExtendedCardNotConfiguredError,
21 | ContentTypeNotSupportedError,
22 | InternalError,
23 | InvalidAgentResponseError,
24 | InvalidParamsError,
25 | InvalidRequestError,
26 | JSONParseError,
27 | MethodNotFoundError,
28 | PushNotificationNotSupportedError,
29 | TaskNotCancelableError,
30 | TaskNotFoundError,
31 | UnsupportedOperationError,
32 | )
33 | from a2a.utils.errors import ServerError
34 |
35 |
36 | logger = logging.getLogger(__name__)
37 |
38 | A2AErrorToHttpStatus: dict[type[A2ABaseModel], int] = {
39 | JSONParseError: 400,
40 | InvalidRequestError: 400,
41 | MethodNotFoundError: 404,
42 | InvalidParamsError: 422,
43 | InternalError: 500,
44 | TaskNotFoundError: 404,
45 | TaskNotCancelableError: 409,
46 | PushNotificationNotSupportedError: 501,
47 | UnsupportedOperationError: 501,
48 | ContentTypeNotSupportedError: 415,
49 | InvalidAgentResponseError: 502,
50 | AuthenticatedExtendedCardNotConfiguredError: 404,
51 | }
52 |
53 |
54 | def rest_error_handler(
55 | func: Callable[..., Awaitable[Response]],
56 | ) -> Callable[..., Awaitable[Response]]:
57 | """Decorator to catch ServerError and map it to an appropriate JSONResponse."""
58 |
59 | @functools.wraps(func)
60 | async def wrapper(*args: Any, **kwargs: Any) -> Response:
61 | try:
62 | return await func(*args, **kwargs)
63 | except ServerError as e:
64 | error = e.error or InternalError(
65 | message='Internal error due to unknown reason'
66 | )
67 | http_code = A2AErrorToHttpStatus.get(type(error), 500)
68 |
69 | log_level = (
70 | logging.ERROR
71 | if isinstance(error, InternalError)
72 | else logging.WARNING
73 | )
74 | logger.log(
75 | log_level,
76 | "Request error: Code=%s, Message='%s'%s",
77 | error.code,
78 | error.message,
79 | ', Data=' + str(error.data) if error.data else '',
80 | )
81 | return JSONResponse(
82 | content={'message': error.message}, status_code=http_code
83 | )
84 | except Exception:
85 | logger.exception('Unknown error occurred')
86 | return JSONResponse(
87 | content={'message': 'unknown exception'}, status_code=500
88 | )
89 |
90 | return wrapper
91 |
92 |
93 | def rest_stream_error_handler(
94 | func: Callable[..., Coroutine[Any, Any, Any]],
95 | ) -> Callable[..., Coroutine[Any, Any, Any]]:
96 | """Decorator to catch ServerError for a streaming method,log it and then rethrow it to be handled by framework."""
97 |
98 | @functools.wraps(func)
99 | async def wrapper(*args: Any, **kwargs: Any) -> Any:
100 | try:
101 | return await func(*args, **kwargs)
102 | except ServerError as e:
103 | error = e.error or InternalError(
104 | message='Internal error due to unknown reason'
105 | )
106 |
107 | log_level = (
108 | logging.ERROR
109 | if isinstance(error, InternalError)
110 | else logging.WARNING
111 | )
112 | logger.log(
113 | log_level,
114 | "Request error: Code=%s, Message='%s'%s",
115 | error.code,
116 | error.message,
117 | ', Data=' + str(error.data) if error.data else '',
118 | )
119 | # Since the stream has started, we can't return a JSONResponse.
120 | # Instead, we runt the error handling logic (provides logging)
121 | # and reraise the error and let server framework manage
122 | raise e
123 | except Exception as e:
124 | # Since the stream has started, we can't return a JSONResponse.
125 | # Instead, we runt the error handling logic (provides logging)
126 | # and reraise the error and let server framework manage
127 | raise e
128 |
129 | return wrapper
130 |
--------------------------------------------------------------------------------
/src/a2a/utils/errors.py:
--------------------------------------------------------------------------------
1 | """Custom exceptions for A2A server-side errors."""
2 |
3 | from a2a.types import (
4 | AuthenticatedExtendedCardNotConfiguredError,
5 | ContentTypeNotSupportedError,
6 | InternalError,
7 | InvalidAgentResponseError,
8 | InvalidParamsError,
9 | InvalidRequestError,
10 | JSONParseError,
11 | JSONRPCError,
12 | MethodNotFoundError,
13 | PushNotificationNotSupportedError,
14 | TaskNotCancelableError,
15 | TaskNotFoundError,
16 | UnsupportedOperationError,
17 | )
18 |
19 |
20 | class A2AServerError(Exception):
21 | """Base exception for A2A Server errors."""
22 |
23 |
24 | class MethodNotImplementedError(A2AServerError):
25 | """Exception raised for methods that are not implemented by the server handler."""
26 |
27 | def __init__(
28 | self, message: str = 'This method is not implemented by the server'
29 | ):
30 | """Initializes the MethodNotImplementedError.
31 |
32 | Args:
33 | message: A descriptive error message.
34 | """
35 | self.message = message
36 | super().__init__(f'Not Implemented operation Error: {message}')
37 |
38 |
39 | class ServerError(Exception):
40 | """Wrapper exception for A2A or JSON-RPC errors originating from the server's logic.
41 |
42 | This exception is used internally by request handlers and other server components
43 | to signal a specific error that should be formatted as a JSON-RPC error response.
44 | """
45 |
46 | def __init__(
47 | self,
48 | error: (
49 | JSONRPCError
50 | | JSONParseError
51 | | InvalidRequestError
52 | | MethodNotFoundError
53 | | InvalidParamsError
54 | | InternalError
55 | | TaskNotFoundError
56 | | TaskNotCancelableError
57 | | PushNotificationNotSupportedError
58 | | UnsupportedOperationError
59 | | ContentTypeNotSupportedError
60 | | InvalidAgentResponseError
61 | | AuthenticatedExtendedCardNotConfiguredError
62 | | None
63 | ),
64 | ):
65 | """Initializes the ServerError.
66 |
67 | Args:
68 | error: The specific A2A or JSON-RPC error model instance.
69 | """
70 | self.error = error
71 |
72 | def __str__(self) -> str:
73 | """Returns a readable representation of the internal Pydantic error."""
74 | if self.error is None:
75 | return 'None'
76 | if self.error.message is None:
77 | return self.error.__class__.__name__
78 | return self.error.message
79 |
80 | def __repr__(self) -> str:
81 | """Returns an unambiguous representation for developers showing how the ServerError was constructed with the internal Pydantic error."""
82 | return f'{self.__class__.__name__}({self.error!r})'
83 |
--------------------------------------------------------------------------------
/src/a2a/utils/message.py:
--------------------------------------------------------------------------------
1 | """Utility functions for creating and handling A2A Message objects."""
2 |
3 | import uuid
4 |
5 | from typing import Any
6 |
7 | from a2a.types import (
8 | DataPart,
9 | FilePart,
10 | FileWithBytes,
11 | FileWithUri,
12 | Message,
13 | Part,
14 | Role,
15 | TextPart,
16 | )
17 |
18 |
19 | def new_agent_text_message(
20 | text: str,
21 | context_id: str | None = None,
22 | task_id: str | None = None,
23 | ) -> Message:
24 | """Creates a new agent message containing a single TextPart.
25 |
26 | Args:
27 | text: The text content of the message.
28 | context_id: The context ID for the message.
29 | task_id: The task ID for the message.
30 |
31 | Returns:
32 | A new `Message` object with role 'agent'.
33 | """
34 | return Message(
35 | role=Role.agent,
36 | parts=[Part(root=TextPart(text=text))],
37 | message_id=str(uuid.uuid4()),
38 | task_id=task_id,
39 | context_id=context_id,
40 | )
41 |
42 |
43 | def new_agent_parts_message(
44 | parts: list[Part],
45 | context_id: str | None = None,
46 | task_id: str | None = None,
47 | ) -> Message:
48 | """Creates a new agent message containing a list of Parts.
49 |
50 | Args:
51 | parts: The list of `Part` objects for the message content.
52 | context_id: The context ID for the message.
53 | task_id: The task ID for the message.
54 |
55 | Returns:
56 | A new `Message` object with role 'agent'.
57 | """
58 | return Message(
59 | role=Role.agent,
60 | parts=parts,
61 | message_id=str(uuid.uuid4()),
62 | task_id=task_id,
63 | context_id=context_id,
64 | )
65 |
66 |
67 | def get_text_parts(parts: list[Part]) -> list[str]:
68 | """Extracts text content from all TextPart objects in a list of Parts.
69 |
70 | Args:
71 | parts: A list of `Part` objects.
72 |
73 | Returns:
74 | A list of strings containing the text content from any `TextPart` objects found.
75 | """
76 | return [part.root.text for part in parts if isinstance(part.root, TextPart)]
77 |
78 |
79 | def get_data_parts(parts: list[Part]) -> list[dict[str, Any]]:
80 | """Extracts dictionary data from all DataPart objects in a list of Parts.
81 |
82 | Args:
83 | parts: A list of `Part` objects.
84 |
85 | Returns:
86 | A list of dictionaries containing the data from any `DataPart` objects found.
87 | """
88 | return [part.root.data for part in parts if isinstance(part.root, DataPart)]
89 |
90 |
91 | def get_file_parts(parts: list[Part]) -> list[FileWithBytes | FileWithUri]:
92 | """Extracts file data from all FilePart objects in a list of Parts.
93 |
94 | Args:
95 | parts: A list of `Part` objects.
96 |
97 | Returns:
98 | A list of `FileWithBytes` or `FileWithUri` objects containing the file data from any `FilePart` objects found.
99 | """
100 | return [part.root.file for part in parts if isinstance(part.root, FilePart)]
101 |
102 |
103 | def get_message_text(message: Message, delimiter: str = '\n') -> str:
104 | """Extracts and joins all text content from a Message's parts.
105 |
106 | Args:
107 | message: The `Message` object.
108 | delimiter: The string to use when joining text from multiple TextParts.
109 |
110 | Returns:
111 | A single string containing all text content, or an empty string if no text parts are found.
112 | """
113 | return delimiter.join(get_text_parts(message.parts))
114 |
--------------------------------------------------------------------------------
/src/a2a/utils/task.py:
--------------------------------------------------------------------------------
1 | """Utility functions for creating A2A Task objects."""
2 |
3 | import uuid
4 |
5 | from a2a.types import Artifact, Message, Task, TaskState, TaskStatus, TextPart
6 |
7 |
8 | def new_task(request: Message) -> Task:
9 | """Creates a new Task object from an initial user message.
10 |
11 | Generates task and context IDs if not provided in the message.
12 |
13 | Args:
14 | request: The initial `Message` object from the user.
15 |
16 | Returns:
17 | A new `Task` object initialized with 'submitted' status and the input message in history.
18 |
19 | Raises:
20 | TypeError: If the message role is None.
21 | ValueError: If the message parts are empty, if any part has empty content, or if the provided context_id is invalid.
22 | """
23 | if not request.role:
24 | raise TypeError('Message role cannot be None')
25 | if not request.parts:
26 | raise ValueError('Message parts cannot be empty')
27 | for part in request.parts:
28 | if isinstance(part.root, TextPart) and not part.root.text:
29 | raise ValueError('TextPart content cannot be empty')
30 |
31 | return Task(
32 | status=TaskStatus(state=TaskState.submitted),
33 | id=request.task_id or str(uuid.uuid4()),
34 | context_id=request.context_id or str(uuid.uuid4()),
35 | history=[request],
36 | )
37 |
38 |
39 | def completed_task(
40 | task_id: str,
41 | context_id: str,
42 | artifacts: list[Artifact],
43 | history: list[Message] | None = None,
44 | ) -> Task:
45 | """Creates a Task object in the 'completed' state.
46 |
47 | Useful for constructing a final Task representation when the agent
48 | finishes and produces artifacts.
49 |
50 | Args:
51 | task_id: The ID of the task.
52 | context_id: The context ID of the task.
53 | artifacts: A list of `Artifact` objects produced by the task.
54 | history: An optional list of `Message` objects representing the task history.
55 |
56 | Returns:
57 | A `Task` object with status set to 'completed'.
58 | """
59 | if not artifacts or not all(isinstance(a, Artifact) for a in artifacts):
60 | raise ValueError(
61 | 'artifacts must be a non-empty list of Artifact objects'
62 | )
63 |
64 | if history is None:
65 | history = []
66 | return Task(
67 | status=TaskStatus(state=TaskState.completed),
68 | id=task_id,
69 | context_id=context_id,
70 | artifacts=artifacts,
71 | history=history,
72 | )
73 |
--------------------------------------------------------------------------------
/tests/README.md:
--------------------------------------------------------------------------------
1 | ## Running the tests
2 |
3 | 1. Run the tests
4 | ```bash
5 | uv run pytest -v -s client/test_client.py
6 | ```
7 |
8 | In case of failures, you can cleanup the cache:
9 |
10 | 1. `uv clean`
11 | 2. `rm -fR .pytest_cache .venv __pycache__`
12 |
--------------------------------------------------------------------------------
/tests/auth/test_user.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from a2a.auth.user import UnauthenticatedUser
4 |
5 |
6 | class TestUnauthenticatedUser(unittest.TestCase):
7 | def test_is_authenticated_returns_false(self):
8 | user = UnauthenticatedUser()
9 | self.assertFalse(user.is_authenticated)
10 |
11 | def test_user_name_returns_empty_string(self):
12 | user = UnauthenticatedUser()
13 | self.assertEqual(user.user_name, '')
14 |
15 |
16 | if __name__ == '__main__':
17 | unittest.main()
18 |
--------------------------------------------------------------------------------
/tests/client/test_base_client.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import AsyncMock, MagicMock
2 |
3 | import pytest
4 |
5 | from a2a.client.base_client import BaseClient
6 | from a2a.client.client import ClientConfig
7 | from a2a.client.transports.base import ClientTransport
8 | from a2a.types import (
9 | AgentCapabilities,
10 | AgentCard,
11 | Message,
12 | Part,
13 | Role,
14 | Task,
15 | TaskState,
16 | TaskStatus,
17 | TextPart,
18 | )
19 |
20 |
21 | @pytest.fixture
22 | def mock_transport():
23 | return AsyncMock(spec=ClientTransport)
24 |
25 |
26 | @pytest.fixture
27 | def sample_agent_card():
28 | return AgentCard(
29 | name='Test Agent',
30 | description='An agent for testing',
31 | url='http://test.com',
32 | version='1.0',
33 | capabilities=AgentCapabilities(streaming=True),
34 | default_input_modes=['text/plain'],
35 | default_output_modes=['text/plain'],
36 | skills=[],
37 | )
38 |
39 |
40 | @pytest.fixture
41 | def sample_message():
42 | return Message(
43 | role=Role.user,
44 | message_id='msg-1',
45 | parts=[Part(root=TextPart(text='Hello'))],
46 | )
47 |
48 |
49 | @pytest.fixture
50 | def base_client(sample_agent_card, mock_transport):
51 | config = ClientConfig(streaming=True)
52 | return BaseClient(
53 | card=sample_agent_card,
54 | config=config,
55 | transport=mock_transport,
56 | consumers=[],
57 | middleware=[],
58 | )
59 |
60 |
61 | @pytest.mark.asyncio
62 | async def test_send_message_streaming(
63 | base_client: BaseClient, mock_transport: MagicMock, sample_message: Message
64 | ):
65 | async def create_stream(*args, **kwargs):
66 | yield Task(
67 | id='task-123',
68 | context_id='ctx-456',
69 | status=TaskStatus(state=TaskState.completed),
70 | )
71 |
72 | mock_transport.send_message_streaming.return_value = create_stream()
73 |
74 | events = [event async for event in base_client.send_message(sample_message)]
75 |
76 | mock_transport.send_message_streaming.assert_called_once()
77 | assert not mock_transport.send_message.called
78 | assert len(events) == 1
79 | assert events[0][0].id == 'task-123'
80 |
81 |
82 | @pytest.mark.asyncio
83 | async def test_send_message_non_streaming(
84 | base_client: BaseClient, mock_transport: MagicMock, sample_message: Message
85 | ):
86 | base_client._config.streaming = False
87 | mock_transport.send_message.return_value = Task(
88 | id='task-456',
89 | context_id='ctx-789',
90 | status=TaskStatus(state=TaskState.completed),
91 | )
92 |
93 | events = [event async for event in base_client.send_message(sample_message)]
94 |
95 | mock_transport.send_message.assert_called_once()
96 | assert not mock_transport.send_message_streaming.called
97 | assert len(events) == 1
98 | assert events[0][0].id == 'task-456'
99 |
100 |
101 | @pytest.mark.asyncio
102 | async def test_send_message_non_streaming_agent_capability_false(
103 | base_client: BaseClient, mock_transport: MagicMock, sample_message: Message
104 | ):
105 | base_client._card.capabilities.streaming = False
106 | mock_transport.send_message.return_value = Task(
107 | id='task-789',
108 | context_id='ctx-101',
109 | status=TaskStatus(state=TaskState.completed),
110 | )
111 |
112 | events = [event async for event in base_client.send_message(sample_message)]
113 |
114 | mock_transport.send_message.assert_called_once()
115 | assert not mock_transport.send_message_streaming.called
116 | assert len(events) == 1
117 | assert events[0][0].id == 'task-789'
118 |
--------------------------------------------------------------------------------
/tests/client/test_client_factory.py:
--------------------------------------------------------------------------------
1 | """Tests for the ClientFactory."""
2 |
3 | import httpx
4 | import pytest
5 |
6 | from a2a.client import ClientConfig, ClientFactory
7 | from a2a.client.transports import JsonRpcTransport, RestTransport
8 | from a2a.types import (
9 | AgentCapabilities,
10 | AgentCard,
11 | AgentInterface,
12 | TransportProtocol,
13 | )
14 |
15 |
16 | @pytest.fixture
17 | def base_agent_card() -> AgentCard:
18 | """Provides a base AgentCard for tests."""
19 | return AgentCard(
20 | name='Test Agent',
21 | description='An agent for testing.',
22 | url='http://primary-url.com',
23 | version='1.0.0',
24 | capabilities=AgentCapabilities(),
25 | skills=[],
26 | default_input_modes=[],
27 | default_output_modes=[],
28 | preferred_transport=TransportProtocol.jsonrpc,
29 | )
30 |
31 |
32 | def test_client_factory_selects_preferred_transport(base_agent_card: AgentCard):
33 | """Verify that the factory selects the preferred transport by default."""
34 | config = ClientConfig(
35 | httpx_client=httpx.AsyncClient(),
36 | supported_transports=[
37 | TransportProtocol.jsonrpc,
38 | TransportProtocol.http_json,
39 | ],
40 | )
41 | factory = ClientFactory(config)
42 | client = factory.create(base_agent_card)
43 |
44 | assert isinstance(client._transport, JsonRpcTransport)
45 | assert client._transport.url == 'http://primary-url.com'
46 |
47 |
48 | def test_client_factory_selects_secondary_transport_url(
49 | base_agent_card: AgentCard,
50 | ):
51 | """Verify that the factory selects the correct URL for a secondary transport."""
52 | base_agent_card.additional_interfaces = [
53 | AgentInterface(
54 | transport=TransportProtocol.http_json,
55 | url='http://secondary-url.com',
56 | )
57 | ]
58 | # Client prefers REST, which is available as a secondary transport
59 | config = ClientConfig(
60 | httpx_client=httpx.AsyncClient(),
61 | supported_transports=[
62 | TransportProtocol.http_json,
63 | TransportProtocol.jsonrpc,
64 | ],
65 | use_client_preference=True,
66 | )
67 | factory = ClientFactory(config)
68 | client = factory.create(base_agent_card)
69 |
70 | assert isinstance(client._transport, RestTransport)
71 | assert client._transport.url == 'http://secondary-url.com'
72 |
73 |
74 | def test_client_factory_server_preference(base_agent_card: AgentCard):
75 | """Verify that the factory respects server transport preference."""
76 | base_agent_card.preferred_transport = TransportProtocol.http_json
77 | base_agent_card.additional_interfaces = [
78 | AgentInterface(
79 | transport=TransportProtocol.jsonrpc, url='http://secondary-url.com'
80 | )
81 | ]
82 | # Client supports both, but server prefers REST
83 | config = ClientConfig(
84 | httpx_client=httpx.AsyncClient(),
85 | supported_transports=[
86 | TransportProtocol.jsonrpc,
87 | TransportProtocol.http_json,
88 | ],
89 | )
90 | factory = ClientFactory(config)
91 | client = factory.create(base_agent_card)
92 |
93 | assert isinstance(client._transport, RestTransport)
94 | assert client._transport.url == 'http://primary-url.com'
95 |
96 |
97 | def test_client_factory_no_compatible_transport(base_agent_card: AgentCard):
98 | """Verify that the factory raises an error if no compatible transport is found."""
99 | config = ClientConfig(
100 | httpx_client=httpx.AsyncClient(),
101 | supported_transports=[TransportProtocol.grpc],
102 | )
103 | factory = ClientFactory(config)
104 | with pytest.raises(ValueError, match='no compatible transports found'):
105 | factory.create(base_agent_card)
106 |
--------------------------------------------------------------------------------
/tests/client/test_client_task_manager.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import AsyncMock, Mock, patch
2 |
3 | import pytest
4 |
5 | from a2a.client.client_task_manager import ClientTaskManager
6 | from a2a.client.errors import (
7 | A2AClientInvalidArgsError,
8 | A2AClientInvalidStateError,
9 | )
10 | from a2a.types import (
11 | Artifact,
12 | Message,
13 | Part,
14 | Role,
15 | Task,
16 | TaskArtifactUpdateEvent,
17 | TaskState,
18 | TaskStatus,
19 | TaskStatusUpdateEvent,
20 | TextPart,
21 | )
22 |
23 |
24 | @pytest.fixture
25 | def task_manager():
26 | return ClientTaskManager()
27 |
28 |
29 | @pytest.fixture
30 | def sample_task():
31 | return Task(
32 | id='task123',
33 | context_id='context456',
34 | status=TaskStatus(state=TaskState.working),
35 | history=[],
36 | artifacts=[],
37 | )
38 |
39 |
40 | @pytest.fixture
41 | def sample_message():
42 | return Message(
43 | message_id='msg1',
44 | role=Role.user,
45 | parts=[Part(root=TextPart(text='Hello'))],
46 | )
47 |
48 |
49 | def test_get_task_no_task_id_returns_none(task_manager: ClientTaskManager):
50 | assert task_manager.get_task() is None
51 |
52 |
53 | def test_get_task_or_raise_no_task_raises_error(
54 | task_manager: ClientTaskManager,
55 | ):
56 | with pytest.raises(A2AClientInvalidStateError, match='no current Task'):
57 | task_manager.get_task_or_raise()
58 |
59 |
60 | @pytest.mark.asyncio
61 | async def test_save_task_event_with_task(
62 | task_manager: ClientTaskManager, sample_task: Task
63 | ):
64 | await task_manager.save_task_event(sample_task)
65 | assert task_manager.get_task() == sample_task
66 | assert task_manager._task_id == sample_task.id
67 | assert task_manager._context_id == sample_task.context_id
68 |
69 |
70 | @pytest.mark.asyncio
71 | async def test_save_task_event_with_task_already_set_raises_error(
72 | task_manager: ClientTaskManager, sample_task: Task
73 | ):
74 | await task_manager.save_task_event(sample_task)
75 | with pytest.raises(
76 | A2AClientInvalidArgsError,
77 | match='Task is already set, create new manager for new tasks.',
78 | ):
79 | await task_manager.save_task_event(sample_task)
80 |
81 |
82 | @pytest.mark.asyncio
83 | async def test_save_task_event_with_status_update(
84 | task_manager: ClientTaskManager, sample_task: Task, sample_message: Message
85 | ):
86 | await task_manager.save_task_event(sample_task)
87 | status_update = TaskStatusUpdateEvent(
88 | task_id=sample_task.id,
89 | context_id=sample_task.context_id,
90 | status=TaskStatus(state=TaskState.completed, message=sample_message),
91 | final=True,
92 | )
93 | updated_task = await task_manager.save_task_event(status_update)
94 | assert updated_task.status.state == TaskState.completed
95 | assert updated_task.history == [sample_message]
96 |
97 |
98 | @pytest.mark.asyncio
99 | async def test_save_task_event_with_artifact_update(
100 | task_manager: ClientTaskManager, sample_task: Task
101 | ):
102 | await task_manager.save_task_event(sample_task)
103 | artifact = Artifact(
104 | artifact_id='art1', parts=[Part(root=TextPart(text='artifact content'))]
105 | )
106 | artifact_update = TaskArtifactUpdateEvent(
107 | task_id=sample_task.id,
108 | context_id=sample_task.context_id,
109 | artifact=artifact,
110 | )
111 |
112 | with patch(
113 | 'a2a.client.client_task_manager.append_artifact_to_task'
114 | ) as mock_append:
115 | updated_task = await task_manager.save_task_event(artifact_update)
116 | mock_append.assert_called_once_with(updated_task, artifact_update)
117 |
118 |
119 | @pytest.mark.asyncio
120 | async def test_save_task_event_creates_task_if_not_exists(
121 | task_manager: ClientTaskManager,
122 | ):
123 | status_update = TaskStatusUpdateEvent(
124 | task_id='new_task',
125 | context_id='new_context',
126 | status=TaskStatus(state=TaskState.working),
127 | final=False,
128 | )
129 | updated_task = await task_manager.save_task_event(status_update)
130 | assert updated_task is not None
131 | assert updated_task.id == 'new_task'
132 | assert updated_task.status.state == TaskState.working
133 |
134 |
135 | @pytest.mark.asyncio
136 | async def test_process_with_task_event(
137 | task_manager: ClientTaskManager, sample_task: Task
138 | ):
139 | with patch.object(
140 | task_manager, 'save_task_event', new_callable=AsyncMock
141 | ) as mock_save:
142 | await task_manager.process(sample_task)
143 | mock_save.assert_called_once_with(sample_task)
144 |
145 |
146 | @pytest.mark.asyncio
147 | async def test_process_with_non_task_event(task_manager: ClientTaskManager):
148 | with patch.object(
149 | task_manager, 'save_task_event', new_callable=Mock
150 | ) as mock_save:
151 | non_task_event = 'not a task event'
152 | await task_manager.process(non_task_event)
153 | mock_save.assert_not_called()
154 |
155 |
156 | def test_update_with_message(
157 | task_manager: ClientTaskManager, sample_task: Task, sample_message: Message
158 | ):
159 | updated_task = task_manager.update_with_message(sample_message, sample_task)
160 | assert updated_task.history == [sample_message]
161 |
162 |
163 | def test_update_with_message_moves_status_message(
164 | task_manager: ClientTaskManager, sample_task: Task, sample_message: Message
165 | ):
166 | status_message = Message(
167 | message_id='status_msg',
168 | role=Role.agent,
169 | parts=[Part(root=TextPart(text='Status'))],
170 | )
171 | sample_task.status.message = status_message
172 | updated_task = task_manager.update_with_message(sample_message, sample_task)
173 | assert updated_task.history == [status_message, sample_message]
174 | assert updated_task.status.message is None
175 |
--------------------------------------------------------------------------------
/tests/client/test_legacy_client.py:
--------------------------------------------------------------------------------
1 | """Tests for the legacy client compatibility layer."""
2 |
3 | from unittest.mock import AsyncMock, MagicMock
4 |
5 | import httpx
6 | import pytest
7 |
8 | from a2a.client import A2AClient, A2AGrpcClient
9 | from a2a.types import (
10 | AgentCapabilities,
11 | AgentCard,
12 | Message,
13 | MessageSendParams,
14 | Part,
15 | Role,
16 | SendMessageRequest,
17 | Task,
18 | TaskQueryParams,
19 | TaskState,
20 | TaskStatus,
21 | TextPart,
22 | )
23 |
24 |
25 | @pytest.fixture
26 | def mock_httpx_client() -> AsyncMock:
27 | return AsyncMock(spec=httpx.AsyncClient)
28 |
29 |
30 | @pytest.fixture
31 | def mock_grpc_stub() -> AsyncMock:
32 | stub = AsyncMock()
33 | stub._channel = MagicMock()
34 | return stub
35 |
36 |
37 | @pytest.fixture
38 | def jsonrpc_agent_card() -> AgentCard:
39 | return AgentCard(
40 | name='Test Agent',
41 | description='A test agent',
42 | url='http://test.agent.com/rpc',
43 | version='1.0.0',
44 | capabilities=AgentCapabilities(streaming=True),
45 | skills=[],
46 | default_input_modes=[],
47 | default_output_modes=[],
48 | preferred_transport='jsonrpc',
49 | )
50 |
51 |
52 | @pytest.fixture
53 | def grpc_agent_card() -> AgentCard:
54 | return AgentCard(
55 | name='Test Agent',
56 | description='A test agent',
57 | url='http://test.agent.com/rpc',
58 | version='1.0.0',
59 | capabilities=AgentCapabilities(streaming=True),
60 | skills=[],
61 | default_input_modes=[],
62 | default_output_modes=[],
63 | preferred_transport='grpc',
64 | )
65 |
66 |
67 | @pytest.mark.asyncio
68 | async def test_a2a_client_send_message(
69 | mock_httpx_client: AsyncMock, jsonrpc_agent_card: AgentCard
70 | ):
71 | client = A2AClient(
72 | httpx_client=mock_httpx_client, agent_card=jsonrpc_agent_card
73 | )
74 |
75 | # Mock the underlying transport's send_message method
76 | mock_response_task = Task(
77 | id='task-123',
78 | context_id='ctx-456',
79 | status=TaskStatus(state=TaskState.completed),
80 | )
81 |
82 | client._transport.send_message = AsyncMock(return_value=mock_response_task)
83 |
84 | message = Message(
85 | message_id='msg-123',
86 | role=Role.user,
87 | parts=[Part(root=TextPart(text='Hello'))],
88 | )
89 | request = SendMessageRequest(
90 | id='req-123', params=MessageSendParams(message=message)
91 | )
92 | response = await client.send_message(request)
93 |
94 | assert response.root.result.id == 'task-123'
95 |
96 |
97 | @pytest.mark.asyncio
98 | async def test_a2a_grpc_client_get_task(
99 | mock_grpc_stub: AsyncMock, grpc_agent_card: AgentCard
100 | ):
101 | client = A2AGrpcClient(grpc_stub=mock_grpc_stub, agent_card=grpc_agent_card)
102 |
103 | mock_response_task = Task(
104 | id='task-456',
105 | context_id='ctx-789',
106 | status=TaskStatus(state=TaskState.working),
107 | )
108 |
109 | client.get_task = AsyncMock(return_value=mock_response_task)
110 |
111 | params = TaskQueryParams(id='task-456')
112 | response = await client.get_task(params)
113 |
114 | assert response.id == 'task-456'
115 | client.get_task.assert_awaited_once_with(params)
116 |
--------------------------------------------------------------------------------
/tests/client/test_optionals.py:
--------------------------------------------------------------------------------
1 | """Tests for a2a.client.optionals module."""
2 |
3 | import importlib
4 | import sys
5 |
6 | from unittest.mock import patch
7 |
8 |
9 | def test_channel_import_failure():
10 | """Test Channel behavior when grpc is not available."""
11 | with patch.dict('sys.modules', {'grpc': None, 'grpc.aio': None}):
12 | if 'a2a.client.optionals' in sys.modules:
13 | del sys.modules['a2a.client.optionals']
14 |
15 | optionals = importlib.import_module('a2a.client.optionals')
16 | assert optionals.Channel is None
17 |
--------------------------------------------------------------------------------
/tests/e2e/push_notifications/agent_app.py:
--------------------------------------------------------------------------------
1 | import httpx
2 |
3 | from fastapi import FastAPI
4 |
5 | from a2a.server.agent_execution import AgentExecutor, RequestContext
6 | from a2a.server.apps import A2ARESTFastAPIApplication
7 | from a2a.server.events import EventQueue
8 | from a2a.server.request_handlers import DefaultRequestHandler
9 | from a2a.server.tasks import (
10 | BasePushNotificationSender,
11 | InMemoryPushNotificationConfigStore,
12 | InMemoryTaskStore,
13 | TaskUpdater,
14 | )
15 | from a2a.types import (
16 | AgentCapabilities,
17 | AgentCard,
18 | AgentSkill,
19 | InvalidParamsError,
20 | Message,
21 | Task,
22 | )
23 | from a2a.utils import (
24 | new_agent_text_message,
25 | new_task,
26 | )
27 | from a2a.utils.errors import ServerError
28 |
29 |
30 | def test_agent_card(url: str) -> AgentCard:
31 | """Returns an agent card for the test agent."""
32 | return AgentCard(
33 | name='Test Agent',
34 | description='Just a test agent',
35 | url=url,
36 | version='1.0.0',
37 | default_input_modes=['text'],
38 | default_output_modes=['text'],
39 | capabilities=AgentCapabilities(streaming=True, push_notifications=True),
40 | skills=[
41 | AgentSkill(
42 | id='greeting',
43 | name='Greeting Agent',
44 | description='just greets the user',
45 | tags=['greeting'],
46 | examples=['Hello Agent!', 'How are you?'],
47 | )
48 | ],
49 | supports_authenticated_extended_card=True,
50 | )
51 |
52 |
53 | class TestAgent:
54 | """Agent for push notification testing."""
55 |
56 | async def invoke(
57 | self, updater: TaskUpdater, msg: Message, task: Task
58 | ) -> None:
59 | # Fail for unsupported messages.
60 | if (
61 | not msg.parts
62 | or len(msg.parts) != 1
63 | or msg.parts[0].root.kind != 'text'
64 | ):
65 | await updater.failed(
66 | new_agent_text_message(
67 | 'Unsupported message.', task.context_id, task.id
68 | )
69 | )
70 | return
71 | text_message = msg.parts[0].root.text
72 |
73 | # Simple request-response flow.
74 | if text_message == 'Hello Agent!':
75 | await updater.complete(
76 | new_agent_text_message('Hello User!', task.context_id, task.id)
77 | )
78 |
79 | # Flow with user input required: "How are you?" -> "Good! How are you?" -> "Good" -> "Amazing".
80 | elif text_message == 'How are you?':
81 | await updater.requires_input(
82 | new_agent_text_message(
83 | 'Good! How are you?', task.context_id, task.id
84 | )
85 | )
86 | elif text_message == 'Good':
87 | await updater.complete(
88 | new_agent_text_message('Amazing', task.context_id, task.id)
89 | )
90 |
91 | # Fail for unsupported messages.
92 | else:
93 | await updater.failed(
94 | new_agent_text_message(
95 | 'Unsupported message.', task.context_id, task.id
96 | )
97 | )
98 |
99 |
100 | class TestAgentExecutor(AgentExecutor):
101 | """Test AgentExecutor implementation."""
102 |
103 | def __init__(self) -> None:
104 | self.agent = TestAgent()
105 |
106 | async def execute(
107 | self,
108 | context: RequestContext,
109 | event_queue: EventQueue,
110 | ) -> None:
111 | if not context.message:
112 | raise ServerError(error=InvalidParamsError(message='No message'))
113 |
114 | task = context.current_task
115 | if not task:
116 | task = new_task(context.message)
117 | await event_queue.enqueue_event(task)
118 | updater = TaskUpdater(event_queue, task.id, task.context_id)
119 |
120 | await self.agent.invoke(updater, context.message, task)
121 |
122 | async def cancel(
123 | self, context: RequestContext, event_queue: EventQueue
124 | ) -> None:
125 | raise NotImplementedError('cancel not supported')
126 |
127 |
128 | def create_agent_app(
129 | url: str, notification_client: httpx.AsyncClient
130 | ) -> FastAPI:
131 | """Creates a new HTTP+REST FastAPI application for the test agent."""
132 | push_config_store = InMemoryPushNotificationConfigStore()
133 | app = A2ARESTFastAPIApplication(
134 | agent_card=test_agent_card(url),
135 | http_handler=DefaultRequestHandler(
136 | agent_executor=TestAgentExecutor(),
137 | task_store=InMemoryTaskStore(),
138 | push_config_store=push_config_store,
139 | push_sender=BasePushNotificationSender(
140 | httpx_client=notification_client,
141 | config_store=push_config_store,
142 | ),
143 | ),
144 | )
145 | return app.build()
146 |
--------------------------------------------------------------------------------
/tests/e2e/push_notifications/notifications_app.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 |
3 | from typing import Annotated
4 |
5 | from fastapi import FastAPI, HTTPException, Path, Request
6 | from pydantic import BaseModel, ValidationError
7 |
8 | from a2a.types import Task
9 |
10 |
11 | class Notification(BaseModel):
12 | """Encapsulates default push notification data."""
13 |
14 | task: Task
15 | token: str
16 |
17 |
18 | def create_notifications_app() -> FastAPI:
19 | """Creates a simple push notification ingesting HTTP+REST application."""
20 | app = FastAPI()
21 | store_lock = asyncio.Lock()
22 | store: dict[str, list[Notification]] = {}
23 |
24 | @app.post('/notifications')
25 | async def add_notification(request: Request):
26 | """Endpoint for injesting notifications from agents. It receives a JSON
27 | payload and stores it in-memory.
28 | """
29 | token = request.headers.get('x-a2a-notification-token')
30 | if not token:
31 | raise HTTPException(
32 | status_code=400,
33 | detail='Missing "x-a2a-notification-token" header.',
34 | )
35 | try:
36 | task = Task.model_validate(await request.json())
37 | except ValidationError as e:
38 | raise HTTPException(status_code=400, detail=str(e))
39 |
40 | async with store_lock:
41 | if task.id not in store:
42 | store[task.id] = []
43 | store[task.id].append(
44 | Notification(
45 | task=task,
46 | token=token,
47 | )
48 | )
49 | return {
50 | 'status': 'received',
51 | }
52 |
53 | @app.get('/tasks/{task_id}/notifications')
54 | async def list_notifications_by_task(
55 | task_id: Annotated[
56 | str, Path(title='The ID of the task to list the notifications for.')
57 | ],
58 | ):
59 | """Helper endpoint for retrieving injested notifications for a given task."""
60 | async with store_lock:
61 | notifications = store.get(task_id, [])
62 | return {'notifications': notifications}
63 |
64 | @app.get('/health')
65 | def health_check():
66 | """Helper endpoint for checking if the server is up."""
67 | return {'status': 'ok'}
68 |
69 | return app
70 |
--------------------------------------------------------------------------------
/tests/e2e/push_notifications/utils.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import socket
3 | import time
4 |
5 | from multiprocessing import Process
6 |
7 | import httpx
8 | import uvicorn
9 |
10 |
11 | def find_free_port():
12 | """Finds and returns an available ephemeral localhost port."""
13 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
14 | s.bind(('127.0.0.1', 0))
15 | return s.getsockname()[1]
16 |
17 |
18 | def run_server(app, host, port) -> None:
19 | """Runs a uvicorn server."""
20 | uvicorn.run(app, host=host, port=port, log_level='warning')
21 |
22 |
23 | def wait_for_server_ready(url: str, timeout: int = 10) -> None:
24 | """Polls the provided URL endpoint until the server is up."""
25 | start_time = time.time()
26 | while True:
27 | with contextlib.suppress(httpx.ConnectError):
28 | with httpx.Client() as client:
29 | response = client.get(url)
30 | if response.status_code == 200:
31 | return
32 | if time.time() - start_time > timeout:
33 | raise TimeoutError(
34 | f'Server at {url} failed to start after {timeout}s'
35 | )
36 | time.sleep(0.1)
37 |
38 |
39 | def create_app_process(app, host, port) -> Process:
40 | """Creates a separate process for a given application."""
41 | return Process(
42 | target=run_server,
43 | args=(app, host, port),
44 | daemon=True,
45 | )
46 |
--------------------------------------------------------------------------------
/tests/extensions/test_common.py:
--------------------------------------------------------------------------------
1 | from a2a.extensions.common import (
2 | find_extension_by_uri,
3 | get_requested_extensions,
4 | )
5 | from a2a.types import AgentCapabilities, AgentCard, AgentExtension
6 |
7 |
8 | def test_get_requested_extensions():
9 | assert get_requested_extensions([]) == set()
10 | assert get_requested_extensions(['foo']) == {'foo'}
11 | assert get_requested_extensions(['foo', 'bar']) == {'foo', 'bar'}
12 | assert get_requested_extensions(['foo, bar']) == {'foo', 'bar'}
13 | assert get_requested_extensions(['foo,bar']) == {'foo', 'bar'}
14 | assert get_requested_extensions(['foo', 'bar,baz']) == {'foo', 'bar', 'baz'}
15 | assert get_requested_extensions(['foo,, bar', 'baz']) == {
16 | 'foo',
17 | 'bar',
18 | 'baz',
19 | }
20 | assert get_requested_extensions([' foo , bar ', 'baz']) == {
21 | 'foo',
22 | 'bar',
23 | 'baz',
24 | }
25 |
26 |
27 | def test_find_extension_by_uri():
28 | ext1 = AgentExtension(uri='foo', description='The Foo extension')
29 | ext2 = AgentExtension(uri='bar', description='The Bar extension')
30 | card = AgentCard(
31 | name='Test Agent',
32 | description='Test Agent Description',
33 | version='1.0',
34 | url='http://test.com',
35 | skills=[],
36 | default_input_modes=['text/plain'],
37 | default_output_modes=['text/plain'],
38 | capabilities=AgentCapabilities(extensions=[ext1, ext2]),
39 | )
40 |
41 | assert find_extension_by_uri(card, 'foo') == ext1
42 | assert find_extension_by_uri(card, 'bar') == ext2
43 | assert find_extension_by_uri(card, 'baz') is None
44 |
45 |
46 | def test_find_extension_by_uri_no_extensions():
47 | card = AgentCard(
48 | name='Test Agent',
49 | description='Test Agent Description',
50 | version='1.0',
51 | url='http://test.com',
52 | skills=[],
53 | default_input_modes=['text/plain'],
54 | default_output_modes=['text/plain'],
55 | capabilities=AgentCapabilities(extensions=None),
56 | )
57 |
58 | assert find_extension_by_uri(card, 'foo') is None
59 |
--------------------------------------------------------------------------------
/tests/server/apps/jsonrpc/test_fastapi_app.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 | from unittest.mock import MagicMock
3 |
4 | import pytest
5 |
6 | from a2a.server.apps.jsonrpc import fastapi_app
7 | from a2a.server.apps.jsonrpc.fastapi_app import A2AFastAPIApplication
8 | from a2a.server.request_handlers.request_handler import (
9 | RequestHandler, # For mock spec
10 | )
11 | from a2a.types import AgentCard # For mock spec
12 |
13 |
14 | # --- A2AFastAPIApplication Tests ---
15 |
16 |
17 | class TestA2AFastAPIApplicationOptionalDeps:
18 | # Running tests in this class requires the optional dependency fastapi to be
19 | # present in the test environment.
20 |
21 | @pytest.fixture(scope='class', autouse=True)
22 | def ensure_pkg_fastapi_is_present(self):
23 | try:
24 | import fastapi as _fastapi # noqa: F401
25 | except ImportError:
26 | pytest.fail(
27 | f'Running tests in {self.__class__.__name__} requires'
28 | ' the optional dependency fastapi to be present in the test'
29 | ' environment. Run `uv sync --dev ...` before running the test'
30 | ' suite.'
31 | )
32 |
33 | @pytest.fixture(scope='class')
34 | def mock_app_params(self) -> dict:
35 | # Mock http_handler
36 | mock_handler = MagicMock(spec=RequestHandler)
37 | # Mock agent_card with essential attributes accessed in __init__
38 | mock_agent_card = MagicMock(spec=AgentCard)
39 | # Ensure 'url' attribute exists on the mock_agent_card, as it's accessed
40 | # in __init__
41 | mock_agent_card.url = 'http://example.com'
42 | # Ensure 'supports_authenticated_extended_card' attribute exists
43 | mock_agent_card.supports_authenticated_extended_card = False
44 | return {'agent_card': mock_agent_card, 'http_handler': mock_handler}
45 |
46 | @pytest.fixture(scope='class')
47 | def mark_pkg_fastapi_not_installed(self):
48 | pkg_fastapi_installed_flag = fastapi_app._package_fastapi_installed
49 | fastapi_app._package_fastapi_installed = False
50 | yield
51 | fastapi_app._package_fastapi_installed = pkg_fastapi_installed_flag
52 |
53 | def test_create_a2a_fastapi_app_with_present_deps_succeeds(
54 | self, mock_app_params: dict
55 | ):
56 | try:
57 | _app = A2AFastAPIApplication(**mock_app_params)
58 | except ImportError:
59 | pytest.fail(
60 | 'With the fastapi package present, creating a'
61 | ' A2AFastAPIApplication instance should not raise ImportError'
62 | )
63 |
64 | def test_create_a2a_fastapi_app_with_missing_deps_raises_importerror(
65 | self,
66 | mock_app_params: dict,
67 | mark_pkg_fastapi_not_installed: Any,
68 | ):
69 | with pytest.raises(
70 | ImportError,
71 | match=(
72 | 'The `fastapi` package is required to use the'
73 | ' `A2AFastAPIApplication`'
74 | ),
75 | ):
76 | _app = A2AFastAPIApplication(**mock_app_params)
77 |
78 |
79 | if __name__ == '__main__':
80 | pytest.main([__file__])
81 |
--------------------------------------------------------------------------------
/tests/server/apps/jsonrpc/test_starlette_app.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 | from unittest.mock import MagicMock
3 |
4 | import pytest
5 |
6 | from a2a.server.apps.jsonrpc import starlette_app
7 | from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication
8 | from a2a.server.request_handlers.request_handler import (
9 | RequestHandler, # For mock spec
10 | )
11 | from a2a.types import AgentCard # For mock spec
12 |
13 |
14 | # --- A2AStarletteApplication Tests ---
15 |
16 |
17 | class TestA2AStarletteApplicationOptionalDeps:
18 | # Running tests in this class requires optional dependencies starlette and
19 | # sse-starlette to be present in the test environment.
20 |
21 | @pytest.fixture(scope='class', autouse=True)
22 | def ensure_pkg_starlette_is_present(self):
23 | try:
24 | import sse_starlette as _sse_starlette # noqa: F401
25 | import starlette as _starlette # noqa: F401
26 | except ImportError:
27 | pytest.fail(
28 | f'Running tests in {self.__class__.__name__} requires'
29 | ' optional dependencies starlette and sse-starlette to be'
30 | ' present in the test environment. Run `uv sync --dev ...`'
31 | ' before running the test suite.'
32 | )
33 |
34 | @pytest.fixture(scope='class')
35 | def mock_app_params(self) -> dict:
36 | # Mock http_handler
37 | mock_handler = MagicMock(spec=RequestHandler)
38 | # Mock agent_card with essential attributes accessed in __init__
39 | mock_agent_card = MagicMock(spec=AgentCard)
40 | # Ensure 'url' attribute exists on the mock_agent_card, as it's accessed
41 | # in __init__
42 | mock_agent_card.url = 'http://example.com'
43 | # Ensure 'supports_authenticated_extended_card' attribute exists
44 | mock_agent_card.supports_authenticated_extended_card = False
45 | return {'agent_card': mock_agent_card, 'http_handler': mock_handler}
46 |
47 | @pytest.fixture(scope='class')
48 | def mark_pkg_starlette_not_installed(self):
49 | pkg_starlette_installed_flag = (
50 | starlette_app._package_starlette_installed
51 | )
52 | starlette_app._package_starlette_installed = False
53 | yield
54 | starlette_app._package_starlette_installed = (
55 | pkg_starlette_installed_flag
56 | )
57 |
58 | def test_create_a2a_starlette_app_with_present_deps_succeeds(
59 | self, mock_app_params: dict
60 | ):
61 | try:
62 | _app = A2AStarletteApplication(**mock_app_params)
63 | except ImportError:
64 | pytest.fail(
65 | 'With packages starlette and see-starlette present, creating an'
66 | ' A2AStarletteApplication instance should not raise ImportError'
67 | )
68 |
69 | def test_create_a2a_starlette_app_with_missing_deps_raises_importerror(
70 | self,
71 | mock_app_params: dict,
72 | mark_pkg_starlette_not_installed: Any,
73 | ):
74 | with pytest.raises(
75 | ImportError,
76 | match='Packages `starlette` and `sse-starlette` are required',
77 | ):
78 | _app = A2AStarletteApplication(**mock_app_params)
79 |
80 |
81 | if __name__ == '__main__':
82 | pytest.main([__file__])
83 |
--------------------------------------------------------------------------------
/tests/server/events/test_inmemory_queue_manager.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 |
3 | from unittest.mock import MagicMock
4 |
5 | import pytest
6 |
7 | from a2a.server.events import InMemoryQueueManager
8 | from a2a.server.events.event_queue import EventQueue
9 | from a2a.server.events.queue_manager import (
10 | NoTaskQueue,
11 | TaskQueueExists,
12 | )
13 |
14 |
15 | class TestInMemoryQueueManager:
16 | @pytest.fixture
17 | def queue_manager(self):
18 | """Fixture to create a fresh InMemoryQueueManager for each test."""
19 | return InMemoryQueueManager()
20 |
21 | @pytest.fixture
22 | def event_queue(self):
23 | """Fixture to create a mock EventQueue."""
24 | queue = MagicMock(spec=EventQueue)
25 | # Mock the tap method to return itself
26 | queue.tap.return_value = queue
27 | return queue
28 |
29 | @pytest.mark.asyncio
30 | async def test_init(self, queue_manager):
31 | """Test that the InMemoryQueueManager initializes with empty task queue and a lock."""
32 | assert queue_manager._task_queue == {}
33 | assert isinstance(queue_manager._lock, asyncio.Lock)
34 |
35 | @pytest.mark.asyncio
36 | async def test_add_new_queue(self, queue_manager, event_queue):
37 | """Test adding a new queue to the manager."""
38 | task_id = 'test_task_id'
39 | await queue_manager.add(task_id, event_queue)
40 | assert queue_manager._task_queue[task_id] == event_queue
41 |
42 | @pytest.mark.asyncio
43 | async def test_add_existing_queue(self, queue_manager, event_queue):
44 | """Test adding a queue with an existing task_id raises TaskQueueExists."""
45 | task_id = 'test_task_id'
46 | await queue_manager.add(task_id, event_queue)
47 |
48 | with pytest.raises(TaskQueueExists):
49 | await queue_manager.add(task_id, event_queue)
50 |
51 | @pytest.mark.asyncio
52 | async def test_get_existing_queue(self, queue_manager, event_queue):
53 | """Test getting an existing queue returns the queue."""
54 | task_id = 'test_task_id'
55 | await queue_manager.add(task_id, event_queue)
56 |
57 | result = await queue_manager.get(task_id)
58 | assert result == event_queue
59 |
60 | @pytest.mark.asyncio
61 | async def test_get_nonexistent_queue(self, queue_manager):
62 | """Test getting a nonexistent queue returns None."""
63 | result = await queue_manager.get('nonexistent_task_id')
64 | assert result is None
65 |
66 | @pytest.mark.asyncio
67 | async def test_tap_existing_queue(self, queue_manager, event_queue):
68 | """Test tapping an existing queue returns the tapped queue."""
69 | task_id = 'test_task_id'
70 | await queue_manager.add(task_id, event_queue)
71 |
72 | result = await queue_manager.tap(task_id)
73 | assert result == event_queue
74 | event_queue.tap.assert_called_once()
75 |
76 | @pytest.mark.asyncio
77 | async def test_tap_nonexistent_queue(self, queue_manager):
78 | """Test tapping a nonexistent queue returns None."""
79 | result = await queue_manager.tap('nonexistent_task_id')
80 | assert result is None
81 |
82 | @pytest.mark.asyncio
83 | async def test_close_existing_queue(self, queue_manager, event_queue):
84 | """Test closing an existing queue removes it from the manager."""
85 | task_id = 'test_task_id'
86 | await queue_manager.add(task_id, event_queue)
87 |
88 | await queue_manager.close(task_id)
89 | assert task_id not in queue_manager._task_queue
90 |
91 | @pytest.mark.asyncio
92 | async def test_close_nonexistent_queue(self, queue_manager):
93 | """Test closing a nonexistent queue raises NoTaskQueue."""
94 | with pytest.raises(NoTaskQueue):
95 | await queue_manager.close('nonexistent_task_id')
96 |
97 | @pytest.mark.asyncio
98 | async def test_create_or_tap_new_queue(self, queue_manager):
99 | """Test create_or_tap with a new task_id creates and returns a new queue."""
100 | task_id = 'test_task_id'
101 |
102 | result = await queue_manager.create_or_tap(task_id)
103 | assert isinstance(result, EventQueue)
104 | assert queue_manager._task_queue[task_id] == result
105 |
106 | @pytest.mark.asyncio
107 | async def test_create_or_tap_existing_queue(
108 | self, queue_manager, event_queue
109 | ):
110 | """Test create_or_tap with an existing task_id taps and returns the existing queue."""
111 | task_id = 'test_task_id'
112 | await queue_manager.add(task_id, event_queue)
113 |
114 | result = await queue_manager.create_or_tap(task_id)
115 |
116 | assert result == event_queue
117 | event_queue.tap.assert_called_once()
118 |
119 | @pytest.mark.asyncio
120 | async def test_concurrency(self, queue_manager):
121 | """Test concurrent access to the queue manager."""
122 |
123 | async def add_task(task_id):
124 | queue = EventQueue()
125 | await queue_manager.add(task_id, queue)
126 | return task_id
127 |
128 | async def get_task(task_id):
129 | return await queue_manager.get(task_id)
130 |
131 | # Create 10 different task IDs
132 | task_ids = [f'task_{i}' for i in range(10)]
133 |
134 | # Add tasks concurrently
135 | add_tasks = [add_task(task_id) for task_id in task_ids]
136 | added_task_ids = await asyncio.gather(*add_tasks)
137 |
138 | # Verify all tasks were added
139 | assert set(added_task_ids) == set(task_ids)
140 |
141 | # Get tasks concurrently
142 | get_tasks = [get_task(task_id) for task_id in task_ids]
143 | queues = await asyncio.gather(*get_tasks)
144 |
145 | # Verify all queues are not None
146 | assert all(queue is not None for queue in queues)
147 |
148 | # Verify all tasks are in the manager
149 | for task_id in task_ids:
150 | assert task_id in queue_manager._task_queue
151 |
--------------------------------------------------------------------------------
/tests/server/tasks/test_inmemory_task_store.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import pytest
4 |
5 | from a2a.server.tasks import InMemoryTaskStore
6 | from a2a.types import Task
7 |
8 |
9 | MINIMAL_TASK: dict[str, Any] = {
10 | 'id': 'task-abc',
11 | 'context_id': 'session-xyz',
12 | 'status': {'state': 'submitted'},
13 | 'kind': 'task',
14 | }
15 |
16 |
17 | @pytest.mark.asyncio
18 | async def test_in_memory_task_store_save_and_get() -> None:
19 | """Test saving and retrieving a task from the in-memory store."""
20 | store = InMemoryTaskStore()
21 | task = Task(**MINIMAL_TASK)
22 | await store.save(task)
23 | retrieved_task = await store.get(MINIMAL_TASK['id'])
24 | assert retrieved_task == task
25 |
26 |
27 | @pytest.mark.asyncio
28 | async def test_in_memory_task_store_get_nonexistent() -> None:
29 | """Test retrieving a nonexistent task."""
30 | store = InMemoryTaskStore()
31 | retrieved_task = await store.get('nonexistent')
32 | assert retrieved_task is None
33 |
34 |
35 | @pytest.mark.asyncio
36 | async def test_in_memory_task_store_delete() -> None:
37 | """Test deleting a task from the store."""
38 | store = InMemoryTaskStore()
39 | task = Task(**MINIMAL_TASK)
40 | await store.save(task)
41 | await store.delete(MINIMAL_TASK['id'])
42 | retrieved_task = await store.get(MINIMAL_TASK['id'])
43 | assert retrieved_task is None
44 |
45 |
46 | @pytest.mark.asyncio
47 | async def test_in_memory_task_store_delete_nonexistent() -> None:
48 | """Test deleting a nonexistent task."""
49 | store = InMemoryTaskStore()
50 | await store.delete('nonexistent')
51 |
--------------------------------------------------------------------------------
/tests/server/tasks/test_push_notification_sender.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from unittest.mock import AsyncMock, MagicMock, patch
4 |
5 | import httpx
6 |
7 | from a2a.server.tasks.base_push_notification_sender import (
8 | BasePushNotificationSender,
9 | )
10 | from a2a.types import (
11 | PushNotificationConfig,
12 | Task,
13 | TaskState,
14 | TaskStatus,
15 | )
16 |
17 |
18 | def create_sample_task(task_id='task123', status_state=TaskState.completed):
19 | return Task(
20 | id=task_id,
21 | context_id='ctx456',
22 | status=TaskStatus(state=status_state),
23 | )
24 |
25 |
26 | def create_sample_push_config(
27 | url='http://example.com/callback', config_id='cfg1', token=None
28 | ):
29 | return PushNotificationConfig(id=config_id, url=url, token=token)
30 |
31 |
32 | class TestBasePushNotificationSender(unittest.IsolatedAsyncioTestCase):
33 | def setUp(self):
34 | self.mock_httpx_client = AsyncMock(spec=httpx.AsyncClient)
35 | self.mock_config_store = AsyncMock()
36 | self.sender = BasePushNotificationSender(
37 | httpx_client=self.mock_httpx_client,
38 | config_store=self.mock_config_store,
39 | )
40 |
41 | def test_constructor_stores_client_and_config_store(self):
42 | self.assertEqual(self.sender._client, self.mock_httpx_client)
43 | self.assertEqual(self.sender._config_store, self.mock_config_store)
44 |
45 | async def test_send_notification_success(self):
46 | task_id = 'task_send_success'
47 | task_data = create_sample_task(task_id=task_id)
48 | config = create_sample_push_config(url='http://notify.me/here')
49 | self.mock_config_store.get_info.return_value = [config]
50 |
51 | mock_response = AsyncMock(spec=httpx.Response)
52 | mock_response.status_code = 200
53 | self.mock_httpx_client.post.return_value = mock_response
54 |
55 | await self.sender.send_notification(task_data)
56 |
57 | self.mock_config_store.get_info.assert_awaited_once_with
58 |
59 | # assert httpx_client post method got invoked with right parameters
60 | self.mock_httpx_client.post.assert_awaited_once_with(
61 | config.url,
62 | json=task_data.model_dump(mode='json', exclude_none=True),
63 | headers=None,
64 | )
65 | mock_response.raise_for_status.assert_called_once()
66 |
67 | async def test_send_notification_with_token_success(self):
68 | task_id = 'task_send_success'
69 | task_data = create_sample_task(task_id=task_id)
70 | config = create_sample_push_config(
71 | url='http://notify.me/here', token='unique_token'
72 | )
73 | self.mock_config_store.get_info.return_value = [config]
74 |
75 | mock_response = AsyncMock(spec=httpx.Response)
76 | mock_response.status_code = 200
77 | self.mock_httpx_client.post.return_value = mock_response
78 |
79 | await self.sender.send_notification(task_data)
80 |
81 | self.mock_config_store.get_info.assert_awaited_once_with
82 |
83 | # assert httpx_client post method got invoked with right parameters
84 | self.mock_httpx_client.post.assert_awaited_once_with(
85 | config.url,
86 | json=task_data.model_dump(mode='json', exclude_none=True),
87 | headers={'X-A2A-Notification-Token': 'unique_token'},
88 | )
89 | mock_response.raise_for_status.assert_called_once()
90 |
91 | async def test_send_notification_no_config(self):
92 | task_id = 'task_send_no_config'
93 | task_data = create_sample_task(task_id=task_id)
94 | self.mock_config_store.get_info.return_value = []
95 |
96 | await self.sender.send_notification(task_data)
97 |
98 | self.mock_config_store.get_info.assert_awaited_once_with(task_id)
99 | self.mock_httpx_client.post.assert_not_called()
100 |
101 | @patch('a2a.server.tasks.base_push_notification_sender.logger')
102 | async def test_send_notification_http_status_error(
103 | self, mock_logger: MagicMock
104 | ):
105 | task_id = 'task_send_http_err'
106 | task_data = create_sample_task(task_id=task_id)
107 | config = create_sample_push_config(url='http://notify.me/http_error')
108 | self.mock_config_store.get_info.return_value = [config]
109 |
110 | mock_response = MagicMock(spec=httpx.Response)
111 | mock_response.status_code = 404
112 | mock_response.text = 'Not Found'
113 | http_error = httpx.HTTPStatusError(
114 | 'Not Found', request=MagicMock(), response=mock_response
115 | )
116 | self.mock_httpx_client.post.side_effect = http_error
117 |
118 | await self.sender.send_notification(task_data)
119 |
120 | self.mock_config_store.get_info.assert_awaited_once_with(task_id)
121 | self.mock_httpx_client.post.assert_awaited_once_with(
122 | config.url,
123 | json=task_data.model_dump(mode='json', exclude_none=True),
124 | headers=None,
125 | )
126 | mock_logger.exception.assert_called_once()
127 |
128 | async def test_send_notification_multiple_configs(self):
129 | task_id = 'task_multiple_configs'
130 | task_data = create_sample_task(task_id=task_id)
131 | config1 = create_sample_push_config(
132 | url='http://notify.me/cfg1', config_id='cfg1'
133 | )
134 | config2 = create_sample_push_config(
135 | url='http://notify.me/cfg2', config_id='cfg2'
136 | )
137 | self.mock_config_store.get_info.return_value = [config1, config2]
138 |
139 | mock_response = AsyncMock(spec=httpx.Response)
140 | mock_response.status_code = 200
141 | self.mock_httpx_client.post.return_value = mock_response
142 |
143 | await self.sender.send_notification(task_data)
144 |
145 | self.mock_config_store.get_info.assert_awaited_once_with(task_id)
146 | self.assertEqual(self.mock_httpx_client.post.call_count, 2)
147 |
148 | # Check calls for config1
149 | self.mock_httpx_client.post.assert_any_call(
150 | config1.url,
151 | json=task_data.model_dump(mode='json', exclude_none=True),
152 | headers=None,
153 | )
154 | # Check calls for config2
155 | self.mock_httpx_client.post.assert_any_call(
156 | config2.url,
157 | json=task_data.model_dump(mode='json', exclude_none=True),
158 | headers=None,
159 | )
160 | mock_response.raise_for_status.call_count = 2
161 |
--------------------------------------------------------------------------------
/tests/server/test_models.py:
--------------------------------------------------------------------------------
1 | """Tests for a2a.server.models module."""
2 |
3 | from unittest.mock import MagicMock
4 |
5 | from sqlalchemy.orm import DeclarativeBase
6 |
7 | from a2a.server.models import (
8 | PydanticListType,
9 | PydanticType,
10 | create_push_notification_config_model,
11 | create_task_model,
12 | )
13 | from a2a.types import Artifact, TaskState, TaskStatus, TextPart
14 |
15 |
16 | class TestPydanticType:
17 | """Tests for PydanticType SQLAlchemy type decorator."""
18 |
19 | def test_process_bind_param_with_pydantic_model(self):
20 | pydantic_type = PydanticType(TaskStatus)
21 | status = TaskStatus(state=TaskState.working)
22 | dialect = MagicMock()
23 |
24 | result = pydantic_type.process_bind_param(status, dialect)
25 | assert result['state'] == 'working'
26 | assert result['message'] is None
27 | # TaskStatus may have other optional fields
28 |
29 | def test_process_bind_param_with_none(self):
30 | pydantic_type = PydanticType(TaskStatus)
31 | dialect = MagicMock()
32 |
33 | result = pydantic_type.process_bind_param(None, dialect)
34 | assert result is None
35 |
36 | def test_process_result_value(self):
37 | pydantic_type = PydanticType(TaskStatus)
38 | dialect = MagicMock()
39 |
40 | result = pydantic_type.process_result_value(
41 | {'state': 'completed', 'message': None}, dialect
42 | )
43 | assert isinstance(result, TaskStatus)
44 | assert result.state == 'completed'
45 |
46 |
47 | class TestPydanticListType:
48 | """Tests for PydanticListType SQLAlchemy type decorator."""
49 |
50 | def test_process_bind_param_with_list(self):
51 | pydantic_list_type = PydanticListType(Artifact)
52 | artifacts = [
53 | Artifact(
54 | artifact_id='1', parts=[TextPart(type='text', text='Hello')]
55 | ),
56 | Artifact(
57 | artifact_id='2', parts=[TextPart(type='text', text='World')]
58 | ),
59 | ]
60 | dialect = MagicMock()
61 |
62 | result = pydantic_list_type.process_bind_param(artifacts, dialect)
63 | assert len(result) == 2
64 | assert result[0]['artifactId'] == '1' # JSON mode uses camelCase
65 | assert result[1]['artifactId'] == '2'
66 |
67 | def test_process_result_value_with_list(self):
68 | pydantic_list_type = PydanticListType(Artifact)
69 | dialect = MagicMock()
70 | data = [
71 | {'artifact_id': '1', 'parts': [{'type': 'text', 'text': 'Hello'}]},
72 | {'artifact_id': '2', 'parts': [{'type': 'text', 'text': 'World'}]},
73 | ]
74 |
75 | result = pydantic_list_type.process_result_value(data, dialect)
76 | assert len(result) == 2
77 | assert all(isinstance(art, Artifact) for art in result)
78 | assert result[0].artifact_id == '1'
79 | assert result[1].artifact_id == '2'
80 |
81 |
82 | def test_create_task_model():
83 | """Test dynamic task model creation."""
84 |
85 | # Create a fresh base to avoid table conflicts
86 | class TestBase(DeclarativeBase):
87 | pass
88 |
89 | # Create with default table name
90 | default_task_model = create_task_model('test_tasks_1', TestBase)
91 | assert default_task_model.__tablename__ == 'test_tasks_1'
92 | assert default_task_model.__name__ == 'TaskModel_test_tasks_1'
93 |
94 | # Create with custom table name
95 | custom_task_model = create_task_model('test_tasks_2', TestBase)
96 | assert custom_task_model.__tablename__ == 'test_tasks_2'
97 | assert custom_task_model.__name__ == 'TaskModel_test_tasks_2'
98 |
99 |
100 | def test_create_push_notification_config_model():
101 | """Test dynamic push notification config model creation."""
102 |
103 | # Create a fresh base to avoid table conflicts
104 | class TestBase(DeclarativeBase):
105 | pass
106 |
107 | # Create with default table name
108 | default_model = create_push_notification_config_model(
109 | 'test_push_configs_1', TestBase
110 | )
111 | assert default_model.__tablename__ == 'test_push_configs_1'
112 |
113 | # Create with custom table name
114 | custom_model = create_push_notification_config_model(
115 | 'test_push_configs_2', TestBase
116 | )
117 | assert custom_model.__tablename__ == 'test_push_configs_2'
118 | assert 'test_push_configs_2' in custom_model.__name__
119 |
--------------------------------------------------------------------------------
/tests/utils/test_artifact.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import uuid
3 |
4 | from unittest.mock import patch
5 |
6 | from a2a.types import DataPart, Part, TextPart
7 | from a2a.utils.artifact import (
8 | new_artifact,
9 | new_data_artifact,
10 | new_text_artifact,
11 | )
12 |
13 |
14 | class TestArtifact(unittest.TestCase):
15 | @patch('uuid.uuid4')
16 | def test_new_artifact_generates_id(self, mock_uuid4):
17 | mock_uuid = uuid.UUID('abcdef12-1234-5678-1234-567812345678')
18 | mock_uuid4.return_value = mock_uuid
19 | artifact = new_artifact(parts=[], name='test_artifact')
20 | self.assertEqual(artifact.artifact_id, str(mock_uuid))
21 |
22 | def test_new_artifact_assigns_parts_name_description(self):
23 | parts = [Part(root=TextPart(text='Sample text'))]
24 | name = 'My Artifact'
25 | description = 'This is a test artifact.'
26 | artifact = new_artifact(parts=parts, name=name, description=description)
27 | self.assertEqual(artifact.parts, parts)
28 | self.assertEqual(artifact.name, name)
29 | self.assertEqual(artifact.description, description)
30 |
31 | def test_new_artifact_empty_description_if_not_provided(self):
32 | parts = [Part(root=TextPart(text='Another sample'))]
33 | name = 'Artifact_No_Desc'
34 | artifact = new_artifact(parts=parts, name=name)
35 | self.assertEqual(artifact.description, '')
36 |
37 | def test_new_text_artifact_creates_single_text_part(self):
38 | text = 'This is a text artifact.'
39 | name = 'Text_Artifact'
40 | artifact = new_text_artifact(text=text, name=name)
41 | self.assertEqual(len(artifact.parts), 1)
42 | self.assertIsInstance(artifact.parts[0].root, TextPart)
43 |
44 | def test_new_text_artifact_part_contains_provided_text(self):
45 | text = 'Hello, world!'
46 | name = 'Greeting_Artifact'
47 | artifact = new_text_artifact(text=text, name=name)
48 | self.assertEqual(artifact.parts[0].root.text, text)
49 |
50 | def test_new_text_artifact_assigns_name_description(self):
51 | text = 'Some content.'
52 | name = 'Named_Text_Artifact'
53 | description = 'Description for text artifact.'
54 | artifact = new_text_artifact(
55 | text=text, name=name, description=description
56 | )
57 | self.assertEqual(artifact.name, name)
58 | self.assertEqual(artifact.description, description)
59 |
60 | def test_new_data_artifact_creates_single_data_part(self):
61 | sample_data = {'key': 'value', 'number': 123}
62 | name = 'Data_Artifact'
63 | artifact = new_data_artifact(data=sample_data, name=name)
64 | self.assertEqual(len(artifact.parts), 1)
65 | self.assertIsInstance(artifact.parts[0].root, DataPart)
66 |
67 | def test_new_data_artifact_part_contains_provided_data(self):
68 | sample_data = {'content': 'test_data', 'is_valid': True}
69 | name = 'Structured_Data_Artifact'
70 | artifact = new_data_artifact(data=sample_data, name=name)
71 | self.assertIsInstance(artifact.parts[0].root, DataPart)
72 | # Ensure the 'data' attribute of DataPart is accessed for comparison
73 | self.assertEqual(artifact.parts[0].root.data, sample_data)
74 |
75 | def test_new_data_artifact_assigns_name_description(self):
76 | sample_data = {'info': 'some details'}
77 | name = 'Named_Data_Artifact'
78 | description = 'Description for data artifact.'
79 | artifact = new_data_artifact(
80 | data=sample_data, name=name, description=description
81 | )
82 | self.assertEqual(artifact.name, name)
83 | self.assertEqual(artifact.description, description)
84 |
85 |
86 | if __name__ == '__main__':
87 | unittest.main()
88 |
--------------------------------------------------------------------------------
/tests/utils/test_constants.py:
--------------------------------------------------------------------------------
1 | """Tests for a2a.utils.constants module."""
2 |
3 | from a2a.utils import constants
4 |
5 |
6 | def test_agent_card_constants():
7 | """Test that agent card constants have expected values."""
8 | assert (
9 | constants.AGENT_CARD_WELL_KNOWN_PATH == '/.well-known/agent-card.json'
10 | )
11 | assert (
12 | constants.PREV_AGENT_CARD_WELL_KNOWN_PATH == '/.well-known/agent.json'
13 | )
14 | assert (
15 | constants.EXTENDED_AGENT_CARD_PATH == '/agent/authenticatedExtendedCard'
16 | )
17 |
18 |
19 | def test_default_rpc_url():
20 | """Test default RPC URL constant."""
21 | assert constants.DEFAULT_RPC_URL == '/'
22 |
--------------------------------------------------------------------------------
/tests/utils/test_error_handlers.py:
--------------------------------------------------------------------------------
1 | """Tests for a2a.utils.error_handlers module."""
2 |
3 | from unittest.mock import patch
4 |
5 | import pytest
6 |
7 | from a2a.types import (
8 | InternalError,
9 | InvalidRequestError,
10 | MethodNotFoundError,
11 | TaskNotFoundError,
12 | )
13 | from a2a.utils.error_handlers import (
14 | A2AErrorToHttpStatus,
15 | rest_error_handler,
16 | rest_stream_error_handler,
17 | )
18 | from a2a.utils.errors import ServerError
19 |
20 |
21 | class MockJSONResponse:
22 | def __init__(self, content, status_code):
23 | self.content = content
24 | self.status_code = status_code
25 |
26 |
27 | @pytest.mark.asyncio
28 | async def test_rest_error_handler_server_error():
29 | """Test rest_error_handler with ServerError."""
30 | error = InvalidRequestError(message='Bad request')
31 |
32 | @rest_error_handler
33 | async def failing_func():
34 | raise ServerError(error=error)
35 |
36 | with patch('a2a.utils.error_handlers.JSONResponse', MockJSONResponse):
37 | result = await failing_func()
38 |
39 | assert isinstance(result, MockJSONResponse)
40 | assert result.status_code == 400
41 | assert result.content == {'message': 'Bad request'}
42 |
43 |
44 | @pytest.mark.asyncio
45 | async def test_rest_error_handler_unknown_exception():
46 | """Test rest_error_handler with unknown exception."""
47 |
48 | @rest_error_handler
49 | async def failing_func():
50 | raise ValueError('Unexpected error')
51 |
52 | with patch('a2a.utils.error_handlers.JSONResponse', MockJSONResponse):
53 | result = await failing_func()
54 |
55 | assert isinstance(result, MockJSONResponse)
56 | assert result.status_code == 500
57 | assert result.content == {'message': 'unknown exception'}
58 |
59 |
60 | @pytest.mark.asyncio
61 | async def test_rest_stream_error_handler_server_error():
62 | """Test rest_stream_error_handler with ServerError."""
63 | error = InternalError(message='Internal server error')
64 |
65 | @rest_stream_error_handler
66 | async def failing_stream():
67 | raise ServerError(error=error)
68 |
69 | with pytest.raises(ServerError) as exc_info:
70 | await failing_stream()
71 |
72 | assert exc_info.value.error == error
73 |
74 |
75 | @pytest.mark.asyncio
76 | async def test_rest_stream_error_handler_reraises_exception():
77 | """Test rest_stream_error_handler reraises other exceptions."""
78 |
79 | @rest_stream_error_handler
80 | async def failing_stream():
81 | raise RuntimeError('Stream failed')
82 |
83 | with pytest.raises(RuntimeError, match='Stream failed'):
84 | await failing_stream()
85 |
86 |
87 | def test_a2a_error_to_http_status_mapping():
88 | """Test A2AErrorToHttpStatus mapping."""
89 | assert A2AErrorToHttpStatus[InvalidRequestError] == 400
90 | assert A2AErrorToHttpStatus[MethodNotFoundError] == 404
91 | assert A2AErrorToHttpStatus[TaskNotFoundError] == 404
92 | assert A2AErrorToHttpStatus[InternalError] == 500
93 |
--------------------------------------------------------------------------------
/tests/utils/test_telemetry.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 |
3 | from typing import NoReturn
4 | from unittest import mock
5 |
6 | import pytest
7 |
8 | from a2a.utils.telemetry import trace_class, trace_function
9 |
10 |
11 | @pytest.fixture
12 | def mock_span():
13 | return mock.MagicMock()
14 |
15 |
16 | @pytest.fixture
17 | def mock_tracer(mock_span):
18 | tracer = mock.MagicMock()
19 | tracer.start_as_current_span.return_value.__enter__.return_value = mock_span
20 | tracer.start_as_current_span.return_value.__exit__.return_value = False
21 | return tracer
22 |
23 |
24 | @pytest.fixture(autouse=True)
25 | def patch_trace_get_tracer(mock_tracer):
26 | with mock.patch('opentelemetry.trace.get_tracer', return_value=mock_tracer):
27 | yield
28 |
29 |
30 | def test_trace_function_sync_success(mock_span):
31 | @trace_function
32 | def foo(x, y):
33 | return x + y
34 |
35 | result = foo(2, 3)
36 | assert result == 5
37 | mock_span.set_status.assert_called()
38 | mock_span.set_status.assert_any_call(mock.ANY)
39 | mock_span.record_exception.assert_not_called()
40 |
41 |
42 | def test_trace_function_sync_exception(mock_span):
43 | @trace_function
44 | def bar() -> NoReturn:
45 | raise ValueError('fail')
46 |
47 | with pytest.raises(ValueError):
48 | bar()
49 | mock_span.record_exception.assert_called()
50 | mock_span.set_status.assert_any_call(mock.ANY, description='fail')
51 |
52 |
53 | def test_trace_function_sync_attribute_extractor_called(mock_span):
54 | called = {}
55 |
56 | def attr_extractor(span, args, kwargs, result, exception) -> None:
57 | called['called'] = True
58 | assert span is mock_span
59 | assert exception is None
60 | assert result == 42
61 |
62 | @trace_function(attribute_extractor=attr_extractor)
63 | def foo() -> int:
64 | return 42
65 |
66 | foo()
67 | assert called['called']
68 |
69 |
70 | def test_trace_function_sync_attribute_extractor_error_logged(mock_span):
71 | with mock.patch('a2a.utils.telemetry.logger') as logger:
72 |
73 | def attr_extractor(span, args, kwargs, result, exception) -> NoReturn:
74 | raise RuntimeError('attr fail')
75 |
76 | @trace_function(attribute_extractor=attr_extractor)
77 | def foo() -> int:
78 | return 1
79 |
80 | foo()
81 | logger.exception.assert_any_call(
82 | 'attribute_extractor error in span %s',
83 | 'test_telemetry.foo',
84 | )
85 |
86 |
87 | @pytest.mark.asyncio
88 | async def test_trace_function_async_success(mock_span):
89 | @trace_function
90 | async def foo(x):
91 | await asyncio.sleep(0)
92 | return x * 2
93 |
94 | result = await foo(4)
95 | assert result == 8
96 | mock_span.set_status.assert_called()
97 | mock_span.record_exception.assert_not_called()
98 |
99 |
100 | @pytest.mark.asyncio
101 | async def test_trace_function_async_exception(mock_span):
102 | @trace_function
103 | async def bar() -> NoReturn:
104 | await asyncio.sleep(0)
105 | raise RuntimeError('async fail')
106 |
107 | with pytest.raises(RuntimeError):
108 | await bar()
109 | mock_span.record_exception.assert_called()
110 | mock_span.set_status.assert_any_call(mock.ANY, description='async fail')
111 |
112 |
113 | @pytest.mark.asyncio
114 | async def test_trace_function_async_attribute_extractor_called(mock_span):
115 | called = {}
116 |
117 | def attr_extractor(span, args, kwargs, result, exception) -> None:
118 | called['called'] = True
119 | assert exception is None
120 | assert result == 99
121 |
122 | @trace_function(attribute_extractor=attr_extractor)
123 | async def foo() -> int:
124 | return 99
125 |
126 | await foo()
127 | assert called['called']
128 |
129 |
130 | def test_trace_function_with_args_and_attributes(mock_span):
131 | @trace_function(span_name='custom.span', attributes={'foo': 'bar'})
132 | def foo() -> int:
133 | return 1
134 |
135 | foo()
136 | mock_span.set_attribute.assert_any_call('foo', 'bar')
137 |
138 |
139 | def test_trace_class_exclude_list(mock_span):
140 | @trace_class(exclude_list=['skip_me'])
141 | class MyClass:
142 | def a(self) -> str:
143 | return 'a'
144 |
145 | def skip_me(self) -> str:
146 | return 'skip'
147 |
148 | def __str__(self):
149 | return 'str'
150 |
151 | obj = MyClass()
152 | assert obj.a() == 'a'
153 | assert obj.skip_me() == 'skip'
154 | # Only 'a' is traced, not 'skip_me' or dunder
155 | assert hasattr(obj.a, '__wrapped__')
156 | assert not hasattr(obj.skip_me, '__wrapped__')
157 |
158 |
159 | def test_trace_class_include_list(mock_span):
160 | @trace_class(include_list=['only_this'])
161 | class MyClass:
162 | def only_this(self) -> str:
163 | return 'yes'
164 |
165 | def not_this(self) -> str:
166 | return 'no'
167 |
168 | obj = MyClass()
169 | assert obj.only_this() == 'yes'
170 | assert obj.not_this() == 'no'
171 | assert hasattr(obj.only_this, '__wrapped__')
172 | assert not hasattr(obj.not_this, '__wrapped__')
173 |
174 |
175 | def test_trace_class_dunder_not_traced(mock_span):
176 | @trace_class()
177 | class MyClass:
178 | def __init__(self):
179 | self.x = 1
180 |
181 | def foo(self) -> str:
182 | return 'foo'
183 |
184 | obj = MyClass()
185 | assert obj.foo() == 'foo'
186 | assert hasattr(obj.foo, '__wrapped__')
187 | assert hasattr(obj, 'x')
188 |
--------------------------------------------------------------------------------