├── tests ├── __init__.py ├── test_task_node.py └── test_task_manager.py ├── src └── async_dag │ ├── py.typed │ ├── state.py │ ├── __init__.py │ ├── execution_result.py │ ├── task_node.py │ └── task_manager.py ├── .python-version ├── .gitignore ├── .github └── workflows │ ├── workflow.yml │ ├── release.yml │ ├── test.yml │ ├── lint.yml │ └── publish.yml ├── LICENSE ├── pyproject.toml ├── examples └── readme.py ├── README.md └── uv.lock /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/async_dag/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.12 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python-generated files 2 | __pycache__/ 3 | *.py[oc] 4 | build/ 5 | dist/ 6 | wheels/ 7 | *.egg-info 8 | 9 | # Virtual environments 10 | .venv 11 | -------------------------------------------------------------------------------- /src/async_dag/state.py: -------------------------------------------------------------------------------- 1 | from enum import IntEnum, auto 2 | 3 | 4 | class State(IntEnum): 5 | UNDISCOVERED = auto() 6 | TEMPORARY = auto() 7 | PERMANENT = auto() 8 | -------------------------------------------------------------------------------- /.github/workflows/workflow.yml: -------------------------------------------------------------------------------- 1 | name: Pull request and main workflow 2 | 3 | on: 4 | pull_request: 5 | types: ["opened", "edited", "reopened", "synchronize"] 6 | push: 7 | branches: 8 | - 'main' 9 | 10 | jobs: 11 | run_lint: 12 | uses: ./.github/workflows/lint.yml 13 | run_test: 14 | uses: ./.github/workflows/test.yml 15 | -------------------------------------------------------------------------------- /src/async_dag/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple library for running complex DAG of async tasks efficiently. 3 | Take a look at `TaskManager` to get started. 4 | """ 5 | 6 | from .execution_result import ExecutionResult 7 | from .task_manager import TaskManager, build_dag 8 | from .task_node import TaskNode 9 | 10 | __all__ = ["ExecutionResult", "TaskManager", "TaskNode", "build_dag"] 11 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release workflow 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | permissions: 8 | contents: read 9 | 10 | jobs: 11 | run_lint: 12 | uses: ./.github/workflows/lint.yml 13 | run_test: 14 | uses: ./.github/workflows/test.yml 15 | publish: 16 | needs: [run_lint, run_test] 17 | permissions: 18 | id-token: write 19 | contents: read 20 | uses: ./.github/workflows/publish.yml 21 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | on: workflow_call 3 | 4 | jobs: 5 | test: 6 | name: Test 7 | runs-on: ubuntu-latest 8 | 9 | steps: 10 | - uses: actions/checkout@v4 11 | 12 | - name: Install uv 13 | uses: astral-sh/setup-uv@v5 14 | with: 15 | enable-cache: true 16 | cache-dependency-glob: "uv.lock" 17 | 18 | - name: "Set up Python" 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version-file: ".python-version" 22 | 23 | - name: Install the project 24 | run: uv sync --group dev 25 | 26 | - name: Run tests 27 | run: uv run pytest 28 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | on: workflow_call 3 | 4 | jobs: 5 | lint: 6 | runs-on: ubuntu-latest 7 | steps: 8 | - uses: actions/checkout@v4 9 | 10 | - name: Install uv 11 | uses: astral-sh/setup-uv@v5 12 | with: 13 | enable-cache: true 14 | cache-dependency-glob: "uv.lock" 15 | 16 | - name: "Set up Python" 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version-file: ".python-version" 20 | 21 | - name: Install the project 22 | run: uv sync --group lint 23 | 24 | - name: Run ruff check 25 | run: uv run ruff check 26 | 27 | - name: Run ruff format 28 | run: uv run ruff format 29 | 30 | - name: Run mypy 31 | run: uv run mypy . 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Mayrom Rabinovich 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish 2 | 3 | on: workflow_call 4 | 5 | jobs: 6 | build: 7 | name: Build 8 | runs-on: ubuntu-latest 9 | 10 | steps: 11 | - uses: actions/checkout@v4 12 | 13 | - name: Install uv 14 | uses: astral-sh/setup-uv@v5 15 | with: 16 | enable-cache: true 17 | cache-dependency-glob: "uv.lock" 18 | 19 | - name: "Set up Python" 20 | uses: actions/setup-python@v5 21 | with: 22 | python-version-file: ".python-version" 23 | 24 | - name: Install the project 25 | run: uv sync 26 | 27 | - name: Run build 28 | run: uv build 29 | 30 | - name: Upload distributions 31 | uses: actions/upload-artifact@v4 32 | with: 33 | name: release-dists 34 | path: dist/ 35 | publish: 36 | runs-on: ubuntu-latest 37 | 38 | needs: 39 | - build 40 | 41 | permissions: 42 | id-token: write 43 | 44 | environment: 45 | name: pypi 46 | 47 | steps: 48 | - name: Retrieve release distributions 49 | uses: actions/download-artifact@v4 50 | with: 51 | name: release-dists 52 | path: dist/ 53 | 54 | - name: Install uv 55 | uses: astral-sh/setup-uv@v5 56 | 57 | - name: Publish lib 58 | run: uv publish 59 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "async-dag" 3 | version = "0.3.1" 4 | description = "A simple library for running complex DAG of async tasks while ensuring maximum possible parallelism" 5 | readme = "README.md" 6 | authors = [ 7 | { name = "Mayrom Rabinovich", email = "nhruo123@gmail.com" } 8 | ] 9 | requires-python = ">=3.12" 10 | dependencies = [] 11 | 12 | [build-system] 13 | requires = ["hatchling"] 14 | build-backend = "hatchling.build" 15 | 16 | [dependency-groups] 17 | dev = [ 18 | "pytest>=8.3.5", 19 | "pytest-asyncio>=0.25.3", 20 | ] 21 | lint = [ 22 | "mypy>=1.15.0", 23 | "ruff>=0.11.2", 24 | ] 25 | 26 | 27 | [tool.hatch.build.targets.sdist] 28 | exclude = [ 29 | "/.github", 30 | "/docs", 31 | "/examples", 32 | ] 33 | 34 | [tool.pytest.ini_options] 35 | asyncio_mode = "auto" 36 | asyncio_default_fixture_loop_scope = "function" 37 | 38 | [tool.ruff] 39 | line-length = 88 40 | extend-exclude = [ 41 | "scripts/cookiecutter/**/*.py", 42 | "*/.venv/*", 43 | "microservices/api/open-approvals-api", 44 | "**/db/migrations/**.py", 45 | ] 46 | 47 | [tool.ruff.lint] 48 | select = [ 49 | "E4", 50 | "E7", 51 | "E9", 52 | "F", 53 | "B", 54 | "Q", 55 | "I", 56 | "UP", 57 | "ASYNC", 58 | "FAST", 59 | "A", 60 | "C4", 61 | "T10", 62 | "ISC", 63 | "LOG", 64 | "PIE", 65 | "PYI", 66 | "SIM", 67 | "RUF", 68 | "DTZ", 69 | "G", 70 | "ANN", 71 | "RET", 72 | "N", 73 | ] 74 | ignore = ["F811"] 75 | 76 | [tool.mypy] 77 | mypy_path = [ 78 | "src", "tests" 79 | ] 80 | strict = true 81 | warn_redundant_casts = true 82 | warn_unused_ignores = true 83 | warn_no_return = true 84 | warn_return_any = true 85 | warn_unreachable = true 86 | install_types = true 87 | non_interactive = true 88 | show_error_code_links = true 89 | disallow_untyped_defs = true 90 | namespace_packages = true 91 | exclude = [ 92 | ".venv/*" 93 | ] 94 | -------------------------------------------------------------------------------- /tests/test_task_node.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from async_dag.execution_result import ExecutionResult 4 | from async_dag.task_manager import TaskManager, build_dag 5 | 6 | 7 | async def imm() -> int: 8 | return 0 9 | 10 | 11 | async def test_invoke_with_task_manager_mismatch_errors() -> None: 12 | with pytest.raises(ValueError): 13 | with build_dag() as tm: 14 | node = tm.add_node(imm) 15 | 16 | other_tm = TaskManager[None]() 17 | execution_result = ExecutionResult(other_tm, None) 18 | 19 | await node._invoke(execution_result) 20 | 21 | 22 | async def test_extract_result_with_task_manager_mismatch_errors() -> None: 23 | with pytest.raises(ValueError): 24 | with build_dag() as tm: 25 | node = tm.add_node(imm) 26 | 27 | other_tm = TaskManager[None]() 28 | execution_result = ExecutionResult(other_tm, None) 29 | 30 | node.extract_result(execution_result) 31 | 32 | 33 | async def test_extract_result_before_sort_errors() -> None: 34 | with pytest.raises(ValueError): 35 | tm = TaskManager[None]() 36 | node = tm.add_node(imm) 37 | execution_result = ExecutionResult(tm, None) 38 | 39 | node.extract_result(execution_result) 40 | 41 | 42 | async def test_invoke_before_sort_errors() -> None: 43 | with pytest.raises(ValueError): 44 | tm = TaskManager[None]() 45 | node = tm.add_node(imm) 46 | execution_result = ExecutionResult(tm, None) 47 | 48 | await node._invoke(execution_result) 49 | 50 | 51 | async def test_extract_result_should_return_value_from_index_from_id() -> None: 52 | expected = 999 53 | with build_dag() as tm: 54 | node = tm.add_node(imm) 55 | execution_result = ExecutionResult(tm, None) 56 | execution_result._results[node._id] = expected 57 | 58 | assert expected == node.extract_result(execution_result) 59 | 60 | 61 | async def test_invoke_returns_its_value() -> None: 62 | with build_dag() as tm: 63 | node = tm.add_node(imm) 64 | execution_result = ExecutionResult(tm, None) 65 | 66 | result = await node._invoke(execution_result) 67 | 68 | assert result == await imm() 69 | -------------------------------------------------------------------------------- /src/async_dag/execution_result.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import TYPE_CHECKING 3 | 4 | if TYPE_CHECKING: 5 | from .task_manager import TaskManager 6 | from .task_node import TaskNode 7 | 8 | 9 | class ExecutionResult[_ParameterType]: 10 | """ 11 | `ExecutionResult` represents the return values of all the nodes in a DAG for a specific `invoke` call. 12 | This class does not export any API, and should only be used when returned from `invoke` in order to pass to `TaskNode.extract_result` or for type annotations. 13 | **You should never initialize this class by yourself!** 14 | """ 15 | 16 | def __init__( 17 | self, task_manager: "TaskManager[_ParameterType]", parameter: _ParameterType 18 | ) -> None: 19 | self._tasks = task_manager._tasks 20 | self._results: list[object] = [None] * len(self._tasks) 21 | self._task_manager = task_manager 22 | self._tasks_missing_dependencies_count = [ 23 | len(task._dependencies_ids) for task in self._tasks 24 | ] 25 | self._starting_nodes_id = self._task_manager._starting_nodes_id 26 | self._parameter = parameter 27 | 28 | async def _invoke_task( 29 | self, task: "TaskNode[_ParameterType, object]", tg: asyncio.TaskGroup 30 | ) -> None: 31 | self._on_task_completion(task, await task._invoke(self), tg) 32 | 33 | def _on_task_completion( 34 | self, 35 | task: "TaskNode[_ParameterType, object]", 36 | result: object, 37 | tg: asyncio.TaskGroup, 38 | ) -> None: 39 | self._results[task._id] = result 40 | 41 | for dependent_id in task._dependents_ids: 42 | self._tasks_missing_dependencies_count[dependent_id] -= 1 43 | if self._tasks_missing_dependencies_count[dependent_id] <= 0: 44 | tg.create_task(self._invoke_task(self._tasks[dependent_id], tg)) 45 | 46 | async def _invoke(self) -> None: 47 | async with asyncio.TaskGroup() as tg: 48 | for node_id in self._starting_nodes_id: 49 | task = self._tasks[node_id] 50 | if task is self._task_manager._parameter_node: 51 | self._on_task_completion(task, self._parameter, tg) 52 | else: 53 | tg.create_task(self._invoke_task(task, tg)) 54 | -------------------------------------------------------------------------------- /src/async_dag/task_node.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Awaitable, Callable, Sequence 2 | from typing import TYPE_CHECKING, cast 3 | 4 | from .execution_result import ExecutionResult 5 | from .state import State 6 | 7 | if TYPE_CHECKING: 8 | from .task_manager import TaskManager 9 | 10 | 11 | class TaskNode[_ParameterType, _ReturnType]: 12 | """ 13 | `TaskNode` represents a task in the DAG. 14 | The only API you should use in this class is the `extract_result` method. 15 | **You should never initialize this class by yourself!** 16 | """ 17 | 18 | def __init__( 19 | self, 20 | callback: Callable[..., Awaitable[_ReturnType]], 21 | task_manager: "TaskManager[_ParameterType]", 22 | dependencies_ids: Sequence[int], 23 | node_id: int, 24 | ) -> None: 25 | self._task_manager = task_manager 26 | self._dependencies_ids = dependencies_ids 27 | self._callback = callback 28 | self._state = State.UNDISCOVERED 29 | self._depth = 0 30 | self._id = node_id 31 | self._dependents_ids: set[int] = set() 32 | 33 | def extract_result( 34 | self, execution_result: ExecutionResult[_ParameterType] 35 | ) -> _ReturnType: 36 | """ 37 | Returns the value that the `callback` of the task returned for a specific `TaskManager.invoke` call represented by the `execution_result` parameter. 38 | 39 | This function raises a `ValueError` if: 40 | 1. It was called before `TaskManager.sort()` was called. 41 | 2. the `execution_result` passed to it was from a different `TaskManager` then the one that created this node. 42 | """ 43 | self._assert_state(State.PERMANENT) 44 | self._assert_task_manager(execution_result._task_manager) 45 | 46 | return cast(_ReturnType, execution_result._results[self._id]) 47 | 48 | async def _invoke( 49 | self, 50 | execution_result: ExecutionResult[_ParameterType], 51 | ) -> _ReturnType: 52 | self._assert_state(State.PERMANENT) 53 | self._assert_task_manager(execution_result._task_manager) 54 | 55 | return await self._callback( 56 | *[execution_result._results[dep_id] for dep_id in self._dependencies_ids], 57 | ) 58 | 59 | def _assert_state(self, expected_state: State) -> None: 60 | if self._state != expected_state: 61 | raise ValueError( 62 | f"TaskNode in invalid state, current state: {self._state}, expected state: {expected_state}" 63 | ) 64 | 65 | def _assert_task_manager(self, task_manager: "TaskManager[_ParameterType]") -> None: 66 | if self._task_manager is not task_manager: 67 | raise ValueError( 68 | f"Task manager mismatch, expected: {self._task_manager} but got: {task_manager}" 69 | ) 70 | -------------------------------------------------------------------------------- /examples/readme.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from async_dag import build_dag 4 | 5 | 6 | async def inc_task(n: int, name: str, delay: float) -> int: 7 | print(f"{name} task started...") 8 | await asyncio.sleep(delay) 9 | print(f"{name} task done!") 10 | 11 | return n + 1 12 | 13 | 14 | async def add_task(a: int, b: int, name: str, delay: float) -> int: 15 | print(f"{name} task started...") 16 | await asyncio.sleep(delay) 17 | print(f"{name} task done!") 18 | 19 | return a + b 20 | 21 | 22 | # Define the DAG 23 | with build_dag(int) as tm: 24 | # Each node is made of an async function, and the parameters that will get passed to it at invoke time, a parameter can be either a value or another node. 25 | # We are essentially creating a partially applied async function, just like `functools.partial`. 26 | 27 | # tm.parameter_node is a spacial node that will get resolved into the invoke parameter (the value passed to `tm.invoke`) 28 | # you can also pass an immediate value to the node as a constant that will be the same across all invocations 29 | fast_task_a = tm.add_node( 30 | inc_task, 31 | tm.parameter_node, 32 | "fast_task_a", 33 | 0.1, 34 | ) 35 | 36 | # here we pass the result from fast_task_a as the n param to inc_task node 37 | slow_task_b = tm.add_node( 38 | inc_task, 39 | fast_task_a, 40 | "slow_task_b", 41 | 1, 42 | ) 43 | 44 | slow_task_a = tm.add_node( 45 | inc_task, 46 | tm.parameter_node, 47 | "slow_task_a", 48 | 0.5, 49 | ) 50 | fast_task_b = tm.add_node( 51 | inc_task, 52 | tm.parameter_node, 53 | "fast_task_b", 54 | 0.2, 55 | ) 56 | fast_task_c = tm.add_node( 57 | add_task, 58 | slow_task_a, 59 | fast_task_b, 60 | "fast_task_c", 61 | 0.1, 62 | ) 63 | 64 | end_task = tm.add_node(add_task, fast_task_c, slow_task_b, "end_task", 0.1) 65 | 66 | 67 | # Invoke the DAG 68 | async def main() -> None: 69 | # In order to execute our partially applied DAG we call `tm.invoke` and pass in the parameters, we can invoke the same DAG many times after we have fully built it. 70 | # each run returns an `ExecutionResult` which can be used to extract the return value of each node by calling `extract_result` on the node. 71 | 72 | # prints: 73 | # fast_task_a task started... 74 | # slow_task_a task started... 75 | # fast_task_b task started... 76 | # fast_task_a task done! 77 | # slow_task_b task started... 78 | # fast_task_b task done! 79 | # slow_task_a task done! 80 | # fast_task_c task started... 81 | # fast_task_c task done! 82 | # slow_task_b task done! 83 | # end_task task started... 84 | # end_task task done! 85 | execution_result = await tm.invoke(0) 86 | 87 | # we can extract each node return value 88 | print(fast_task_a.extract_result(execution_result)) # 1 89 | print(end_task.extract_result(execution_result)) # 4 90 | 91 | 92 | if __name__ == "__main__": 93 | asyncio.run(main()) 94 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | async-dag 2 | --- 3 | [![PyPI - Version](https://img.shields.io/pypi/v/async-dag)](https://pypi.org/project/async-dag/) 4 | [![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/nhruo123/async-dag/workflow.yml)](https://github.com/nhruo123/async-dag/actions) 5 | 6 | 7 | A simple library for running complex DAG of async tasks while ensuring maximum possible parallelism. 8 | 9 | ### Use case and example 10 | 11 | Lets assume that you have the following task dependencies graph (each task is an async function that could take time to resolve, and some tasks may depends on each other): 12 | ```mermaid 13 | graph TD; 14 | FastTask_A-->SlowTask_B; 15 | SlowTask_B-->EndTask; 16 | 17 | SlowTask_A-->FastTask_C; 18 | FastTask_B-->FastTask_C; 19 | FastTask_C-->EndTask; 20 | ``` 21 | 22 | A naive way to write something like this would be: 23 | ```python 24 | await end_task( 25 | await slow_task_b( 26 | await fast_task_a() 27 | ), 28 | await fast_task_c( 29 | await slow_task_a(), 30 | await fast_task_b() 31 | ) 32 | ) 33 | ``` 34 | But that would be bad because we will miss a lot of opportunities to run tasks in parallel. 35 | 36 | A better version would be: 37 | ```python 38 | fast_task_a_res, slow_task_a_res, fast_task_b_res = await asyncio.gather( 39 | fast_task_a(), slow_task_a(), fast_task_b() 40 | ) 41 | slow_task_b_res, fast_task_c_res = await asyncio.gather( 42 | slow_task_b(fast_task_a_res), fast_task_c(slow_task_a_res, fast_task_b_res) 43 | ) 44 | await end_task(slow_task_b_res, fast_task_c_res) 45 | ``` 46 | Where we run `fast_task_a_res`, `slow_task_a_res`, and `fast_task_b_res` in parallel, and then after we are done with them we run `slow_task_b` and `fast_task_c`. 47 | but this is still suboptimal because we can start executing either `slow_task_b` once `fast_task_a` ends or `fast_task_c` once both `slow_task_a` and `fast_task_b` ends. 48 | 49 | The optimal way to run this flow would be: 50 | ```python 51 | async def _left_branch(): 52 | return await slow_task_b(await fast_task_a()) 53 | 54 | 55 | async def _right_branch(): 56 | slow_task_a_res, fast_task_b_res = await asyncio.gather( 57 | slow_task_a(), fast_task_b() 58 | ) 59 | 60 | return await fast_task_c(slow_task_a_res, fast_task_b_res) 61 | 62 | 63 | async def _end_node(): 64 | left_branch_res, right_branch_res = await asyncio.gather( 65 | _left_branch(), _right_branch() 66 | ) 67 | 68 | return await end_task(left_branch_res, right_branch_res) 69 | 70 | 71 | await _end_node() 72 | ``` 73 | Which is very cumbersome and error prone to write by hand. 74 | 75 | 76 | Using `async_dag` you can just write: 77 | ```python 78 | from async_dag import build_dag 79 | 80 | # Define your DAG 81 | with build_dag() as tm: 82 | fast_task_a_node = tm.add_node(fast_task_a) 83 | slow_task_b_node = tm.add_node(slow_task_b, fast_task_a_node) 84 | 85 | slow_task_a_node = tm.add_node(slow_task_a) 86 | fast_task_b_node = tm.add_node(fast_task_b) 87 | fast_task_c_node = tm.add_node(fast_task_c, slow_task_a_node, fast_task_b_node) 88 | 89 | end_task_node = tm.add_node(end_task, slow_task_b_node, fast_task_c_node) 90 | 91 | # Execute your DAG 92 | execution_result = await tm.invoke(None) 93 | 94 | # Extract the return value of `end_task` 95 | end_task_result = end_task_node.extract_result(execution_result) 96 | ``` 97 | And enjoy maximum parallelism without the hassle of creating a lot of async functions by hand. 98 | 99 | **For a full example take a look at [examples/readme.py](https://github.com/nhruo123/async-dag/blob/main/examples/readme.py).** 100 | 101 | ### Docs and more examples 102 | We use Docstring in order to describe our API, the main concepts you need to know are: 103 | 1. The [TaskManager](https://github.com/nhruo123/async-dag/blob/main/src/async_dag/task_manager.py) class. 104 | 2. The [add_node](https://github.com/nhruo123/async-dag/blob/main/src/async_dag/task_manager.py) method on `TaskManager`. 105 | 3. The [parameter_node](https://github.com/nhruo123/async-dag/blob/main/src/async_dag/task_manager.py) property of `TaskManager`. 106 | 4. The [sort](https://github.com/nhruo123/async-dag/blob/main/src/async_dag/task_manager.py) method of `TaskManager`. 107 | 5. The [invoke](https://github.com/nhruo123/async-dag/blob/main/src/async_dag/task_manager.py) method of `TaskManager`. 108 | 6. The [extract_result](https://github.com/nhruo123/async-dag/blob/main/src/async_dag/task_node.py) method of `TaskNode`. 109 | 7. The [build_dag](https://github.com/nhruo123/async-dag/blob/main/src/async_dag/task_manager.py) helper function. 110 | -------------------------------------------------------------------------------- /tests/test_task_manager.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from collections.abc import Awaitable, Callable 3 | from dataclasses import dataclass 4 | from typing import Never 5 | 6 | import pytest 7 | 8 | from async_dag.task_manager import TaskManager, build_dag 9 | 10 | 11 | @dataclass 12 | class Input: 13 | starting_number: int 14 | 15 | 16 | async def imm() -> int: 17 | return 1 18 | 19 | 20 | async def from_args(args: Input) -> int: 21 | return args.starting_number 22 | 23 | 24 | async def inc(n: int) -> int: 25 | return n + 1 26 | 27 | 28 | async def to_str(n: int) -> str: 29 | return str(n) 30 | 31 | 32 | async def to_int(n: str, _: Input) -> int: 33 | return int(n) 34 | 35 | 36 | async def inc_str(n: str) -> int: 37 | return int(n) + 1 38 | 39 | 40 | async def inc_max(a: int, b: int) -> int: 41 | return max(a, b) + 1 42 | 43 | 44 | async def raise_exception(exception: Exception, _: int) -> Never: 45 | raise exception 46 | 47 | 48 | async def test_task_manager_sanity() -> None: 49 | tm = TaskManager[Input]() 50 | starting_node = tm.add_node(from_args, tm.parameter_node) 51 | 52 | inc_1 = tm.add_node(inc, starting_node) 53 | 54 | str_node = tm.add_node(to_str, starting_node) 55 | str_to_int_node = tm.add_node(to_int, str_node, tm.parameter_node) 56 | int_node = tm.add_node(inc_str, str_node) 57 | inc_2 = tm.add_node(inc, int_node) 58 | 59 | end_1 = tm.add_node(inc_max, inc_2, inc_1) 60 | 61 | end_2 = tm.add_node(inc_max, inc_1, starting_node) 62 | 63 | end_3 = tm.add_node(inc_max, str_to_int_node, starting_node) 64 | 65 | tm.sort() 66 | 67 | result_1 = await tm.invoke(Input(0)) 68 | 69 | result_2 = await tm.invoke(Input(999)) 70 | 71 | assert end_1.extract_result(result_1) == 3 72 | assert end_2.extract_result(result_1) == 2 73 | assert end_3.extract_result(result_1) == 1 74 | 75 | assert end_1.extract_result(result_2) == 1002 76 | assert end_2.extract_result(result_2) == 1001 77 | assert end_3.extract_result(result_2) == 1000 78 | 79 | 80 | async def test_build_dag_sanity() -> None: 81 | with build_dag(Input) as tm: 82 | starting_node = tm.add_node(from_args, tm.parameter_node) 83 | 84 | inc_1 = tm.add_node(inc, starting_node) 85 | 86 | str_node = tm.add_node(to_str, starting_node) 87 | str_to_int_node = tm.add_node(to_int, str_node, tm.parameter_node) 88 | int_node = tm.add_node(inc_str, str_node) 89 | inc_2 = tm.add_node(inc, int_node) 90 | 91 | end_1 = tm.add_node(inc_max, inc_2, inc_1) 92 | 93 | end_2 = tm.add_node(inc_max, inc_1, starting_node) 94 | 95 | end_3 = tm.add_node(inc_max, str_to_int_node, starting_node) 96 | 97 | result_1 = await tm.invoke(Input(0)) 98 | 99 | result_2 = await tm.invoke(Input(999)) 100 | 101 | assert end_1.extract_result(result_1) == 3 102 | assert end_2.extract_result(result_1) == 2 103 | assert end_3.extract_result(result_1) == 1 104 | 105 | assert end_1.extract_result(result_2) == 1002 106 | assert end_2.extract_result(result_2) == 1001 107 | assert end_3.extract_result(result_2) == 1000 108 | 109 | 110 | async def test_multiple_invocation() -> None: 111 | with build_dag(Input) as tm: 112 | from_args_node = tm.add_node(from_args, tm.parameter_node) 113 | 114 | expected_1 = 0 115 | result_1 = await tm.invoke(Input(expected_1)) 116 | 117 | expected_2 = 999 118 | result_2 = await tm.invoke(Input(expected_2)) 119 | 120 | assert from_args_node.extract_result(result_1) == expected_1 121 | assert from_args_node.extract_result(result_2) == expected_2 122 | 123 | 124 | async def test_dag_with_single_node() -> None: 125 | with build_dag() as tm: 126 | starting_node = tm.add_node(imm) 127 | result = await tm.invoke(None) 128 | 129 | assert starting_node.extract_result(result) == await imm() 130 | 131 | 132 | async def test_dag_with_zero_nodes() -> None: 133 | with build_dag() as tm: 134 | pass 135 | result = await tm.invoke(None) 136 | assert result is not None 137 | 138 | 139 | async def test_sorting_twice_should_error() -> None: 140 | with pytest.raises(ValueError), build_dag() as tm: 141 | tm.sort() 142 | 143 | with pytest.raises(ValueError): 144 | tm = TaskManager[None]() 145 | tm.sort() 146 | tm.sort() 147 | 148 | 149 | async def test_adding_nodes_after_sort_should_error() -> None: 150 | with pytest.raises(ValueError): 151 | with build_dag() as tm: 152 | pass 153 | tm.add_node(imm) 154 | 155 | with pytest.raises(ValueError): 156 | tm = TaskManager[None]() 157 | tm.sort() 158 | tm.add_node(imm) 159 | 160 | 161 | async def test_invoke_before_sort_should_error() -> None: 162 | with pytest.raises(ValueError), build_dag() as tm: 163 | await tm.invoke(None) 164 | 165 | with pytest.raises(ValueError): 166 | tm = TaskManager[None]() 167 | await tm.invoke(None) 168 | 169 | 170 | async def test_calling_order_of_dag() -> None: 171 | def define_step( 172 | expected_state: int, delay: float 173 | ) -> Callable[[None, Input], Awaitable[None]]: 174 | async def inner(_: None, global_state: Input) -> None: 175 | await asyncio.sleep(delay) 176 | assert expected_state == global_state.starting_number 177 | global_state.starting_number += 1 178 | return 179 | 180 | return inner 181 | 182 | def define_merge_step( 183 | expected_state: int, delay: float 184 | ) -> Callable[[None, None, Input], Awaitable[None]]: 185 | async def inner(_a: None, _b: None, global_state: Input) -> None: 186 | await asyncio.sleep(delay) 187 | assert expected_state == global_state.starting_number 188 | global_state.starting_number += 1 189 | return 190 | 191 | return inner 192 | 193 | with build_dag(Input) as tm: 194 | starting_node = tm.add_immediate_node(None) 195 | 196 | single_path_to_end_node_1 = tm.add_node( 197 | define_step(1, 0), starting_node, tm.parameter_node 198 | ) 199 | single_path_to_end_node_2 = tm.add_node( 200 | define_step(2, 0), single_path_to_end_node_1, tm.parameter_node 201 | ) 202 | 203 | merge_path_to_end_node_1_1 = tm.add_node( 204 | define_step(3, 0.1), starting_node, tm.parameter_node 205 | ) 206 | merge_path_to_end_node_1_2 = tm.add_node( 207 | define_step(4, 0.15), starting_node, tm.parameter_node 208 | ) 209 | 210 | merge_path_to_end_node_2 = tm.add_node( 211 | define_merge_step(5, 0), 212 | merge_path_to_end_node_1_1, 213 | merge_path_to_end_node_1_2, 214 | tm.parameter_node, 215 | ) 216 | 217 | tm.add_node( 218 | define_merge_step(6, 0), 219 | single_path_to_end_node_2, 220 | merge_path_to_end_node_2, 221 | tm.parameter_node, 222 | ) 223 | context = Input(1) 224 | await tm.invoke(context) 225 | 226 | assert context.starting_number == 7 227 | 228 | 229 | async def test_add_node_with_mixed_managers_errors() -> None: 230 | with pytest.raises(ValueError), build_dag() as tm_1, build_dag() as tm_2: 231 | starting_node = tm_1.add_node(imm) 232 | tm_2.add_node(inc, starting_node) 233 | 234 | 235 | async def test_sort_with_non_dag_graph_errors() -> None: 236 | with pytest.raises(ValueError), build_dag() as tm: 237 | node_1 = tm.add_node(imm) 238 | node_2 = tm.add_node(imm) 239 | 240 | node_1._dependencies_ids = [node_2._id] 241 | node_2._dependencies_ids = [node_1._id] 242 | 243 | 244 | async def test_immediate_node_should_return_its_value() -> None: 245 | expected = "Hello World" 246 | with build_dag() as tm: 247 | imm_node = tm.add_immediate_node(expected) 248 | 249 | result = await tm.invoke(None) 250 | 251 | assert imm_node.extract_result(result) == expected 252 | 253 | 254 | async def test_duplicate_task_node_input() -> None: 255 | expected_value = 0 256 | 257 | async def merge(_a: None, _b: None, input_: Input) -> None: 258 | assert input_.starting_number == expected_value 259 | input_.starting_number += 1 260 | return 261 | 262 | with build_dag(Input) as tm: 263 | starting_node = tm.add_immediate_node(None) 264 | 265 | tm.add_node(merge, starting_node, starting_node, tm.parameter_node) 266 | 267 | await tm.invoke(Input(expected_value)) 268 | 269 | 270 | async def test_parameter_node_should_be_a_singleton() -> None: 271 | with build_dag(Input) as tm: 272 | assert tm.parameter_node is tm.parameter_node 273 | 274 | 275 | async def test_parameter_node_result_should_be_the_parameter() -> None: 276 | expected = "FOO BAR" 277 | with build_dag(str) as tm: 278 | node = tm.parameter_node 279 | result = await tm.invoke(expected) 280 | 281 | assert node.extract_result(result) == expected 282 | 283 | 284 | async def test_parameter_node_should_never_be_called() -> None: 285 | with build_dag() as tm: 286 | assert tm.parameter_node is tm.parameter_node 287 | 288 | await tm.invoke(None) 289 | 290 | 291 | async def test_immediate_value_get_converted_to_node() -> None: 292 | expected_value = 1 293 | with build_dag() as tm: 294 | node = tm.add_node(inc, 0) 295 | 296 | result = await tm.invoke(None) 297 | assert node.extract_result(result) == expected_value 298 | 299 | 300 | async def test_raised_error_should_reach_the_caller() -> None: 301 | expected = ValueError("FOO BAR") 302 | with build_dag() as tm: 303 | node_1 = tm.add_node(inc, 0) 304 | node_2 = tm.add_node(inc, node_1) 305 | node_3 = tm.add_node(inc, node_2) 306 | tm.add_node(raise_exception, expected, node_3) 307 | 308 | try: 309 | await tm.invoke(None) 310 | except ExceptionGroup as actual: 311 | assert actual.exceptions[0] == expected 312 | -------------------------------------------------------------------------------- /uv.lock: -------------------------------------------------------------------------------- 1 | version = 1 2 | revision = 1 3 | requires-python = ">=3.12" 4 | 5 | [[package]] 6 | name = "async-dag" 7 | version = "0.3.1" 8 | source = { editable = "." } 9 | 10 | [package.dev-dependencies] 11 | dev = [ 12 | { name = "pytest" }, 13 | { name = "pytest-asyncio" }, 14 | ] 15 | lint = [ 16 | { name = "mypy" }, 17 | { name = "ruff" }, 18 | ] 19 | 20 | [package.metadata] 21 | 22 | [package.metadata.requires-dev] 23 | dev = [ 24 | { name = "pytest", specifier = ">=8.3.5" }, 25 | { name = "pytest-asyncio", specifier = ">=0.25.3" }, 26 | ] 27 | lint = [ 28 | { name = "mypy", specifier = ">=1.15.0" }, 29 | { name = "ruff", specifier = ">=0.11.2" }, 30 | ] 31 | 32 | [[package]] 33 | name = "colorama" 34 | version = "0.4.6" 35 | source = { registry = "https://pypi.org/simple" } 36 | sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697 } 37 | wheels = [ 38 | { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335 }, 39 | ] 40 | 41 | [[package]] 42 | name = "iniconfig" 43 | version = "2.1.0" 44 | source = { registry = "https://pypi.org/simple" } 45 | sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793 } 46 | wheels = [ 47 | { url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050 }, 48 | ] 49 | 50 | [[package]] 51 | name = "mypy" 52 | version = "1.15.0" 53 | source = { registry = "https://pypi.org/simple" } 54 | dependencies = [ 55 | { name = "mypy-extensions" }, 56 | { name = "typing-extensions" }, 57 | ] 58 | sdist = { url = "https://files.pythonhosted.org/packages/ce/43/d5e49a86afa64bd3839ea0d5b9c7103487007d728e1293f52525d6d5486a/mypy-1.15.0.tar.gz", hash = "sha256:404534629d51d3efea5c800ee7c42b72a6554d6c400e6a79eafe15d11341fd43", size = 3239717 } 59 | wheels = [ 60 | { url = "https://files.pythonhosted.org/packages/98/3a/03c74331c5eb8bd025734e04c9840532226775c47a2c39b56a0c8d4f128d/mypy-1.15.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:aea39e0583d05124836ea645f412e88a5c7d0fd77a6d694b60d9b6b2d9f184fd", size = 10793981 }, 61 | { url = "https://files.pythonhosted.org/packages/f0/1a/41759b18f2cfd568848a37c89030aeb03534411eef981df621d8fad08a1d/mypy-1.15.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2f2147ab812b75e5b5499b01ade1f4a81489a147c01585cda36019102538615f", size = 9749175 }, 62 | { url = "https://files.pythonhosted.org/packages/12/7e/873481abf1ef112c582db832740f4c11b2bfa510e829d6da29b0ab8c3f9c/mypy-1.15.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ce436f4c6d218a070048ed6a44c0bbb10cd2cc5e272b29e7845f6a2f57ee4464", size = 11455675 }, 63 | { url = "https://files.pythonhosted.org/packages/b3/d0/92ae4cde706923a2d3f2d6c39629134063ff64b9dedca9c1388363da072d/mypy-1.15.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8023ff13985661b50a5928fc7a5ca15f3d1affb41e5f0a9952cb68ef090b31ee", size = 12410020 }, 64 | { url = "https://files.pythonhosted.org/packages/46/8b/df49974b337cce35f828ba6fda228152d6db45fed4c86ba56ffe442434fd/mypy-1.15.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:1124a18bc11a6a62887e3e137f37f53fbae476dc36c185d549d4f837a2a6a14e", size = 12498582 }, 65 | { url = "https://files.pythonhosted.org/packages/13/50/da5203fcf6c53044a0b699939f31075c45ae8a4cadf538a9069b165c1050/mypy-1.15.0-cp312-cp312-win_amd64.whl", hash = "sha256:171a9ca9a40cd1843abeca0e405bc1940cd9b305eaeea2dda769ba096932bb22", size = 9366614 }, 66 | { url = "https://files.pythonhosted.org/packages/6a/9b/fd2e05d6ffff24d912f150b87db9e364fa8282045c875654ce7e32fffa66/mypy-1.15.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:93faf3fdb04768d44bf28693293f3904bbb555d076b781ad2530214ee53e3445", size = 10788592 }, 67 | { url = "https://files.pythonhosted.org/packages/74/37/b246d711c28a03ead1fd906bbc7106659aed7c089d55fe40dd58db812628/mypy-1.15.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:811aeccadfb730024c5d3e326b2fbe9249bb7413553f15499a4050f7c30e801d", size = 9753611 }, 68 | { url = "https://files.pythonhosted.org/packages/a6/ac/395808a92e10cfdac8003c3de9a2ab6dc7cde6c0d2a4df3df1b815ffd067/mypy-1.15.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:98b7b9b9aedb65fe628c62a6dc57f6d5088ef2dfca37903a7d9ee374d03acca5", size = 11438443 }, 69 | { url = "https://files.pythonhosted.org/packages/d2/8b/801aa06445d2de3895f59e476f38f3f8d610ef5d6908245f07d002676cbf/mypy-1.15.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c43a7682e24b4f576d93072216bf56eeff70d9140241f9edec0c104d0c515036", size = 12402541 }, 70 | { url = "https://files.pythonhosted.org/packages/c7/67/5a4268782eb77344cc613a4cf23540928e41f018a9a1ec4c6882baf20ab8/mypy-1.15.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:baefc32840a9f00babd83251560e0ae1573e2f9d1b067719479bfb0e987c6357", size = 12494348 }, 71 | { url = "https://files.pythonhosted.org/packages/83/3e/57bb447f7bbbfaabf1712d96f9df142624a386d98fb026a761532526057e/mypy-1.15.0-cp313-cp313-win_amd64.whl", hash = "sha256:b9378e2c00146c44793c98b8d5a61039a048e31f429fb0eb546d93f4b000bedf", size = 9373648 }, 72 | { url = "https://files.pythonhosted.org/packages/09/4e/a7d65c7322c510de2c409ff3828b03354a7c43f5a8ed458a7a131b41c7b9/mypy-1.15.0-py3-none-any.whl", hash = "sha256:5469affef548bd1895d86d3bf10ce2b44e33d86923c29e4d675b3e323437ea3e", size = 2221777 }, 73 | ] 74 | 75 | [[package]] 76 | name = "mypy-extensions" 77 | version = "1.0.0" 78 | source = { registry = "https://pypi.org/simple" } 79 | sdist = { url = "https://files.pythonhosted.org/packages/98/a4/1ab47638b92648243faf97a5aeb6ea83059cc3624972ab6b8d2316078d3f/mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782", size = 4433 } 80 | wheels = [ 81 | { url = "https://files.pythonhosted.org/packages/2a/e2/5d3f6ada4297caebe1a2add3b126fe800c96f56dbe5d1988a2cbe0b267aa/mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d", size = 4695 }, 82 | ] 83 | 84 | [[package]] 85 | name = "packaging" 86 | version = "24.2" 87 | source = { registry = "https://pypi.org/simple" } 88 | sdist = { url = "https://files.pythonhosted.org/packages/d0/63/68dbb6eb2de9cb10ee4c9c14a0148804425e13c4fb20d61cce69f53106da/packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f", size = 163950 } 89 | wheels = [ 90 | { url = "https://files.pythonhosted.org/packages/88/ef/eb23f262cca3c0c4eb7ab1933c3b1f03d021f2c48f54763065b6f0e321be/packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759", size = 65451 }, 91 | ] 92 | 93 | [[package]] 94 | name = "pluggy" 95 | version = "1.5.0" 96 | source = { registry = "https://pypi.org/simple" } 97 | sdist = { url = "https://files.pythonhosted.org/packages/96/2d/02d4312c973c6050a18b314a5ad0b3210edb65a906f868e31c111dede4a6/pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1", size = 67955 } 98 | wheels = [ 99 | { url = "https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669", size = 20556 }, 100 | ] 101 | 102 | [[package]] 103 | name = "pytest" 104 | version = "8.3.5" 105 | source = { registry = "https://pypi.org/simple" } 106 | dependencies = [ 107 | { name = "colorama", marker = "sys_platform == 'win32'" }, 108 | { name = "iniconfig" }, 109 | { name = "packaging" }, 110 | { name = "pluggy" }, 111 | ] 112 | sdist = { url = "https://files.pythonhosted.org/packages/ae/3c/c9d525a414d506893f0cd8a8d0de7706446213181570cdbd766691164e40/pytest-8.3.5.tar.gz", hash = "sha256:f4efe70cc14e511565ac476b57c279e12a855b11f48f212af1080ef2263d3845", size = 1450891 } 113 | wheels = [ 114 | { url = "https://files.pythonhosted.org/packages/30/3d/64ad57c803f1fa1e963a7946b6e0fea4a70df53c1a7fed304586539c2bac/pytest-8.3.5-py3-none-any.whl", hash = "sha256:c69214aa47deac29fad6c2a4f590b9c4a9fdb16a403176fe154b79c0b4d4d820", size = 343634 }, 115 | ] 116 | 117 | [[package]] 118 | name = "pytest-asyncio" 119 | version = "0.25.3" 120 | source = { registry = "https://pypi.org/simple" } 121 | dependencies = [ 122 | { name = "pytest" }, 123 | ] 124 | sdist = { url = "https://files.pythonhosted.org/packages/f2/a8/ecbc8ede70921dd2f544ab1cadd3ff3bf842af27f87bbdea774c7baa1d38/pytest_asyncio-0.25.3.tar.gz", hash = "sha256:fc1da2cf9f125ada7e710b4ddad05518d4cee187ae9412e9ac9271003497f07a", size = 54239 } 125 | wheels = [ 126 | { url = "https://files.pythonhosted.org/packages/67/17/3493c5624e48fd97156ebaec380dcaafee9506d7e2c46218ceebbb57d7de/pytest_asyncio-0.25.3-py3-none-any.whl", hash = "sha256:9e89518e0f9bd08928f97a3482fdc4e244df17529460bc038291ccaf8f85c7c3", size = 19467 }, 127 | ] 128 | 129 | [[package]] 130 | name = "ruff" 131 | version = "0.11.2" 132 | source = { registry = "https://pypi.org/simple" } 133 | sdist = { url = "https://files.pythonhosted.org/packages/90/61/fb87430f040e4e577e784e325351186976516faef17d6fcd921fe28edfd7/ruff-0.11.2.tar.gz", hash = "sha256:ec47591497d5a1050175bdf4e1a4e6272cddff7da88a2ad595e1e326041d8d94", size = 3857511 } 134 | wheels = [ 135 | { url = "https://files.pythonhosted.org/packages/62/99/102578506f0f5fa29fd7e0df0a273864f79af044757aef73d1cae0afe6ad/ruff-0.11.2-py3-none-linux_armv6l.whl", hash = "sha256:c69e20ea49e973f3afec2c06376eb56045709f0212615c1adb0eda35e8a4e477", size = 10113146 }, 136 | { url = "https://files.pythonhosted.org/packages/74/ad/5cd4ba58ab602a579997a8494b96f10f316e874d7c435bcc1a92e6da1b12/ruff-0.11.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:2c5424cc1c4eb1d8ecabe6d4f1b70470b4f24a0c0171356290b1953ad8f0e272", size = 10867092 }, 137 | { url = "https://files.pythonhosted.org/packages/fc/3e/d3f13619e1d152c7b600a38c1a035e833e794c6625c9a6cea6f63dbf3af4/ruff-0.11.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:ecf20854cc73f42171eedb66f006a43d0a21bfb98a2523a809931cda569552d9", size = 10224082 }, 138 | { url = "https://files.pythonhosted.org/packages/90/06/f77b3d790d24a93f38e3806216f263974909888fd1e826717c3ec956bbcd/ruff-0.11.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0c543bf65d5d27240321604cee0633a70c6c25c9a2f2492efa9f6d4b8e4199bb", size = 10394818 }, 139 | { url = "https://files.pythonhosted.org/packages/99/7f/78aa431d3ddebfc2418cd95b786642557ba8b3cb578c075239da9ce97ff9/ruff-0.11.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:20967168cc21195db5830b9224be0e964cc9c8ecf3b5a9e3ce19876e8d3a96e3", size = 9952251 }, 140 | { url = "https://files.pythonhosted.org/packages/30/3e/f11186d1ddfaca438c3bbff73c6a2fdb5b60e6450cc466129c694b0ab7a2/ruff-0.11.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:955a9ce63483999d9f0b8f0b4a3ad669e53484232853054cc8b9d51ab4c5de74", size = 11563566 }, 141 | { url = "https://files.pythonhosted.org/packages/22/6c/6ca91befbc0a6539ee133d9a9ce60b1a354db12c3c5d11cfdbf77140f851/ruff-0.11.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:86b3a27c38b8fce73bcd262b0de32e9a6801b76d52cdb3ae4c914515f0cef608", size = 12208721 }, 142 | { url = "https://files.pythonhosted.org/packages/19/b0/24516a3b850d55b17c03fc399b681c6a549d06ce665915721dc5d6458a5c/ruff-0.11.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a3b66a03b248c9fcd9d64d445bafdf1589326bee6fc5c8e92d7562e58883e30f", size = 11662274 }, 143 | { url = "https://files.pythonhosted.org/packages/d7/65/76be06d28ecb7c6070280cef2bcb20c98fbf99ff60b1c57d2fb9b8771348/ruff-0.11.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0397c2672db015be5aa3d4dac54c69aa012429097ff219392c018e21f5085147", size = 13792284 }, 144 | { url = "https://files.pythonhosted.org/packages/ce/d2/4ceed7147e05852876f3b5f3fdc23f878ce2b7e0b90dd6e698bda3d20787/ruff-0.11.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:869bcf3f9abf6457fbe39b5a37333aa4eecc52a3b99c98827ccc371a8e5b6f1b", size = 11327861 }, 145 | { url = "https://files.pythonhosted.org/packages/c4/78/4935ecba13706fd60ebe0e3dc50371f2bdc3d9bc80e68adc32ff93914534/ruff-0.11.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:2a2b50ca35457ba785cd8c93ebbe529467594087b527a08d487cf0ee7b3087e9", size = 10276560 }, 146 | { url = "https://files.pythonhosted.org/packages/81/7f/1b2435c3f5245d410bb5dc80f13ec796454c21fbda12b77d7588d5cf4e29/ruff-0.11.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:7c69c74bf53ddcfbc22e6eb2f31211df7f65054bfc1f72288fc71e5f82db3eab", size = 9945091 }, 147 | { url = "https://files.pythonhosted.org/packages/39/c4/692284c07e6bf2b31d82bb8c32f8840f9d0627d92983edaac991a2b66c0a/ruff-0.11.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:6e8fb75e14560f7cf53b15bbc55baf5ecbe373dd5f3aab96ff7aa7777edd7630", size = 10977133 }, 148 | { url = "https://files.pythonhosted.org/packages/94/cf/8ab81cb7dd7a3b0a3960c2769825038f3adcd75faf46dd6376086df8b128/ruff-0.11.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:842a472d7b4d6f5924e9297aa38149e5dcb1e628773b70e6387ae2c97a63c58f", size = 11378514 }, 149 | { url = "https://files.pythonhosted.org/packages/d9/3a/a647fa4f316482dacf2fd68e8a386327a33d6eabd8eb2f9a0c3d291ec549/ruff-0.11.2-py3-none-win32.whl", hash = "sha256:aca01ccd0eb5eb7156b324cfaa088586f06a86d9e5314b0eb330cb48415097cc", size = 10319835 }, 150 | { url = "https://files.pythonhosted.org/packages/86/54/3c12d3af58012a5e2cd7ebdbe9983f4834af3f8cbea0e8a8c74fa1e23b2b/ruff-0.11.2-py3-none-win_amd64.whl", hash = "sha256:3170150172a8f994136c0c66f494edf199a0bbea7a409f649e4bc8f4d7084080", size = 11373713 }, 151 | { url = "https://files.pythonhosted.org/packages/d6/d4/dd813703af8a1e2ac33bf3feb27e8a5ad514c9f219df80c64d69807e7f71/ruff-0.11.2-py3-none-win_arm64.whl", hash = "sha256:52933095158ff328f4c77af3d74f0379e34fd52f175144cefc1b192e7ccd32b4", size = 10441990 }, 152 | ] 153 | 154 | [[package]] 155 | name = "typing-extensions" 156 | version = "4.12.2" 157 | source = { registry = "https://pypi.org/simple" } 158 | sdist = { url = "https://files.pythonhosted.org/packages/df/db/f35a00659bc03fec321ba8bce9420de607a1d37f8342eee1863174c69557/typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8", size = 85321 } 159 | wheels = [ 160 | { url = "https://files.pythonhosted.org/packages/26/9f/ad63fc0248c5379346306f8668cda6e2e2e9c95e01216d2b8ffd9ff037d0/typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d", size = 37438 }, 161 | ] 162 | -------------------------------------------------------------------------------- /src/async_dag/task_manager.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Awaitable, Callable, Iterator 2 | from contextlib import contextmanager 3 | from typing import Never, overload 4 | 5 | from .execution_result import ExecutionResult 6 | from .state import State 7 | from .task_node import TaskNode 8 | 9 | type TaskCallback[_ReturnType, *_Inputs] = Callable[[*_Inputs], Awaitable[_ReturnType]] 10 | type TaskNodeOrImmediate[_ParameterType, _ReturnType] = ( 11 | TaskNode[_ParameterType, _ReturnType] | _ReturnType 12 | ) 13 | 14 | 15 | async def _unreachable(*_: object) -> Never: 16 | raise ValueError("unreachable") 17 | 18 | 19 | class TaskManager[_ParameterType]: 20 | """ 21 | ### TaskManager 22 | `TaskManager` is the main building block of `async_dag`, it provides an interface to build and run DAGs. 23 | The four main APIs that are relevant to you are: 24 | 1. `add_node(async_task, param_a, param_b, ...)` - adds a new node to the graph go to function signature for feature information. 25 | 2. `parameter_node` - a spacial property of `TaskManager` that when passed as a dependency of a task will resolve to the parameter passed to `invoke`. 26 | 3. `sort()` - sorts the DAG and ready the TaskManager up for upcoming `invoke` calls. 27 | 4. `invoke(parameter)` - execute the tasks in the DAG with a given parameter returns an `ExecutionResult`. 28 | 4. `TaskNode.extract_result(execution_result)` - extracts the result returned from the task passed to the node. 29 | 30 | You can also use the helper function `build_dag` that provides a context manager and handle calling sort for you. 31 | 32 | All of the listed methods, properties and functions listed here have strings with a deeper explanation. 33 | 34 | #### Example 35 | ```python 36 | async def add(n: int) -> int: 37 | return n + 1 38 | 39 | tm = TaskManager[int]() # Define a dag that receives an int as a parameter 40 | 41 | # Build the DAG 42 | node_1 = tm.add_node(add, tm.parameter_node) # use the parameter passed to `tm.invoke` 43 | node_2 = tm.add_node(add, node_1) # use the value returned from node_1 44 | node_3 = tm.add_node(add, node_2) # use the value returned from node_2 45 | 46 | tm.sort() # sorts the DAG and ready the `TaskManager` for `invoke` calls 47 | 48 | execution_result = await tm.invoke(0) # Invoke our DAG 49 | 50 | # Extract the result from one of the nodes 51 | print(node_3.extract_result(execution_result)) # prints 3 52 | ``` 53 | """ 54 | 55 | def __init__(self) -> None: 56 | # NOTE: the tasks in _tasks must be a contiguous array of sorted by _id 57 | self._tasks: list[TaskNode[_ParameterType, object]] = [] 58 | self._max_depth = 0 59 | self._starting_nodes_id: list[int] = [] 60 | self._is_sorted: bool = False 61 | self._parameter_node: TaskNode[_ParameterType, _ParameterType] | None = None 62 | 63 | async def invoke( 64 | self, parameter: _ParameterType 65 | ) -> ExecutionResult[_ParameterType]: 66 | """ 67 | Execute the DAG, this functions `parameter` argument will be passed to each node that is depending on the spacial `parameter_node`. 68 | 69 | This function should only be called after `sort` was called, any calls to it before `sort` was called raise a `ValueError`. 70 | 71 | If any task raises an exception this function will raise a `ExceptionGroup` with that exception and any other exceptions raised during the cancellation of the rest of the tasks. 72 | """ 73 | if not self._is_sorted: 74 | raise ValueError("'invoke' can not be called before 'sort'") 75 | 76 | execution_result = ExecutionResult(self, parameter) 77 | await execution_result._invoke() 78 | return execution_result 79 | 80 | def sort(self) -> None: 81 | """ 82 | Ready the `TaskManager` for `invoke` calls, this method check that the created graph is indeed a DAG, 83 | If a cycle is detected a `ValueError` will be raised. 84 | 85 | This function should only be called once, any call after the first will raise a `ValueError`. 86 | 87 | Do not call this function if you are using the `build_dag` helper function. 88 | """ 89 | if self._is_sorted: 90 | raise ValueError("'sort' can only be called once") 91 | 92 | def visit(node: TaskNode[_ParameterType, object]) -> None: 93 | if node._state == State.PERMANENT: 94 | return 95 | if node._state == State.TEMPORARY: 96 | raise ValueError("Cycle detected, graph is not a DAG") 97 | 98 | node._state = State.TEMPORARY 99 | for dep_task in [self._tasks[dep_id] for dep_id in node._dependencies_ids]: 100 | if node._depth <= dep_task._depth: 101 | node._depth = dep_task._depth + 1 102 | self._max_depth = max(node._depth, self._max_depth) 103 | visit(dep_task) 104 | 105 | node._state = State.PERMANENT 106 | 107 | for task in self._tasks: 108 | if task._state == State.UNDISCOVERED: 109 | visit(task) 110 | 111 | for task in self._tasks: 112 | if len(task._dependencies_ids) == 0: 113 | self._starting_nodes_id.append(task._id) 114 | for dep in task._dependencies_ids: 115 | self._tasks[dep]._dependents_ids.add(task._id) 116 | 117 | self._is_sorted = True 118 | 119 | @property 120 | def parameter_node(self) -> TaskNode[_ParameterType, _ParameterType]: 121 | """A spacial node that represents the parameter value, if a node `TaskNode` depends on this node it will receive the value passed to `invoke`.""" 122 | if self._parameter_node is None: 123 | self._parameter_node = self._add_node(_unreachable) 124 | return self._parameter_node 125 | 126 | def add_immediate_node[_ReturnType]( 127 | self, value: _ReturnType 128 | ) -> TaskNode[_ParameterType, _ReturnType]: 129 | async def get_value() -> _ReturnType: 130 | return value 131 | 132 | return self._add_node(get_value) 133 | 134 | @overload 135 | def add_node[_ReturnType]( 136 | self, 137 | task: TaskCallback[_ReturnType], 138 | ) -> TaskNode[_ParameterType, _ReturnType]: ... 139 | 140 | @overload 141 | def add_node[_ReturnType, _I_1]( 142 | self, 143 | task: TaskCallback[_ReturnType, _I_1], 144 | arg_1: TaskNodeOrImmediate[_ParameterType, _I_1], 145 | ) -> TaskNode[_ParameterType, _ReturnType]: ... 146 | 147 | @overload 148 | def add_node[_ReturnType, _I_1, _I_2]( 149 | self, 150 | task: TaskCallback[_ReturnType, _I_1, _I_2], 151 | arg_1: TaskNodeOrImmediate[_ParameterType, _I_1], 152 | arg_2: TaskNodeOrImmediate[_ParameterType, _I_2], 153 | ) -> TaskNode[_ParameterType, _ReturnType]: ... 154 | 155 | @overload 156 | def add_node[_ReturnType, _I_1, _I_2, _I_3]( 157 | self, 158 | task: TaskCallback[_ReturnType, _I_1, _I_2, _I_3], 159 | arg_1: TaskNodeOrImmediate[_ParameterType, _I_1], 160 | arg_2: TaskNodeOrImmediate[_ParameterType, _I_2], 161 | arg_3: TaskNodeOrImmediate[_ParameterType, _I_3], 162 | ) -> TaskNode[_ParameterType, _ReturnType]: ... 163 | 164 | @overload 165 | def add_node[_ReturnType, _I_1, _I_2, _I_3, _I_4]( 166 | self, 167 | task: TaskCallback[_ReturnType, _I_1, _I_2, _I_3, _I_4], 168 | arg_1: TaskNodeOrImmediate[_ParameterType, _I_1], 169 | arg_2: TaskNodeOrImmediate[_ParameterType, _I_2], 170 | arg_3: TaskNodeOrImmediate[_ParameterType, _I_3], 171 | arg_4: TaskNodeOrImmediate[_ParameterType, _I_4], 172 | ) -> TaskNode[_ParameterType, _ReturnType]: ... 173 | 174 | @overload 175 | def add_node[_ReturnType, _I_1, _I_2, _I_3, _I_4, _I_5]( 176 | self, 177 | task: TaskCallback[_ReturnType, _I_1, _I_2, _I_3, _I_4, _I_5], 178 | arg_1: TaskNodeOrImmediate[_ParameterType, _I_1], 179 | arg_2: TaskNodeOrImmediate[_ParameterType, _I_2], 180 | arg_3: TaskNodeOrImmediate[_ParameterType, _I_3], 181 | arg_4: TaskNodeOrImmediate[_ParameterType, _I_4], 182 | arg_5: TaskNodeOrImmediate[_ParameterType, _I_5], 183 | ) -> TaskNode[_ParameterType, _ReturnType]: ... 184 | 185 | @overload 186 | def add_node[_ReturnType, _I_1, _I_2, _I_3, _I_4, _I_5, _I_6]( 187 | self, 188 | task: TaskCallback[_ReturnType, _I_1, _I_2, _I_3, _I_4, _I_5, _I_6], 189 | arg_1: TaskNodeOrImmediate[_ParameterType, _I_1], 190 | arg_2: TaskNodeOrImmediate[_ParameterType, _I_2], 191 | arg_3: TaskNodeOrImmediate[_ParameterType, _I_3], 192 | arg_4: TaskNodeOrImmediate[_ParameterType, _I_4], 193 | arg_5: TaskNodeOrImmediate[_ParameterType, _I_5], 194 | arg_6: TaskNodeOrImmediate[_ParameterType, _I_6], 195 | ) -> TaskNode[_ParameterType, _ReturnType]: ... 196 | 197 | @overload 198 | def add_node[_ReturnType, _I_1, _I_2, _I_3, _I_4, _I_5, _I_6, _I_7]( 199 | self, 200 | task: TaskCallback[_ReturnType, _I_1, _I_2, _I_3, _I_4, _I_5, _I_6, _I_7], 201 | arg_1: TaskNodeOrImmediate[_ParameterType, _I_1], 202 | arg_2: TaskNodeOrImmediate[_ParameterType, _I_2], 203 | arg_3: TaskNodeOrImmediate[_ParameterType, _I_3], 204 | arg_4: TaskNodeOrImmediate[_ParameterType, _I_4], 205 | arg_5: TaskNodeOrImmediate[_ParameterType, _I_5], 206 | arg_6: TaskNodeOrImmediate[_ParameterType, _I_6], 207 | arg_7: TaskNodeOrImmediate[_ParameterType, _I_7], 208 | ) -> TaskNode[_ParameterType, _ReturnType]: ... 209 | 210 | @overload 211 | def add_node[_ReturnType, _I_1, _I_2, _I_3, _I_4, _I_5, _I_6, _I_7, _I_8]( 212 | self, 213 | task: TaskCallback[_ReturnType, _I_1, _I_2, _I_3, _I_4, _I_5, _I_6, _I_7, _I_8], 214 | arg_1: TaskNodeOrImmediate[_ParameterType, _I_1], 215 | arg_2: TaskNodeOrImmediate[_ParameterType, _I_2], 216 | arg_3: TaskNodeOrImmediate[_ParameterType, _I_3], 217 | arg_4: TaskNodeOrImmediate[_ParameterType, _I_4], 218 | arg_5: TaskNodeOrImmediate[_ParameterType, _I_5], 219 | arg_6: TaskNodeOrImmediate[_ParameterType, _I_6], 220 | arg_7: TaskNodeOrImmediate[_ParameterType, _I_7], 221 | arg_8: TaskNodeOrImmediate[_ParameterType, _I_8], 222 | ) -> TaskNode[_ParameterType, _ReturnType]: ... 223 | 224 | @overload 225 | def add_node[_ReturnType, _I_1, _I_2, _I_3, _I_4, _I_5, _I_6, _I_7, _I_8, _I_9]( 226 | self, 227 | task: TaskCallback[ 228 | _ReturnType, 229 | _I_1, 230 | _I_2, 231 | _I_3, 232 | _I_4, 233 | _I_5, 234 | _I_6, 235 | _I_7, 236 | _I_8, 237 | _I_9, 238 | ], 239 | arg_1: TaskNodeOrImmediate[_ParameterType, _I_1], 240 | arg_2: TaskNodeOrImmediate[_ParameterType, _I_2], 241 | arg_3: TaskNodeOrImmediate[_ParameterType, _I_3], 242 | arg_4: TaskNodeOrImmediate[_ParameterType, _I_4], 243 | arg_5: TaskNodeOrImmediate[_ParameterType, _I_5], 244 | arg_6: TaskNodeOrImmediate[_ParameterType, _I_6], 245 | arg_7: TaskNodeOrImmediate[_ParameterType, _I_7], 246 | arg_8: TaskNodeOrImmediate[_ParameterType, _I_8], 247 | arg_9: TaskNodeOrImmediate[_ParameterType, _I_9], 248 | ) -> TaskNode[_ParameterType, _ReturnType]: ... 249 | 250 | @overload 251 | def add_node[ 252 | _ReturnType, 253 | _I_1, 254 | _I_2, 255 | _I_3, 256 | _I_4, 257 | _I_5, 258 | _I_6, 259 | _I_7, 260 | _I_8, 261 | _I_9, 262 | _I_10, 263 | ]( 264 | self, 265 | task: TaskCallback[ 266 | _ReturnType, 267 | _I_1, 268 | _I_2, 269 | _I_3, 270 | _I_4, 271 | _I_5, 272 | _I_6, 273 | _I_7, 274 | _I_8, 275 | _I_9, 276 | _I_10, 277 | ], 278 | arg_1: TaskNodeOrImmediate[_ParameterType, _I_1], 279 | arg_2: TaskNodeOrImmediate[_ParameterType, _I_2], 280 | arg_3: TaskNodeOrImmediate[_ParameterType, _I_3], 281 | arg_4: TaskNodeOrImmediate[_ParameterType, _I_4], 282 | arg_5: TaskNodeOrImmediate[_ParameterType, _I_5], 283 | arg_6: TaskNodeOrImmediate[_ParameterType, _I_6], 284 | arg_7: TaskNodeOrImmediate[_ParameterType, _I_7], 285 | arg_8: TaskNodeOrImmediate[_ParameterType, _I_8], 286 | arg_9: TaskNodeOrImmediate[_ParameterType, _I_9], 287 | arg_10: TaskNodeOrImmediate[_ParameterType, _I_10], 288 | ) -> TaskNode[_ParameterType, _ReturnType]: ... 289 | 290 | @overload 291 | def add_node[ 292 | _ReturnType, 293 | _I_1, 294 | _I_2, 295 | _I_3, 296 | _I_4, 297 | _I_5, 298 | _I_6, 299 | _I_7, 300 | _I_8, 301 | _I_9, 302 | _I_10, 303 | _I_11, 304 | ]( 305 | self, 306 | task: TaskCallback[ 307 | _ReturnType, 308 | _I_1, 309 | _I_2, 310 | _I_3, 311 | _I_4, 312 | _I_5, 313 | _I_6, 314 | _I_7, 315 | _I_8, 316 | _I_9, 317 | _I_10, 318 | _I_11, 319 | ], 320 | arg_1: TaskNodeOrImmediate[_ParameterType, _I_1], 321 | arg_2: TaskNodeOrImmediate[_ParameterType, _I_2], 322 | arg_3: TaskNodeOrImmediate[_ParameterType, _I_3], 323 | arg_4: TaskNodeOrImmediate[_ParameterType, _I_4], 324 | arg_5: TaskNodeOrImmediate[_ParameterType, _I_5], 325 | arg_6: TaskNodeOrImmediate[_ParameterType, _I_6], 326 | arg_7: TaskNodeOrImmediate[_ParameterType, _I_7], 327 | arg_8: TaskNodeOrImmediate[_ParameterType, _I_8], 328 | arg_9: TaskNodeOrImmediate[_ParameterType, _I_9], 329 | arg_10: TaskNodeOrImmediate[_ParameterType, _I_10], 330 | arg_11: TaskNodeOrImmediate[_ParameterType, _I_11], 331 | ) -> TaskNode[_ParameterType, _ReturnType]: ... 332 | 333 | @overload 334 | def add_node[ 335 | _ReturnType, 336 | _I_1, 337 | _I_2, 338 | _I_3, 339 | _I_4, 340 | _I_5, 341 | _I_6, 342 | _I_7, 343 | _I_8, 344 | _I_9, 345 | _I_10, 346 | _I_11, 347 | _I_12, 348 | ]( 349 | self, 350 | task: TaskCallback[ 351 | _ReturnType, 352 | _I_1, 353 | _I_2, 354 | _I_3, 355 | _I_4, 356 | _I_5, 357 | _I_6, 358 | _I_7, 359 | _I_8, 360 | _I_9, 361 | _I_10, 362 | _I_11, 363 | _I_12, 364 | ], 365 | arg_1: TaskNodeOrImmediate[_ParameterType, _I_1], 366 | arg_2: TaskNodeOrImmediate[_ParameterType, _I_2], 367 | arg_3: TaskNodeOrImmediate[_ParameterType, _I_3], 368 | arg_4: TaskNodeOrImmediate[_ParameterType, _I_4], 369 | arg_5: TaskNodeOrImmediate[_ParameterType, _I_5], 370 | arg_6: TaskNodeOrImmediate[_ParameterType, _I_6], 371 | arg_7: TaskNodeOrImmediate[_ParameterType, _I_7], 372 | arg_8: TaskNodeOrImmediate[_ParameterType, _I_8], 373 | arg_9: TaskNodeOrImmediate[_ParameterType, _I_9], 374 | arg_10: TaskNodeOrImmediate[_ParameterType, _I_10], 375 | arg_11: TaskNodeOrImmediate[_ParameterType, _I_11], 376 | arg_12: TaskNodeOrImmediate[_ParameterType, _I_12], 377 | ) -> TaskNode[_ParameterType, _ReturnType]: ... 378 | 379 | @overload 380 | def add_node[ 381 | _ReturnType, 382 | _I_1, 383 | _I_2, 384 | _I_3, 385 | _I_4, 386 | _I_5, 387 | _I_6, 388 | _I_7, 389 | _I_8, 390 | _I_9, 391 | _I_10, 392 | _I_11, 393 | _I_12, 394 | _I_13, 395 | ]( 396 | self, 397 | task: TaskCallback[ 398 | _ReturnType, 399 | _I_1, 400 | _I_2, 401 | _I_3, 402 | _I_4, 403 | _I_5, 404 | _I_6, 405 | _I_7, 406 | _I_8, 407 | _I_9, 408 | _I_10, 409 | _I_11, 410 | _I_12, 411 | _I_13, 412 | ], 413 | arg_1: TaskNodeOrImmediate[_ParameterType, _I_1], 414 | arg_2: TaskNodeOrImmediate[_ParameterType, _I_2], 415 | arg_3: TaskNodeOrImmediate[_ParameterType, _I_3], 416 | arg_4: TaskNodeOrImmediate[_ParameterType, _I_4], 417 | arg_5: TaskNodeOrImmediate[_ParameterType, _I_5], 418 | arg_6: TaskNodeOrImmediate[_ParameterType, _I_6], 419 | arg_7: TaskNodeOrImmediate[_ParameterType, _I_7], 420 | arg_8: TaskNodeOrImmediate[_ParameterType, _I_8], 421 | arg_9: TaskNodeOrImmediate[_ParameterType, _I_9], 422 | arg_10: TaskNodeOrImmediate[_ParameterType, _I_10], 423 | arg_11: TaskNodeOrImmediate[_ParameterType, _I_11], 424 | arg_12: TaskNodeOrImmediate[_ParameterType, _I_12], 425 | arg_13: TaskNodeOrImmediate[_ParameterType, _I_13], 426 | ) -> TaskNode[_ParameterType, _ReturnType]: ... 427 | 428 | @overload 429 | def add_node[ 430 | _ReturnType, 431 | _I_1, 432 | _I_2, 433 | _I_3, 434 | _I_4, 435 | _I_5, 436 | _I_6, 437 | _I_7, 438 | _I_8, 439 | _I_9, 440 | _I_10, 441 | _I_11, 442 | _I_12, 443 | _I_13, 444 | _I_14, 445 | ]( 446 | self, 447 | task: TaskCallback[ 448 | _ReturnType, 449 | _I_1, 450 | _I_2, 451 | _I_3, 452 | _I_4, 453 | _I_5, 454 | _I_6, 455 | _I_7, 456 | _I_8, 457 | _I_9, 458 | _I_10, 459 | _I_11, 460 | _I_12, 461 | _I_13, 462 | _I_14, 463 | ], 464 | arg_1: TaskNodeOrImmediate[_ParameterType, _I_1], 465 | arg_2: TaskNodeOrImmediate[_ParameterType, _I_2], 466 | arg_3: TaskNodeOrImmediate[_ParameterType, _I_3], 467 | arg_4: TaskNodeOrImmediate[_ParameterType, _I_4], 468 | arg_5: TaskNodeOrImmediate[_ParameterType, _I_5], 469 | arg_6: TaskNodeOrImmediate[_ParameterType, _I_6], 470 | arg_7: TaskNodeOrImmediate[_ParameterType, _I_7], 471 | arg_8: TaskNodeOrImmediate[_ParameterType, _I_8], 472 | arg_9: TaskNodeOrImmediate[_ParameterType, _I_9], 473 | arg_10: TaskNodeOrImmediate[_ParameterType, _I_10], 474 | arg_11: TaskNodeOrImmediate[_ParameterType, _I_11], 475 | arg_12: TaskNodeOrImmediate[_ParameterType, _I_12], 476 | arg_13: TaskNodeOrImmediate[_ParameterType, _I_13], 477 | arg_14: TaskNodeOrImmediate[_ParameterType, _I_14], 478 | ) -> TaskNode[_ParameterType, _ReturnType]: ... 479 | 480 | @overload 481 | def add_node[ 482 | _ReturnType, 483 | _I_1, 484 | _I_2, 485 | _I_3, 486 | _I_4, 487 | _I_5, 488 | _I_6, 489 | _I_7, 490 | _I_8, 491 | _I_9, 492 | _I_10, 493 | _I_11, 494 | _I_12, 495 | _I_13, 496 | _I_14, 497 | _I_15, 498 | ]( 499 | self, 500 | task: TaskCallback[ 501 | _ReturnType, 502 | _I_1, 503 | _I_2, 504 | _I_3, 505 | _I_4, 506 | _I_5, 507 | _I_6, 508 | _I_7, 509 | _I_8, 510 | _I_9, 511 | _I_10, 512 | _I_11, 513 | _I_12, 514 | _I_13, 515 | _I_14, 516 | _I_15, 517 | ], 518 | arg_1: TaskNodeOrImmediate[_ParameterType, _I_1], 519 | arg_2: TaskNodeOrImmediate[_ParameterType, _I_2], 520 | arg_3: TaskNodeOrImmediate[_ParameterType, _I_3], 521 | arg_4: TaskNodeOrImmediate[_ParameterType, _I_4], 522 | arg_5: TaskNodeOrImmediate[_ParameterType, _I_5], 523 | arg_6: TaskNodeOrImmediate[_ParameterType, _I_6], 524 | arg_7: TaskNodeOrImmediate[_ParameterType, _I_7], 525 | arg_8: TaskNodeOrImmediate[_ParameterType, _I_8], 526 | arg_9: TaskNodeOrImmediate[_ParameterType, _I_9], 527 | arg_10: TaskNodeOrImmediate[_ParameterType, _I_10], 528 | arg_11: TaskNodeOrImmediate[_ParameterType, _I_11], 529 | arg_12: TaskNodeOrImmediate[_ParameterType, _I_12], 530 | arg_13: TaskNodeOrImmediate[_ParameterType, _I_13], 531 | arg_14: TaskNodeOrImmediate[_ParameterType, _I_14], 532 | arg_15: TaskNodeOrImmediate[_ParameterType, _I_15], 533 | ) -> TaskNode[_ParameterType, _ReturnType]: ... 534 | 535 | @overload 536 | def add_node[ 537 | _ReturnType, 538 | _I_1, 539 | _I_2, 540 | _I_3, 541 | _I_4, 542 | _I_5, 543 | _I_6, 544 | _I_7, 545 | _I_8, 546 | _I_9, 547 | _I_10, 548 | _I_11, 549 | _I_12, 550 | _I_13, 551 | _I_14, 552 | _I_15, 553 | _I_16, 554 | ]( 555 | self, 556 | task: TaskCallback[ 557 | _ReturnType, 558 | _I_1, 559 | _I_2, 560 | _I_3, 561 | _I_4, 562 | _I_5, 563 | _I_6, 564 | _I_7, 565 | _I_8, 566 | _I_9, 567 | _I_10, 568 | _I_11, 569 | _I_12, 570 | _I_13, 571 | _I_14, 572 | _I_15, 573 | _I_16, 574 | ], 575 | arg_1: TaskNodeOrImmediate[_ParameterType, _I_1], 576 | arg_2: TaskNodeOrImmediate[_ParameterType, _I_2], 577 | arg_3: TaskNodeOrImmediate[_ParameterType, _I_3], 578 | arg_4: TaskNodeOrImmediate[_ParameterType, _I_4], 579 | arg_5: TaskNodeOrImmediate[_ParameterType, _I_5], 580 | arg_6: TaskNodeOrImmediate[_ParameterType, _I_6], 581 | arg_7: TaskNodeOrImmediate[_ParameterType, _I_7], 582 | arg_8: TaskNodeOrImmediate[_ParameterType, _I_8], 583 | arg_9: TaskNodeOrImmediate[_ParameterType, _I_9], 584 | arg_10: TaskNodeOrImmediate[_ParameterType, _I_10], 585 | arg_11: TaskNodeOrImmediate[_ParameterType, _I_11], 586 | arg_12: TaskNodeOrImmediate[_ParameterType, _I_12], 587 | arg_13: TaskNodeOrImmediate[_ParameterType, _I_13], 588 | arg_14: TaskNodeOrImmediate[_ParameterType, _I_14], 589 | arg_15: TaskNodeOrImmediate[_ParameterType, _I_15], 590 | arg_16: TaskNodeOrImmediate[_ParameterType, _I_16], 591 | ) -> TaskNode[_ParameterType, _ReturnType]: ... 592 | 593 | @overload 594 | def add_node[ 595 | _ReturnType, 596 | _I_1, 597 | _I_2, 598 | _I_3, 599 | _I_4, 600 | _I_5, 601 | _I_6, 602 | _I_7, 603 | _I_8, 604 | _I_9, 605 | _I_10, 606 | _I_11, 607 | _I_12, 608 | _I_13, 609 | _I_14, 610 | _I_15, 611 | _I_16, 612 | _I_17, 613 | ]( 614 | self, 615 | task: TaskCallback[ 616 | _ReturnType, 617 | _I_1, 618 | _I_2, 619 | _I_3, 620 | _I_4, 621 | _I_5, 622 | _I_6, 623 | _I_7, 624 | _I_8, 625 | _I_9, 626 | _I_10, 627 | _I_11, 628 | _I_12, 629 | _I_13, 630 | _I_14, 631 | _I_15, 632 | _I_16, 633 | _I_17, 634 | ], 635 | arg_1: TaskNodeOrImmediate[_ParameterType, _I_1], 636 | arg_2: TaskNodeOrImmediate[_ParameterType, _I_2], 637 | arg_3: TaskNodeOrImmediate[_ParameterType, _I_3], 638 | arg_4: TaskNodeOrImmediate[_ParameterType, _I_4], 639 | arg_5: TaskNodeOrImmediate[_ParameterType, _I_5], 640 | arg_6: TaskNodeOrImmediate[_ParameterType, _I_6], 641 | arg_7: TaskNodeOrImmediate[_ParameterType, _I_7], 642 | arg_8: TaskNodeOrImmediate[_ParameterType, _I_8], 643 | arg_9: TaskNodeOrImmediate[_ParameterType, _I_9], 644 | arg_10: TaskNodeOrImmediate[_ParameterType, _I_10], 645 | arg_11: TaskNodeOrImmediate[_ParameterType, _I_11], 646 | arg_12: TaskNodeOrImmediate[_ParameterType, _I_12], 647 | arg_13: TaskNodeOrImmediate[_ParameterType, _I_13], 648 | arg_14: TaskNodeOrImmediate[_ParameterType, _I_14], 649 | arg_15: TaskNodeOrImmediate[_ParameterType, _I_15], 650 | arg_16: TaskNodeOrImmediate[_ParameterType, _I_16], 651 | arg_17: TaskNodeOrImmediate[_ParameterType, _I_17], 652 | ) -> TaskNode[_ParameterType, _ReturnType]: ... 653 | 654 | @overload 655 | def add_node[ 656 | _ReturnType, 657 | _I_1, 658 | _I_2, 659 | _I_3, 660 | _I_4, 661 | _I_5, 662 | _I_6, 663 | _I_7, 664 | _I_8, 665 | _I_9, 666 | _I_10, 667 | _I_11, 668 | _I_12, 669 | _I_13, 670 | _I_14, 671 | _I_15, 672 | _I_16, 673 | _I_17, 674 | _I_18, 675 | ]( 676 | self, 677 | task: TaskCallback[ 678 | _ReturnType, 679 | _I_1, 680 | _I_2, 681 | _I_3, 682 | _I_4, 683 | _I_5, 684 | _I_6, 685 | _I_7, 686 | _I_8, 687 | _I_9, 688 | _I_10, 689 | _I_11, 690 | _I_12, 691 | _I_13, 692 | _I_14, 693 | _I_15, 694 | _I_16, 695 | _I_17, 696 | _I_18, 697 | ], 698 | arg_1: TaskNodeOrImmediate[_ParameterType, _I_1], 699 | arg_2: TaskNodeOrImmediate[_ParameterType, _I_2], 700 | arg_3: TaskNodeOrImmediate[_ParameterType, _I_3], 701 | arg_4: TaskNodeOrImmediate[_ParameterType, _I_4], 702 | arg_5: TaskNodeOrImmediate[_ParameterType, _I_5], 703 | arg_6: TaskNodeOrImmediate[_ParameterType, _I_6], 704 | arg_7: TaskNodeOrImmediate[_ParameterType, _I_7], 705 | arg_8: TaskNodeOrImmediate[_ParameterType, _I_8], 706 | arg_9: TaskNodeOrImmediate[_ParameterType, _I_9], 707 | arg_10: TaskNodeOrImmediate[_ParameterType, _I_10], 708 | arg_11: TaskNodeOrImmediate[_ParameterType, _I_11], 709 | arg_12: TaskNodeOrImmediate[_ParameterType, _I_12], 710 | arg_13: TaskNodeOrImmediate[_ParameterType, _I_13], 711 | arg_14: TaskNodeOrImmediate[_ParameterType, _I_14], 712 | arg_15: TaskNodeOrImmediate[_ParameterType, _I_15], 713 | arg_16: TaskNodeOrImmediate[_ParameterType, _I_16], 714 | arg_17: TaskNodeOrImmediate[_ParameterType, _I_17], 715 | arg_18: TaskNodeOrImmediate[_ParameterType, _I_18], 716 | ) -> TaskNode[_ParameterType, _ReturnType]: ... 717 | 718 | @overload 719 | def add_node[ 720 | _ReturnType, 721 | _I_1, 722 | _I_2, 723 | _I_3, 724 | _I_4, 725 | _I_5, 726 | _I_6, 727 | _I_7, 728 | _I_8, 729 | _I_9, 730 | _I_10, 731 | _I_11, 732 | _I_12, 733 | _I_13, 734 | _I_14, 735 | _I_15, 736 | _I_16, 737 | _I_17, 738 | _I_18, 739 | _I_19, 740 | ]( 741 | self, 742 | task: TaskCallback[ 743 | _ReturnType, 744 | _I_1, 745 | _I_2, 746 | _I_3, 747 | _I_4, 748 | _I_5, 749 | _I_6, 750 | _I_7, 751 | _I_8, 752 | _I_9, 753 | _I_10, 754 | _I_11, 755 | _I_12, 756 | _I_13, 757 | _I_14, 758 | _I_15, 759 | _I_16, 760 | _I_17, 761 | _I_18, 762 | _I_19, 763 | ], 764 | arg_1: TaskNodeOrImmediate[_ParameterType, _I_1], 765 | arg_2: TaskNodeOrImmediate[_ParameterType, _I_2], 766 | arg_3: TaskNodeOrImmediate[_ParameterType, _I_3], 767 | arg_4: TaskNodeOrImmediate[_ParameterType, _I_4], 768 | arg_5: TaskNodeOrImmediate[_ParameterType, _I_5], 769 | arg_6: TaskNodeOrImmediate[_ParameterType, _I_6], 770 | arg_7: TaskNodeOrImmediate[_ParameterType, _I_7], 771 | arg_8: TaskNodeOrImmediate[_ParameterType, _I_8], 772 | arg_9: TaskNodeOrImmediate[_ParameterType, _I_9], 773 | arg_10: TaskNodeOrImmediate[_ParameterType, _I_10], 774 | arg_11: TaskNodeOrImmediate[_ParameterType, _I_11], 775 | arg_12: TaskNodeOrImmediate[_ParameterType, _I_12], 776 | arg_13: TaskNodeOrImmediate[_ParameterType, _I_13], 777 | arg_14: TaskNodeOrImmediate[_ParameterType, _I_14], 778 | arg_15: TaskNodeOrImmediate[_ParameterType, _I_15], 779 | arg_16: TaskNodeOrImmediate[_ParameterType, _I_16], 780 | arg_17: TaskNodeOrImmediate[_ParameterType, _I_17], 781 | arg_18: TaskNodeOrImmediate[_ParameterType, _I_18], 782 | arg_19: TaskNodeOrImmediate[_ParameterType, _I_19], 783 | ) -> TaskNode[_ParameterType, _ReturnType]: ... 784 | 785 | @overload 786 | def add_node[ 787 | _ReturnType, 788 | _I_1, 789 | _I_2, 790 | _I_3, 791 | _I_4, 792 | _I_5, 793 | _I_6, 794 | _I_7, 795 | _I_8, 796 | _I_9, 797 | _I_10, 798 | _I_11, 799 | _I_12, 800 | _I_13, 801 | _I_14, 802 | _I_15, 803 | _I_16, 804 | _I_17, 805 | _I_18, 806 | _I_19, 807 | _I_20, 808 | ]( 809 | self, 810 | task: TaskCallback[ 811 | _ReturnType, 812 | _I_1, 813 | _I_2, 814 | _I_3, 815 | _I_4, 816 | _I_5, 817 | _I_6, 818 | _I_7, 819 | _I_8, 820 | _I_9, 821 | _I_10, 822 | _I_11, 823 | _I_12, 824 | _I_13, 825 | _I_14, 826 | _I_15, 827 | _I_16, 828 | _I_17, 829 | _I_18, 830 | _I_19, 831 | _I_20, 832 | ], 833 | arg_1: TaskNodeOrImmediate[_ParameterType, _I_1], 834 | arg_2: TaskNodeOrImmediate[_ParameterType, _I_2], 835 | arg_3: TaskNodeOrImmediate[_ParameterType, _I_3], 836 | arg_4: TaskNodeOrImmediate[_ParameterType, _I_4], 837 | arg_5: TaskNodeOrImmediate[_ParameterType, _I_5], 838 | arg_6: TaskNodeOrImmediate[_ParameterType, _I_6], 839 | arg_7: TaskNodeOrImmediate[_ParameterType, _I_7], 840 | arg_8: TaskNodeOrImmediate[_ParameterType, _I_8], 841 | arg_9: TaskNodeOrImmediate[_ParameterType, _I_9], 842 | arg_10: TaskNodeOrImmediate[_ParameterType, _I_10], 843 | arg_11: TaskNodeOrImmediate[_ParameterType, _I_11], 844 | arg_12: TaskNodeOrImmediate[_ParameterType, _I_12], 845 | arg_13: TaskNodeOrImmediate[_ParameterType, _I_13], 846 | arg_14: TaskNodeOrImmediate[_ParameterType, _I_14], 847 | arg_15: TaskNodeOrImmediate[_ParameterType, _I_15], 848 | arg_16: TaskNodeOrImmediate[_ParameterType, _I_16], 849 | arg_17: TaskNodeOrImmediate[_ParameterType, _I_17], 850 | arg_18: TaskNodeOrImmediate[_ParameterType, _I_18], 851 | arg_19: TaskNodeOrImmediate[_ParameterType, _I_19], 852 | arg_20: TaskNodeOrImmediate[_ParameterType, _I_20], 853 | ) -> TaskNode[_ParameterType, _ReturnType]: ... 854 | 855 | # TODO: remove all the @overload functions once https://github.com/python/typing/issues/1216 get solved 856 | def add_node[_ReturnType, *_InputsType]( # type: ignore 857 | self, 858 | task: Callable[[*_InputsType], Awaitable[_ReturnType]], 859 | *dependencies: TaskNodeOrImmediate[_ParameterType, object], 860 | ) -> TaskNode[_ParameterType, _ReturnType]: 861 | """ 862 | This functions is the heart of this library, each call to this function adds a new node in our execution DAG. 863 | 864 | This function receives a `task` which is a partially applied coroutine (uncalled) which will get called once all the `dependencies` passed in the following arguments are satisfied. 865 | 866 | The `dependencies` parameters could either be an immediate value (not a `TaskNode`) that will get resolved immediately upon calling `invoke` (hance the name immediate), 867 | or a `TaskNode` which will act as a dependency to that task. 868 | You can look at this function the same way as you look at `functools.partial`, but in addition to parameters we can also pass the task dependencies. 869 | """ 870 | return self._add_node( 871 | task, 872 | *( 873 | dep if isinstance(dep, TaskNode) else self.add_immediate_node(dep) 874 | for dep in dependencies 875 | ), 876 | ) 877 | 878 | def _add_node[_ReturnType, *_InputsType]( 879 | self, 880 | task: Callable[[*_InputsType], Awaitable[_ReturnType]], 881 | *dependencies: TaskNode[_ParameterType, object], 882 | ) -> TaskNode[_ParameterType, _ReturnType]: 883 | if self._is_sorted: 884 | raise ValueError("'add_node' can not be called after 'sort'") 885 | 886 | for dep in dependencies: 887 | if dep._task_manager is not self: 888 | raise ValueError( 889 | f"Task manager mismatch, expected: {self} but got: {dep._task_manager}" 890 | ) 891 | 892 | task_node = TaskNode( 893 | task, 894 | self, 895 | [dep_id._id for dep_id in dependencies], 896 | len(self._tasks), 897 | ) 898 | self._tasks.append(task_node) 899 | return task_node 900 | 901 | 902 | @overload 903 | @contextmanager 904 | def build_dag() -> Iterator[TaskManager[None]]: ... 905 | 906 | 907 | @overload 908 | @contextmanager 909 | def build_dag[T](parameter_type: type[T]) -> Iterator[TaskManager[T]]: ... 910 | 911 | 912 | @contextmanager # type: ignore 913 | def build_dag[T](_: type[T] | None = None) -> Iterator[TaskManager[T]]: 914 | """ 915 | A helper function that returns a context manager that calls sort for you on the `TaskManager` it creates. 916 | This is useful for creating an indented section that defines your DAG. 917 | 918 | The first parameter defines the type T of `TaskManager[T]` which sets the `invoke(parameter: T)` parameter type. 919 | 920 | #### Example: 921 | ```python 922 | async def add(n: int) -> int: 923 | return n + 1 924 | 925 | # Define a dag that receives an int as a parameter 926 | with build_dag(int) as tm: 927 | # Build the DAG 928 | node_1 = tm.add_node(add, tm.parameter_node) # use the parameter passed to `tm.invoke` 929 | node_2 = tm.add_node(add, node_1) # use the value returned from node_1 930 | node_3 = tm.add_node(add, node_2) # use the value returned from node_2 931 | 932 | # After we exited the `with` block we can already call `tm.invoke` because the context manager handled the `sort` call for us. 933 | 934 | execution_result = await tm.invoke(0) # Invoke our DAG 935 | 936 | # Extract the result from one of the nodes 937 | print(node_3.extract_result(execution_result)) # prints 3 938 | ``` 939 | """ 940 | task_manager = TaskManager[T]() 941 | 942 | yield task_manager 943 | 944 | task_manager.sort() 945 | --------------------------------------------------------------------------------