├── pytest.ini ├── .gitignore ├── .github └── workflows │ ├── test.yml │ └── publish.yml ├── pyproject.toml ├── asyncinject ├── __init__.py └── vendored_graphlib.py ├── README.md ├── tests └── test_asyncinject.py └── LICENSE /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | asyncio_mode = strict 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | uv.lock 2 | .venv 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | venv 7 | .eggs 8 | .pytest_cache 9 | *.egg-info 10 | .DS_Store 11 | dist 12 | build 13 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: [push, pull_request] 4 | 5 | permissions: 6 | contents: read 7 | 8 | jobs: 9 | test: 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] 14 | steps: 15 | - uses: actions/checkout@v5 16 | - name: Set up Python ${{ matrix.python-version }} 17 | uses: actions/setup-python@v6 18 | with: 19 | python-version: ${{ matrix.python-version }} 20 | cache: pip 21 | cache-dependency-path: pyproject.toml 22 | - name: Install dependencies 23 | run: | 24 | pip install '.[test]' 25 | - name: Run tests 26 | run: | 27 | python -m pytest 28 | 29 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "asyncinject" 3 | version = "0.6.1" 4 | description = "Run async workflows using pytest-fixtures-style dependency injection" 5 | readme = "README.md" 6 | requires-python = ">=3.10" 7 | authors = [{name = "Simon Willison"}] 8 | license = "Apache-2.0" 9 | classifiers = [] 10 | dependencies = [] 11 | 12 | [build-system] 13 | requires = ["setuptools"] 14 | build-backend = "setuptools.build_meta" 15 | 16 | [project.urls] 17 | Homepage = "https://github.com/simonw/asyncinject" 18 | Changelog = "https://github.com/simonw/asyncinject/releases" 19 | Issues = "https://github.com/simonw/asyncinject/issues" 20 | CI = "https://github.com/simonw/asyncinject/actions" 21 | 22 | [project.optional-dependencies] 23 | test = ["pytest", "pytest-asyncio"] 24 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python Package 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | permissions: 8 | contents: read 9 | 10 | jobs: 11 | test: 12 | runs-on: ubuntu-latest 13 | strategy: 14 | matrix: 15 | python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] 16 | steps: 17 | - uses: actions/checkout@v5 18 | - name: Set up Python ${{ matrix.python-version }} 19 | uses: actions/setup-python@v6 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | cache: pip 23 | cache-dependency-path: pyproject.toml 24 | - name: Install dependencies 25 | run: | 26 | pip install '.[test]' 27 | - name: Run tests 28 | run: | 29 | python -m pytest 30 | build: 31 | runs-on: ubuntu-latest 32 | needs: [test] 33 | steps: 34 | - uses: actions/checkout@v5 35 | - name: Set up Python 36 | uses: actions/setup-python@v6 37 | with: 38 | python-version: "3.14" 39 | cache: pip 40 | cache-dependency-path: pyproject.toml 41 | - name: Install dependencies 42 | run: | 43 | pip install setuptools wheel build 44 | - name: Build 45 | run: | 46 | python -m build 47 | - name: Store the distribution packages 48 | uses: actions/upload-artifact@v4 49 | with: 50 | name: python-packages 51 | path: dist/ 52 | publish: 53 | name: Publish to PyPI 54 | runs-on: ubuntu-latest 55 | if: startsWith(github.ref, 'refs/tags/') 56 | needs: [build] 57 | environment: release 58 | permissions: 59 | id-token: write 60 | steps: 61 | - name: Download distribution packages 62 | uses: actions/download-artifact@v4 63 | with: 64 | name: python-packages 65 | path: dist/ 66 | - name: Publish to PyPI 67 | uses: pypa/gh-action-pypi-publish@release/v1 68 | 69 | -------------------------------------------------------------------------------- /asyncinject/__init__.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import time 3 | 4 | try: 5 | import graphlib 6 | except ImportError: 7 | from . import vendored_graphlib as graphlib 8 | import asyncio 9 | 10 | 11 | class Registry: 12 | def __init__(self, *fns, parallel=True, timer=None): 13 | self._registry = {} 14 | self._graph = None 15 | self._reversed = None 16 | self.parallel = parallel 17 | self.timer = timer 18 | for fn in fns: 19 | self.register(fn) 20 | 21 | @classmethod 22 | def from_dict(cls, d, parallel=True, timer=None): 23 | instance = cls(parallel=parallel, timer=timer) 24 | for key, fn in d.items(): 25 | instance.register(fn, name=key) 26 | return instance 27 | 28 | def register(self, fn, *, name=None): 29 | self._registry[name or fn.__name__] = fn 30 | # Clear caches: 31 | self._graph = None 32 | self._reversed = None 33 | 34 | def _make_time_logger(self, awaitable): 35 | async def inner(): 36 | start = time.perf_counter() 37 | result = await awaitable 38 | end = time.perf_counter() 39 | self.timer(awaitable.__name__, start, end) 40 | return result 41 | 42 | return inner() 43 | 44 | @property 45 | def graph(self): 46 | if self._graph is None: 47 | self._graph = { 48 | key: set(inspect.signature(fn).parameters.keys()) 49 | for key, fn in self._registry.items() 50 | } 51 | return self._graph 52 | 53 | @property 54 | def reversed(self): 55 | if self._reversed is None: 56 | self._reversed = dict(reversed(pair) for pair in self._registry.items()) 57 | return self._reversed 58 | 59 | async def resolve(self, fn, **kwargs): 60 | if not isinstance(fn, str): 61 | # It's a fn - is it a registered one? 62 | name = self.reversed.get(fn) 63 | if name is None: 64 | # Special case - since it is not registered we need to 65 | # introspect its parameters here and use resolve_multi 66 | params = inspect.signature(fn).parameters.keys() 67 | to_resolve = {p for p in params if p not in kwargs} 68 | resolved = await self.resolve_multi(to_resolve, results=kwargs) 69 | result = fn(**{param: resolved[param] for param in params}) 70 | if asyncio.iscoroutine(result): 71 | result = await result 72 | return result 73 | else: 74 | name = fn 75 | results = await self.resolve_multi([name], results=kwargs) 76 | return results[name] 77 | 78 | def _plan(self, names, results=None): 79 | if results is None: 80 | results = {} 81 | 82 | ts = graphlib.TopologicalSorter() 83 | to_do = set(names) 84 | done = set(results.keys()) 85 | while to_do: 86 | item = to_do.pop() 87 | dependencies = self.graph.get(item) or set() 88 | ts.add(item, *dependencies) 89 | done.add(item) 90 | # Add any not-done dependencies to the queue 91 | to_do.update({k for k in dependencies if k not in done}) 92 | 93 | return ts 94 | 95 | def _get_awaitable(self, name, results): 96 | fn = self._registry[name] 97 | kwargs = {k: v for k, v in results.items() if k in self.graph[name]} 98 | 99 | awaitable_fn = fn 100 | 101 | if not inspect.iscoroutinefunction(fn): 102 | 103 | async def _awaitable(*args, **kwargs): 104 | return fn(*args, **kwargs) 105 | 106 | _awaitable.__name__ = fn.__name__ 107 | awaitable_fn = _awaitable 108 | 109 | aw = awaitable_fn(**kwargs) 110 | if self.timer: 111 | aw = self._make_time_logger(aw) 112 | return aw 113 | 114 | async def _execute_sequential(self, results, ts): 115 | for name in ts.static_order(): 116 | if name not in self._registry: 117 | continue 118 | results[name] = await self._get_awaitable(name, results) 119 | 120 | async def _execute_parallel(self, results, ts): 121 | ts.prepare() 122 | tasks = [] 123 | 124 | def schedule(): 125 | for name in ts.get_ready(): 126 | if name not in self._registry: 127 | ts.done(name) 128 | continue 129 | tasks.append(asyncio.create_task(worker(name))) 130 | 131 | async def worker(name): 132 | res = await self._get_awaitable(name, results) 133 | results[name] = res 134 | ts.done(name) 135 | schedule() 136 | 137 | schedule() 138 | while tasks: 139 | await asyncio.gather(*[tasks.pop() for _ in range(len(tasks))]) 140 | 141 | async def resolve_multi(self, names, results=None): 142 | if results is None: 143 | results = {} 144 | 145 | ts = self._plan(names, results) 146 | 147 | if self.parallel: 148 | await self._execute_parallel(results, ts) 149 | else: 150 | await self._execute_sequential(results, ts) 151 | 152 | return results 153 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # asyncinject 2 | 3 | [![PyPI](https://img.shields.io/pypi/v/asyncinject.svg)](https://pypi.org/project/asyncinject/) 4 | [![Changelog](https://img.shields.io/github/v/release/simonw/asyncinject?include_prereleases&label=changelog)](https://github.com/simonw/asyncinject/releases) 5 | [![License](https://img.shields.io/badge/license-Apache%202.0-blue.svg)](https://github.com/simonw/asyncinject/blob/main/LICENSE) 6 | 7 | Run async workflows using pytest-fixtures-style dependency injection 8 | 9 | ## Installation 10 | 11 | Install this library using `pip`: 12 | 13 | $ pip install asyncinject 14 | 15 | ## Usage 16 | 17 | This library is inspired by [pytest fixtures](https://docs.pytest.org/en/6.2.x/fixture.html). 18 | 19 | The idea is to simplify executing parallel `asyncio` operations by allowing them to be defined using a collection of functions, where the function arguments represent dependent functions that need to be executed first. 20 | 21 | The library can then create and execute a plan for executing the required functions in parallel in the most efficient sequence possible. 22 | 23 | Here's an example, using the [httpx](https://www.python-httpx.org/) HTTP library. 24 | 25 | ```python 26 | from asyncinject import Registry 27 | import httpx 28 | 29 | 30 | async def get(url): 31 | async with httpx.AsyncClient() as client: 32 | return (await client.get(url)).text 33 | 34 | async def example(): 35 | return await get("http://www.example.com/") 36 | 37 | async def simonwillison(): 38 | return await get("https://simonwillison.net/search/?tag=empty") 39 | 40 | async def both(example, simonwillison): 41 | return example + "\n\n" + simonwillison 42 | 43 | registry = Registry(example, simonwillison, both) 44 | combined = await registry.resolve(both) 45 | print(combined) 46 | ``` 47 | If you run this in `ipython` or `python -m asyncio` (to enable top-level await in the console) you will see output that combines HTML from both of those pages. 48 | 49 | The HTTP requests to `www.example.com` and `simonwillison.net` will be performed in parallel. 50 | 51 | The library notices that `both()` takes two arguments which are the names of other registered `async def` functions, and will construct an execution plan that executes those two functions in parallel, then passes their results to the `both()` method. 52 | 53 | ### Registry.from_dict() 54 | 55 | Passing a list of functions to the `Registry` constructor will register each function under their introspected function name, using `fn.__name__`. 56 | 57 | You can set explicit names instead using a dictionary: 58 | 59 | ```python 60 | registry = Registry.from_dict({ 61 | "example": example, 62 | "simonwillison": simonwillison, 63 | "both": both 64 | }) 65 | ``` 66 | Those string names will be used to match parameters, so each function will need to accept parameters named after the keys used in that dictionary. 67 | 68 | ### Registering additional functions 69 | 70 | Functions that are registered can be regular functions or `async def` functions. 71 | 72 | In addition to registering functions by passing them to the constructor, you can also add them to a registry using the `.register()` method: 73 | 74 | ```python 75 | async def another(): 76 | return "another" 77 | 78 | registry.register(another) 79 | ``` 80 | To register them with a name other than the name of the function, pass the `name=` argument: 81 | ```python 82 | async def another(): 83 | return "another 2" 84 | 85 | registry.register(another, name="another_2") 86 | ``` 87 | 88 | ### Resolving an unregistered function 89 | 90 | You don't need to register the final function that you pass to `.resolve()` - if you pass an unregistered function, the library will introspect the function's parameters and resolve them directly. 91 | 92 | This works with both regular and async functions: 93 | 94 | ```python 95 | async def one(): 96 | return 1 97 | 98 | async def two(): 99 | return 2 100 | 101 | registry = Registry(one, two) 102 | 103 | # async def works here too: 104 | def three(one, two): 105 | return one + two 106 | 107 | print(await registry.resolve(three)) 108 | # Prints 3 109 | ``` 110 | 111 | ### Parameters are passed through 112 | 113 | Your dependent functions can require keyword arguments which have been passed to the `.resolve()` call: 114 | 115 | ```python 116 | async def get_param_1(param1): 117 | return await get(param1) 118 | 119 | async def get_param_2(param2): 120 | return await get(param2) 121 | 122 | async def both(get_param_1, get_param_2): 123 | return get_param_1 + "\n\n" + get_param_2 124 | 125 | 126 | combined = await Registry(get_param_1, get_param_2, both).resolve( 127 | both, 128 | param1 = "http://www.example.com/", 129 | param2 = "https://simonwillison.net/search/?tag=empty" 130 | ) 131 | print(combined) 132 | ``` 133 | ### Parameters with default values are ignored 134 | 135 | You can opt a parameter out of the dependency injection mechanism by assigning it a default value: 136 | 137 | ```python 138 | async def go(calc1, x=5): 139 | return calc1 + x 140 | 141 | async def calc1(): 142 | return 5 143 | 144 | print(await Registry(calc1, go).resolve(go)) 145 | # Prints 10 146 | ``` 147 | 148 | ### Tracking with a timer 149 | 150 | You can pass a `timer=` callable to the `Registry` constructor to gather timing information about executed tasks.. Your function should take three positional arguments: 151 | 152 | - `name` - the name of the function that is being timed 153 | - `start` - the time that it started executing, using `time.perf_counter()` ([perf_counter() docs](https://docs.python.org/3/library/time.html#time.perf_counter)) 154 | - `end` - the time that it finished executing 155 | 156 | You can use `print` here too: 157 | 158 | ```python 159 | combined = await Registry( 160 | get_param_1, get_param_2, both, timer=print 161 | ).resolve( 162 | both, 163 | param1 = "http://www.example.com/", 164 | param2 = "https://simonwillison.net/search/?tag=empty" 165 | ) 166 | ``` 167 | This will output: 168 | ``` 169 | get_param_1 436633.584580685 436633.797921747 170 | get_param_2 436633.641832699 436634.196364347 171 | both 436634.196570217 436634.196575639 172 | ``` 173 | ### Turning off parallel execution 174 | 175 | By default, functions that can run in parallel according to the execution plan will run in parallel using `asyncio.gather()`. 176 | 177 | You can disable this parallel exection by passing `parallel=False` to the `Registry` constructor, or by setting `registry.parallel = False` after the registry object has been created. 178 | 179 | This is mainly useful for benchmarking the difference between parallel and serial execution for your project. 180 | 181 | ## Development 182 | 183 | To contribute to this library, first checkout the code. Then create a new virtual environment: 184 | 185 | cd asyncinject 186 | python -m venv venv 187 | source venv/bin/activate 188 | 189 | Now install the dependencies and test dependencies: 190 | 191 | pip install -e '.[test]' 192 | 193 | To run the tests: 194 | 195 | pytest 196 | -------------------------------------------------------------------------------- /tests/test_asyncinject.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import pytest 3 | from asyncinject import Registry 4 | from random import random 5 | import time 6 | 7 | 8 | @pytest.fixture 9 | def complex_registry(): 10 | async def log(): 11 | return [] 12 | 13 | async def d(log): 14 | await asyncio.sleep(0.1 + random() * 0.5) 15 | log.append("d") 16 | 17 | async def c(log): 18 | await asyncio.sleep(0.1 + random() * 0.5) 19 | log.append("c") 20 | 21 | async def b(log, c, d): 22 | log.append("b") 23 | 24 | async def a(log, b, c): 25 | log.append("a") 26 | 27 | async def go(log, a): 28 | log.append("go") 29 | return log 30 | 31 | return Registry(log, d, c, b, a, go) 32 | 33 | 34 | @pytest.mark.asyncio 35 | async def test_complex(complex_registry): 36 | result = await complex_registry.resolve("go") 37 | # 'c' should only be called once 38 | assert tuple(result) in ( 39 | # c and d could happen in either order 40 | ("c", "d", "b", "a", "go"), 41 | ("d", "c", "b", "a", "go"), 42 | ) 43 | 44 | 45 | @pytest.mark.asyncio 46 | async def test_with_parameters(): 47 | async def go(calc1, calc2, param1): 48 | return param1 + calc1 + calc2 49 | 50 | async def calc1(): 51 | return 5 52 | 53 | async def calc2(): 54 | return 6 55 | 56 | registry = Registry(go, calc1, calc2) 57 | result = await registry.resolve(go, param1=4) 58 | assert result == 15 59 | 60 | # Should throw an error if that parameter is missing 61 | with pytest.raises(TypeError) as e: 62 | result = await registry.resolve(go) 63 | assert "go() missing 1 required positional" in e.args[0] 64 | 65 | 66 | @pytest.mark.asyncio 67 | async def test_parameters_passed_through(): 68 | async def go(calc1, calc2, param1): 69 | return calc1 + calc2 70 | 71 | async def calc1(): 72 | return 5 73 | 74 | async def calc2(param1): 75 | return 6 + param1 76 | 77 | registry = Registry(go, calc1, calc2) 78 | result = await registry.resolve(go, param1=1) 79 | assert result == 12 80 | 81 | 82 | @pytest.mark.asyncio 83 | async def test_ignore_default_parameters(): 84 | async def go(calc1, x=5): 85 | return calc1 + x 86 | 87 | async def calc1(): 88 | return 5 89 | 90 | registry = Registry(go, calc1) 91 | result = await registry.resolve(go) 92 | assert result == 10 93 | 94 | 95 | @pytest.mark.asyncio 96 | async def test_timer(complex_registry): 97 | collected = [] 98 | complex_registry.timer = lambda name, start, end: collected.append( 99 | (name, start, end) 100 | ) 101 | await complex_registry.resolve("go") 102 | assert len(collected) == 6 103 | names = [c[0] for c in collected] 104 | starts = [c[1] for c in collected] 105 | ends = [c[2] for c in collected] 106 | assert all(isinstance(n, float) for n in starts) 107 | assert all(isinstance(n, float) for n in ends) 108 | assert names[0] == "log" 109 | assert names[5] == "go" 110 | assert sorted(names[1:5]) == ["a", "b", "c", "d"] 111 | 112 | 113 | @pytest.mark.asyncio 114 | async def test_parallel(complex_registry): 115 | collected = [] 116 | complex_registry.timer = lambda name, start, end: collected.append( 117 | (name, start, end) 118 | ) 119 | # Run it once in parallel=True mode 120 | await complex_registry.resolve("go") 121 | parallel_timings = {c[0]: (c[1], c[2]) for c in collected} 122 | # 'c' and 'd' should have started within 0.05s 123 | c_start, d_start = parallel_timings["c"][0], parallel_timings["d"][0] 124 | assert abs(c_start - d_start) < 0.05 125 | 126 | # And again in parallel=False mode 127 | collected.clear() 128 | complex_registry.parallel = False 129 | await complex_registry.resolve("go") 130 | serial_timings = {c[0]: (c[1], c[2]) for c in collected} 131 | # 'c' and 'd' should have started at least 0.1s apart 132 | c_start_serial, d_start_serial = serial_timings["c"][0], serial_timings["d"][0] 133 | assert abs(c_start_serial - d_start_serial) > 0.1 134 | 135 | 136 | @pytest.mark.asyncio 137 | async def test_optimal_concurrency(): 138 | # https://github.com/simonw/asyncinject/issues/10 139 | async def a(): 140 | await asyncio.sleep(0.1) 141 | 142 | async def b(): 143 | await asyncio.sleep(0.2) 144 | 145 | async def c(a): 146 | await asyncio.sleep(0.1) 147 | 148 | async def d(b, c): 149 | pass 150 | 151 | registry = Registry(a, b, c, d) 152 | start = time.perf_counter() 153 | await registry.resolve(d) 154 | end = time.perf_counter() 155 | # Should have taken ~0.2s 156 | assert 0.18 < (end - start) < 0.22 157 | 158 | 159 | @pytest.mark.asyncio 160 | @pytest.mark.parametrize("use_async", (True, False)) 161 | async def test_resolve_unregistered_function(use_async): 162 | # https://github.com/simonw/asyncinject/issues/13 163 | async def one(): 164 | return 1 165 | 166 | async def two(): 167 | return 2 168 | 169 | registry = Registry(one, two) 170 | 171 | async def three_async(one, two): 172 | return one + two 173 | 174 | def three_not_async(one, two): 175 | return one + two 176 | 177 | fn = three_async if use_async else three_not_async 178 | result = await registry.resolve(fn) 179 | assert result == 3 180 | 181 | # Test that passing parameters works too 182 | result2 = await registry.resolve(fn, one=2) 183 | assert result2 == 4 184 | 185 | 186 | @pytest.mark.asyncio 187 | async def test_register(): 188 | registry = Registry() 189 | 190 | # Mix in a non-async function too: 191 | def one(): 192 | return "one" 193 | 194 | async def two_(): 195 | return "two" 196 | 197 | async def three(one, two): 198 | return one + two 199 | 200 | registry.register(one) 201 | 202 | # Should raise an error if you don't use name= 203 | with pytest.raises(TypeError): 204 | registry.register(two_, "two") 205 | 206 | registry.register(two_, name="two") 207 | 208 | result = await registry.resolve(three) 209 | 210 | assert result == "onetwo" 211 | 212 | 213 | @pytest.mark.asyncio 214 | @pytest.mark.parametrize("parallel", (True, False)) 215 | async def test_just_sync_functions(parallel): 216 | def one(): 217 | return 1 218 | 219 | def two(): 220 | return 2 221 | 222 | def three(one, two): 223 | return one + two 224 | 225 | timed = [] 226 | 227 | registry = Registry( 228 | one, two, three, parallel=parallel, timer=lambda *args: timed.append(args) 229 | ) 230 | result = await registry.resolve(three) 231 | assert result == 3 232 | 233 | assert {t[0] for t in timed} == {"two", "one", "three"} 234 | 235 | 236 | @pytest.mark.asyncio 237 | @pytest.mark.parametrize("use_string_name", (True, False)) 238 | async def test_registry_from_dict(use_string_name): 239 | async def _one(): 240 | return 1 241 | 242 | async def _two(): 243 | return 2 244 | 245 | async def _three(one, two): 246 | return one + two 247 | 248 | registry = Registry.from_dict({"one": _one, "two": _two, "three": _three}) 249 | if use_string_name: 250 | result = await registry.resolve("three") 251 | else: 252 | result = await registry.resolve(_three) 253 | assert result == 3 254 | -------------------------------------------------------------------------------- /asyncinject/vendored_graphlib.py: -------------------------------------------------------------------------------- 1 | # Vendored from https://raw.githubusercontent.com/python/cpython/3.10/Lib/graphlib.py 2 | # Modified to work on Python 3.6 (I removed := operator) 3 | # License: https://github.com/python/cpython/blob/main/LICENSE 4 | 5 | __all__ = ["TopologicalSorter", "CycleError"] 6 | 7 | _NODE_OUT = -1 8 | _NODE_DONE = -2 9 | 10 | 11 | class _NodeInfo: 12 | __slots__ = "node", "npredecessors", "successors" 13 | 14 | def __init__(self, node): 15 | # The node this class is augmenting. 16 | self.node = node 17 | 18 | # Number of predecessors, generally >= 0. When this value falls to 0, 19 | # and is returned by get_ready(), this is set to _NODE_OUT and when the 20 | # node is marked done by a call to done(), set to _NODE_DONE. 21 | self.npredecessors = 0 22 | 23 | # List of successor nodes. The list can contain duplicated elements as 24 | # long as they're all reflected in the successor's npredecessors attribute. 25 | self.successors = [] 26 | 27 | 28 | class CycleError(ValueError): 29 | """Subclass of ValueError raised by TopologicalSorter.prepare if cycles 30 | exist in the working graph. 31 | 32 | If multiple cycles exist, only one undefined choice among them will be reported 33 | and included in the exception. The detected cycle can be accessed via the second 34 | element in the *args* attribute of the exception instance and consists in a list 35 | of nodes, such that each node is, in the graph, an immediate predecessor of the 36 | next node in the list. In the reported list, the first and the last node will be 37 | the same, to make it clear that it is cyclic. 38 | """ 39 | 40 | pass 41 | 42 | 43 | class TopologicalSorter: 44 | """Provides functionality to topologically sort a graph of hashable nodes""" 45 | 46 | def __init__(self, graph=None): 47 | self._node2info = {} 48 | self._ready_nodes = None 49 | self._npassedout = 0 50 | self._nfinished = 0 51 | 52 | if graph is not None: 53 | for node, predecessors in graph.items(): 54 | self.add(node, *predecessors) 55 | 56 | def _get_nodeinfo(self, node): 57 | result = self._node2info.get(node) 58 | if result is None: 59 | self._node2info[node] = result = _NodeInfo(node) 60 | return result 61 | 62 | def add(self, node, *predecessors): 63 | """Add a new node and its predecessors to the graph. 64 | 65 | Both the *node* and all elements in *predecessors* must be hashable. 66 | 67 | If called multiple times with the same node argument, the set of dependencies 68 | will be the union of all dependencies passed in. 69 | 70 | It is possible to add a node with no dependencies (*predecessors* is not provided) 71 | as well as provide a dependency twice. If a node that has not been provided before 72 | is included among *predecessors* it will be automatically added to the graph with 73 | no predecessors of its own. 74 | 75 | Raises ValueError if called after "prepare". 76 | """ 77 | if self._ready_nodes is not None: 78 | raise ValueError("Nodes cannot be added after a call to prepare()") 79 | 80 | # Create the node -> predecessor edges 81 | nodeinfo = self._get_nodeinfo(node) 82 | nodeinfo.npredecessors += len(predecessors) 83 | 84 | # Create the predecessor -> node edges 85 | for pred in predecessors: 86 | pred_info = self._get_nodeinfo(pred) 87 | pred_info.successors.append(node) 88 | 89 | def prepare(self): 90 | """Mark the graph as finished and check for cycles in the graph. 91 | 92 | If any cycle is detected, "CycleError" will be raised, but "get_ready" can 93 | still be used to obtain as many nodes as possible until cycles block more 94 | progress. After a call to this function, the graph cannot be modified and 95 | therefore no more nodes can be added using "add". 96 | """ 97 | if self._ready_nodes is not None: 98 | raise ValueError("cannot prepare() more than once") 99 | 100 | self._ready_nodes = [ 101 | i.node for i in self._node2info.values() if i.npredecessors == 0 102 | ] 103 | # ready_nodes is set before we look for cycles on purpose: 104 | # if the user wants to catch the CycleError, that's fine, 105 | # they can continue using the instance to grab as many 106 | # nodes as possible before cycles block more progress 107 | cycle = self._find_cycle() 108 | if cycle: 109 | raise CycleError(f"nodes are in a cycle", cycle) 110 | 111 | def get_ready(self): 112 | """Return a tuple of all the nodes that are ready. 113 | 114 | Initially it returns all nodes with no predecessors; once those are marked 115 | as processed by calling "done", further calls will return all new nodes that 116 | have all their predecessors already processed. Once no more progress can be made, 117 | empty tuples are returned. 118 | 119 | Raises ValueError if called without calling "prepare" previously. 120 | """ 121 | if self._ready_nodes is None: 122 | raise ValueError("prepare() must be called first") 123 | 124 | # Get the nodes that are ready and mark them 125 | result = tuple(self._ready_nodes) 126 | n2i = self._node2info 127 | for node in result: 128 | n2i[node].npredecessors = _NODE_OUT 129 | 130 | # Clean the list of nodes that are ready and update 131 | # the counter of nodes that we have returned. 132 | self._ready_nodes.clear() 133 | self._npassedout += len(result) 134 | 135 | return result 136 | 137 | def is_active(self): 138 | """Return ``True`` if more progress can be made and ``False`` otherwise. 139 | 140 | Progress can be made if cycles do not block the resolution and either there 141 | are still nodes ready that haven't yet been returned by "get_ready" or the 142 | number of nodes marked "done" is less than the number that have been returned 143 | by "get_ready". 144 | 145 | Raises ValueError if called without calling "prepare" previously. 146 | """ 147 | if self._ready_nodes is None: 148 | raise ValueError("prepare() must be called first") 149 | return self._nfinished < self._npassedout or bool(self._ready_nodes) 150 | 151 | def __bool__(self): 152 | return self.is_active() 153 | 154 | def done(self, *nodes): 155 | """Marks a set of nodes returned by "get_ready" as processed. 156 | 157 | This method unblocks any successor of each node in *nodes* for being returned 158 | in the future by a call to "get_ready". 159 | 160 | Raises :exec:`ValueError` if any node in *nodes* has already been marked as 161 | processed by a previous call to this method, if a node was not added to the 162 | graph by using "add" or if called without calling "prepare" previously or if 163 | node has not yet been returned by "get_ready". 164 | """ 165 | 166 | if self._ready_nodes is None: 167 | raise ValueError("prepare() must be called first") 168 | 169 | n2i = self._node2info 170 | 171 | for node in nodes: 172 | # Check if we know about this node (it was added previously using add() 173 | nodeinfo = n2i.get(node) 174 | if nodeinfo is None: 175 | raise ValueError(f"node {node!r} was not added using add()") 176 | 177 | # If the node has not being returned (marked as ready) previously, inform the user. 178 | stat = nodeinfo.npredecessors 179 | if stat != _NODE_OUT: 180 | if stat >= 0: 181 | raise ValueError( 182 | f"node {node!r} was not passed out (still not ready)" 183 | ) 184 | elif stat == _NODE_DONE: 185 | raise ValueError(f"node {node!r} was already marked done") 186 | else: 187 | assert False, f"node {node!r}: unknown status {stat}" 188 | 189 | # Mark the node as processed 190 | nodeinfo.npredecessors = _NODE_DONE 191 | 192 | # Go to all the successors and reduce the number of predecessors, collecting all the ones 193 | # that are ready to be returned in the next get_ready() call. 194 | for successor in nodeinfo.successors: 195 | successor_info = n2i[successor] 196 | successor_info.npredecessors -= 1 197 | if successor_info.npredecessors == 0: 198 | self._ready_nodes.append(successor) 199 | self._nfinished += 1 200 | 201 | def _find_cycle(self): 202 | n2i = self._node2info 203 | stack = [] 204 | itstack = [] 205 | seen = set() 206 | node2stacki = {} 207 | 208 | for node in n2i: 209 | if node in seen: 210 | continue 211 | 212 | while True: 213 | if node in seen: 214 | # If we have seen already the node and is in the 215 | # current stack we have found a cycle. 216 | if node in node2stacki: 217 | return stack[node2stacki[node] :] + [node] 218 | # else go on to get next successor 219 | else: 220 | seen.add(node) 221 | itstack.append(iter(n2i[node].successors).__next__) 222 | node2stacki[node] = len(stack) 223 | stack.append(node) 224 | 225 | # Backtrack to the topmost stack entry with 226 | # at least another successor. 227 | while stack: 228 | try: 229 | node = itstack[-1]() 230 | break 231 | except StopIteration: 232 | del node2stacki[stack.pop()] 233 | itstack.pop() 234 | else: 235 | break 236 | return None 237 | 238 | def static_order(self): 239 | """Returns an iterable of nodes in a topological order. 240 | 241 | The particular order that is returned may depend on the specific 242 | order in which the items were inserted in the graph. 243 | 244 | Using this method does not require to call "prepare" or "done". If any 245 | cycle is detected, :exc:`CycleError` will be raised. 246 | """ 247 | self.prepare() 248 | while self.is_active(): 249 | node_group = self.get_ready() 250 | yield from node_group 251 | self.done(*node_group) 252 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------