├── .codespell_ignore ├── .flake8 ├── .gitignore ├── .gitpod.Dockerfile ├── .gitpod.yml ├── .isort.cfg ├── .pre-commit-config.yaml ├── .travis.yml ├── LICENSE.md ├── README.md ├── computation_graph ├── __init__.py ├── base_types.py ├── composers │ ├── __init__.py │ ├── composers_test.py │ ├── condition.py │ ├── condition_test.py │ ├── debug.py │ ├── duplication.py │ ├── lift.py │ ├── logic.py │ ├── memory.py │ └── memory_test.py ├── graph.py ├── graph_runners.py ├── graph_test.py ├── legacy.py ├── run.py ├── signature.py └── trace │ ├── __init__.py │ ├── ascii.py │ ├── graphviz.py │ ├── graphviz_test.py │ ├── mermaid.py │ └── trace_utils.py ├── pyproject.toml ├── pytest.ini └── setup.py /.codespell_ignore: -------------------------------------------------------------------------------- 1 | juxt -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E501,W503,E203,E231 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | 141 | # pytype static type analyzer 142 | .pytype/ 143 | 144 | # Cython debug symbols 145 | cython_debug/ 146 | 147 | # PyCharm 148 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can 149 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 150 | # and can be added to the global gitignore or merged into this file. For a more nuclear 151 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 152 | #.idea/ -------------------------------------------------------------------------------- /.gitpod.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM gitpod/workspace-full 2 | 3 | USER gitpod 4 | 5 | RUN sudo apt-get install -yq python3-dev graphviz graphviz-dev && \ 6 | pyenv update && \ 7 | pyenv install 3.9.10 && \ 8 | pyenv global 3.9.10 && \ 9 | python -m pip install --no-cache-dir --upgrade pip && \ 10 | echo "alias pip='python -m pip'" >> ~/.bash_aliases 11 | -------------------------------------------------------------------------------- /.gitpod.yml: -------------------------------------------------------------------------------- 1 | tasks: 2 | - before: | 3 | pip install pre-commit 4 | pre-commit install --install-hooks 5 | pip install -e .[test] 6 | image: 7 | file: .gitpod.Dockerfile 8 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | float_to_top=true 3 | atomic = true 4 | line_length = 88 5 | multi_line_output = 3 6 | include_trailing_comma = true 7 | known_third_party = gamla,immutables,pygraphviz,pytest,setuptools,termcolor,toposort,typeguard 8 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.1.0 4 | hooks: 5 | - id: check-added-large-files 6 | - id: debug-statements 7 | - id: check-json 8 | - id: pretty-format-json 9 | args: [ "--autofix", "--no-ensure-ascii", "--no-sort-keys" ] 10 | 11 | - repo: https://github.com/asottile/seed-isort-config 12 | rev: v2.2.0 13 | hooks: 14 | - id: seed-isort-config 15 | 16 | - repo: https://github.com/PyCQA/isort 17 | rev: 5.12.0 18 | hooks: 19 | - id: isort 20 | 21 | - repo: https://github.com/ambv/black 22 | rev: 22.3.0 23 | hooks: 24 | - id: black 25 | 26 | - repo: https://github.com/pycqa/flake8 27 | rev: 3.9.2 28 | hooks: 29 | - id: flake8 30 | additional_dependencies: 31 | [ 32 | "flake8-assertive", 33 | "flake8-comprehensions", 34 | "flake8-mutable", 35 | "flake8-print", 36 | "flake8-self", 37 | "pep8-naming", 38 | ] 39 | 40 | - repo: https://github.com/pre-commit/mirrors-mypy 41 | rev: v0.942 42 | hooks: 43 | - id: mypy 44 | additional_dependencies: [types-termcolor] 45 | 46 | - repo: https://github.com/hyroai/lint 47 | rev: 96a6998defb11e34d1057cc79616481565f568e3 48 | hooks: 49 | - id: static-analysis 50 | 51 | - repo: https://github.com/codespell-project/codespell 52 | rev: v2.1.0 53 | hooks: 54 | - id: codespell 55 | entry: codespell --ignore-words=.codespell_ignore --quiet-level=4 --check-filenames 56 | exclude: \.(csv|json|txt)$ 57 | 58 | - repo: https://github.com/myint/autoflake 59 | rev: v1.4 60 | hooks: 61 | - id: autoflake 62 | entry: autoflake -i --remove-all-unused-imports 63 | 64 | - repo: https://github.com/alan-turing-institute/CleverCSV-pre-commit 65 | rev: v0.7.5 66 | hooks: 67 | - id: clevercsv-standardize 68 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - '3.11' 4 | dist: focal 5 | jobs: 6 | include: 7 | - stage: test 8 | name: pytest 9 | install: 10 | - sudo apt update && sudo apt install graphviz graphviz-dev 11 | - pip install .[test] 12 | script: pytest 13 | - stage: test 14 | name: pre-commit 15 | install: pip install pre-commit 16 | script: pre-commit run --all-files 17 | deploy: 18 | provider: pypi 19 | username: __token__ 20 | skip_existing: true 21 | password: 22 | secure: DtG3+MJ2ZaZDM+Jre0LJakgV6F4LgMYBFjF7Prb7tqA9GE2xuOZegsjtADrZeegiI5nNMZfS3DAcL46g4+gbH33rqTUAoUnaddnML4YPpc12hnDSml7MYVCjv9UYn8oLlquzjasDBGMstIvUWmypPYLkHEkenZ9uj7n8zDh0trlnFcEDQu3p+yOKZ3hk9ysNi9opA8Spu1hix7PvW9iuXW7dJiJlSvUxxiphkKywJj4WTOWwuXWThmXPD5z3rfo7qMn4QZXdf8DnfW/rlk8vZaH9LXQmBeYhyKqBnvLfgFR0DlksVloPVjhLJCaiti+meLINkY6YbjsTc5Yv++6Ezi/SgeAeLT4m3oj2TjhnIiM6Qt/0eKPgpqZEJV7p0/NOwdz3B2xU18jrgZKBqkkveQnonBaZjBxWcubXiRdmrcqNuMxmQqWYCInEKmK2ngeMIZR3MsI4N3DK93vzMWtk+X4gFvghuc9uhZLfqooxwuMsFNxUVJ86TBZKPKyFjvNmdluZszbCTY38FjonIsu0wsOS81dyHJfGrtMKIfpwp5HwidAeWTPHEZS8kpgGUi9d8SUukqYAdJPKifV720biT73pH3I92NQi7iIpnJvuiYh5ramrrqD9CozlNpZW9zD1jT2rExGMXkGF+U8C5tUsKxOpv7EkdsoPg0Ch5ac8zkg= 23 | cache: 24 | directories: 25 | - $HOME/.cache/pre-commit -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 hyroai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.com/hyroai/computation-graph.svg?branch=master)](https://travis-ci.com/hyroai/computation-graph) 2 | 3 | A function composition framework that supports: 4 | 5 | 1. State - functions which retain state for their next turn of action. 6 | 2. Prioritized paths - lazily attempt overloaded composition paths according to priorities. 7 | 3. Deep dependency injection - compose a function to a variadic function at the end of an arbitrarily long pipeline. 8 | 4. Non cancerous `asyncio` support. 9 | 10 | `pip install computation-graph` 11 | 12 | To deploy: `python setup.py sdist bdist_wheel; twine upload dist/*; rm -rf dist/;` 13 | 14 | ### Type checking 15 | 16 | The runner will type check all outputs for nodes with return type annotations. In case of a wrong typing, it will log the node at fault. 17 | 18 | ### Debugging 19 | 20 | #### Computation trace 21 | 22 | Available computation trace visualizers: 23 | 24 | 1. `graphviz.computation_trace` 25 | 1. `mermaid.computation_trace` 26 | 1. `ascii.computation_trace` 27 | 28 | To use, replace `to_callable` with `run.to_callable_with_side_effect` with your selected style as the first argument. 29 | 30 | #### Graphviz debugger 31 | 32 | This debugger will save a file on each graph execution to current working directory. 33 | 34 | You can use this file in a graph viewer like [gephi](https://gephi.org/). 35 | Nodes colored red are part of the 'winning' computation path. 36 | Each of these nodes has the attributes 'result' and 'state'. 37 | 'result' is the output of the node, and 'state' is the _new_ state of the node. 38 | 39 | In gephi you can filter for the nodes participating in calculation of final result by filtering on result != null. 40 | -------------------------------------------------------------------------------- /computation_graph/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyroai/computation-graph/091e1c3b8817ae7f0d4dfb12f1c84d40a655554a/computation_graph/__init__.py -------------------------------------------------------------------------------- /computation_graph/base_types.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import dataclasses 4 | import functools 5 | import os 6 | import typing 7 | from typing import Callable, FrozenSet, Hashable, Optional, Tuple, Union 8 | 9 | import gamla 10 | 11 | COMPUTATION_GRAPH_DEBUG_ENV_KEY = "COMPUTATION_GRAPH_DEBUG" 12 | 13 | Result = Hashable 14 | 15 | 16 | def pretty_print_function_name(f: Callable) -> str: 17 | return f"{f.__code__.co_filename}:{f.__code__.co_firstlineno}:{f.__name__}" 18 | 19 | 20 | _get_unary_input_typing = gamla.compose_left( 21 | typing.get_type_hints, 22 | gamla.when(gamla.inside("return"), gamla.remove_key("return")), 23 | dict.values, 24 | gamla.head, 25 | ) 26 | 27 | 28 | def _mismatch_message(key, source: Callable, destination: Callable) -> str: 29 | return "\n".join( 30 | [ 31 | "", 32 | f"source: {pretty_print_function_name(source)}", 33 | f"destination: {pretty_print_function_name(destination)}", 34 | f"key: {key}", 35 | str(typing.get_type_hints(source)["return"]), 36 | str( 37 | typing.get_type_hints(destination)[key] 38 | if key is not None 39 | else _get_unary_input_typing(destination) 40 | ), 41 | ] 42 | ) 43 | 44 | 45 | class ComputationGraphTypeError(Exception): 46 | pass 47 | 48 | 49 | class SkipComputationError(Exception): 50 | pass 51 | 52 | 53 | @dataclasses.dataclass(frozen=True) 54 | class ComputationEdge: 55 | destination: ComputationNode 56 | priority: int 57 | key: str 58 | source: Optional[ComputationNode] 59 | args: Tuple[ComputationNode, ...] 60 | is_future: bool 61 | 62 | def __post_init__(self): 63 | assert bool(self.args) != bool( 64 | self.source 65 | ), f"Edge must have a source or args, not both: {self}" 66 | if ( 67 | not self.args 68 | # TODO(uri): doesn't support `functools.partial`, suggested to drop support for it entirely. 69 | and not isinstance(self.source.func, functools.partial) 70 | and not isinstance(self.destination.func, functools.partial) 71 | ): 72 | if not gamla.composable(self.destination.func, self.source.func, self.key): 73 | raise ComputationGraphTypeError( 74 | _mismatch_message(self.key, self.source.func, self.destination.func) 75 | ) 76 | 77 | def __repr__(self): 78 | source_str = ( 79 | "".join(map(str, self.args)) if self.source is None else str(self.source) 80 | ) 81 | line = "...." if self.is_future else "----" 82 | return source_str + line + self.key + line + ">" + str(self.destination) 83 | 84 | 85 | @dataclasses.dataclass(frozen=True) 86 | class NodeSignature: 87 | is_args: bool 88 | kwargs: Tuple[str, ...] 89 | optional_kwargs: Tuple[str, ...] 90 | is_kwargs: bool 91 | 92 | 93 | @dataclasses.dataclass(frozen=True) 94 | class ComputationNode: 95 | name: str 96 | func: Callable 97 | signature: NodeSignature 98 | is_terminal: bool 99 | computed_hash: int = dataclasses.field(init=False) 100 | 101 | def __post_init__(self): 102 | object.__setattr__(self, "computed_hash", hash(self.func)) 103 | 104 | def __hash__(self): 105 | return self.computed_hash 106 | 107 | def __repr__(self): 108 | return self.name 109 | 110 | 111 | node_implementation = gamla.attrgetter("func") 112 | node_is_terminal = gamla.attrgetter("is_terminal") 113 | 114 | edge_args = gamla.attrgetter("args") 115 | edge_destination = gamla.attrgetter("destination") 116 | edge_key = gamla.attrgetter("key") 117 | edge_priority = gamla.attrgetter("priority") 118 | edge_source = gamla.attrgetter("source") 119 | edge_is_future = gamla.attrgetter("is_future") 120 | 121 | 122 | def edge_sources(edge: ComputationEdge) -> Tuple[ComputationNode, ...]: 123 | return edge.args or (edge.source,) # type: ignore 124 | 125 | 126 | is_computation_graph = gamla.alljuxt( 127 | # Note that this cannot be set to `GraphType` (due to `is_instance` limitation). 128 | gamla.is_instance((tuple, set, frozenset)), 129 | gamla.len_greater(0), 130 | gamla.allmap(gamla.is_instance(ComputationEdge)), 131 | ) 132 | 133 | 134 | ambiguity_groups = gamla.compose( 135 | frozenset, 136 | gamla.filter(gamla.len_greater(1)), 137 | dict.values, 138 | gamla.keyfilter(gamla.compose_left(gamla.head, gamla.complement(node_is_terminal))), 139 | gamla.groupby(gamla.juxt(edge_destination, edge_key, edge_priority)), 140 | ) 141 | 142 | assert_no_unwanted_ambiguity = gamla.side_effect( 143 | gamla.compose_left( 144 | ambiguity_groups, 145 | gamla.assert_that_with_message( 146 | gamla.wrap_str( 147 | "There are multiple edges with the same destination, key and priority in the computation graph!: {}" 148 | f"\n To get a more relevant stacktrace please set the environment variable with %env {COMPUTATION_GRAPH_DEBUG_ENV_KEY}=1 and rebuild." 149 | ), 150 | gamla.len_equals(0), 151 | ), 152 | ) 153 | ) 154 | 155 | 156 | @gamla.side_effect 157 | def _assert_no_unwanted_ambiguity_when_debug_set(graph): 158 | if os.getenv(COMPUTATION_GRAPH_DEBUG_ENV_KEY) is not None: 159 | assert_no_unwanted_ambiguity(graph) 160 | 161 | 162 | def merge_graphs(*graphs): 163 | s = frozenset({}) 164 | s = s.union(*graphs) 165 | 166 | return _assert_no_unwanted_ambiguity_when_debug_set(s) 167 | 168 | 169 | # We use a tuple to generate a unique id for each node based on the order of edges. 170 | GraphType = FrozenSet[ComputationEdge] 171 | GraphOrCallable = Union[Callable, GraphType] 172 | CallableOrNode = Union[Callable, ComputationNode] 173 | CallableOrNodeOrGraph = Union[CallableOrNode, GraphType] 174 | NodeOrGraph = Union[ComputationNode, GraphType] 175 | EMPTY_GRAPH: GraphType = frozenset() 176 | -------------------------------------------------------------------------------- /computation_graph/composers/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any, Callable, Dict, Iterable, Optional, Sequence 4 | 5 | import gamla 6 | 7 | from computation_graph import base_types, graph, signature 8 | 9 | 10 | class _ComputationError: 11 | pass 12 | 13 | 14 | _callable_or_graph_type_to_node_or_graph_type = gamla.unless( 15 | gamla.is_instance((tuple, set, frozenset)), graph.make_computation_node 16 | ) 17 | 18 | 19 | def _get_edges_from_node_or_graph( 20 | node_or_graph: base_types.NodeOrGraph, 21 | ) -> base_types.GraphType: 22 | if isinstance(node_or_graph, base_types.ComputationNode): 23 | return base_types.EMPTY_GRAPH 24 | return node_or_graph 25 | 26 | 27 | @gamla.curry 28 | def make_optional( 29 | func: base_types.CallableOrNodeOrGraph, default_value: Any 30 | ) -> base_types.GraphType: 31 | return make_first(func, lambda: default_value) 32 | 33 | 34 | def make_and( 35 | funcs: Iterable[base_types.CallableOrNodeOrGraph], 36 | merge_fn: base_types.CallableOrNodeOrGraph, 37 | ) -> base_types.GraphType: 38 | """Aggregate funcs' output into `merge_fn`. 39 | * merge_fn should have only 1 argument named `args`. 40 | * All funcs must not raise an exception in order for merge_fn to run. 41 | >>>make_and(composers.make_and([gamla.just(1), gamla.just(2), gamla.just(3)], lambda args: sum(args))) 42 | (justjustjust----*args---->args_to_tuple, args_to_tuple----args---->) 43 | Will return 6. 44 | """ 45 | 46 | def args_to_tuple(*args): 47 | return args 48 | 49 | merge_node = graph.make_computation_node(args_to_tuple) 50 | 51 | return gamla.sync.pipe( 52 | funcs, 53 | gamla.sync.map(_callable_or_graph_type_to_node_or_graph_type), 54 | tuple, 55 | gamla.sync.juxtcat( 56 | gamla.sync.map(_get_edges_from_node_or_graph), 57 | gamla.sync.compose_left( 58 | gamla.sync.map(_infer_sink), 59 | tuple, 60 | lambda nodes: ( 61 | ( 62 | base_types.ComputationEdge( 63 | is_future=False, 64 | priority=0, 65 | source=None, 66 | args=nodes, 67 | destination=merge_node, 68 | key="*args", 69 | ), 70 | ), 71 | make_compose(merge_fn, merge_node, key="args"), 72 | ), 73 | ), 74 | ), 75 | gamla.sync.star(base_types.merge_graphs), 76 | ) 77 | 78 | 79 | def make_or( 80 | funcs: Sequence[base_types.CallableOrNodeOrGraph], 81 | merge_fn: base_types.CallableOrNodeOrGraph, 82 | ) -> base_types.GraphType: 83 | """Aggregate funcs' output into `merge_fn`. 84 | * merge_fn should have only 1 argument named `args`. 85 | * merge_fn will run anyway, with the outputs of the funcs that didn't raise an exception. 86 | """ 87 | 88 | def filter_computation_errors(*args): 89 | return gamla.pipe( 90 | args, gamla.remove(gamla.is_instance(_ComputationError)), tuple 91 | ) 92 | 93 | filter_node = graph.make_computation_node(filter_computation_errors) 94 | 95 | return gamla.sync.pipe( 96 | funcs, 97 | gamla.sync.map(make_optional(default_value=_ComputationError())), 98 | tuple, 99 | gamla.sync.pair_right( 100 | gamla.sync.compose_left( 101 | gamla.sync.map(_infer_sink), 102 | tuple, 103 | lambda sinks: ( 104 | ( 105 | base_types.ComputationEdge( 106 | is_future=False, 107 | priority=0, 108 | source=None, 109 | args=sinks, 110 | destination=filter_node, 111 | key="*args", 112 | ), 113 | ), 114 | make_compose(merge_fn, filter_node, key="args"), 115 | ), 116 | ) 117 | ), 118 | gamla.concat, 119 | gamla.sync.star(base_types.merge_graphs), 120 | ) 121 | 122 | 123 | _destinations = gamla.compose(set, gamla.map(base_types.edge_destination)) 124 | 125 | 126 | def _infer_sink(graph_or_node: base_types.NodeOrGraph) -> base_types.ComputationNode: 127 | if isinstance(graph_or_node, base_types.ComputationNode): 128 | return graph_or_node 129 | graph_without_future_edges = graph.remove_future_edges(graph_or_node) 130 | if graph_without_future_edges: 131 | try: 132 | return graph.sink_excluding_terminals(graph_without_future_edges) 133 | except AssertionError: 134 | # If we reached here we can try again without sources of future edges. 135 | sources_of_future_edges = gamla.sync.pipe( 136 | graph_or_node, 137 | gamla.sync.filter(base_types.edge_is_future), 138 | gamla.sync.map(base_types.edge_source), 139 | frozenset, 140 | ) 141 | result = ( 142 | graph.get_leaves(graph_without_future_edges) - sources_of_future_edges 143 | ) 144 | assert len(result) == 1 145 | return gamla.head(result) 146 | 147 | assert len(_destinations(graph_or_node)) == 1, graph_or_node 148 | return base_types.edge_destination(gamla.head(graph_or_node)) 149 | 150 | 151 | def make_first(*graphs: base_types.CallableOrNodeOrGraph) -> base_types.GraphType: 152 | """Returns a graph that when run, returns the first value that doesn't raise an exception. 153 | >>> def raise_some_exception(): 154 | ... raise Exception 155 | ... make_first(raise_some_exception, gamla.just(1)) 156 | (raise_exception----constituent_of_first---->first_sink, just----constituent_of_first---->first_sink) 157 | Will return 1. 158 | """ 159 | graph_or_nodes = tuple(map(_callable_or_graph_type_to_node_or_graph_type, graphs)) 160 | 161 | def first_sink(constituent_of_first): 162 | return constituent_of_first 163 | 164 | return base_types.merge_graphs( 165 | *map(_get_edges_from_node_or_graph, graph_or_nodes), 166 | gamla.pipe( 167 | graph_or_nodes, 168 | gamla.map(_infer_sink), 169 | enumerate, 170 | gamla.map( 171 | gamla.star( 172 | lambda i, g: base_types.ComputationEdge( 173 | source=g, 174 | destination=graph.make_computation_node(first_sink), 175 | key="constituent_of_first", 176 | args=(), 177 | priority=i, 178 | is_future=False, 179 | ) 180 | ) 181 | ), 182 | tuple, 183 | ), 184 | ) 185 | 186 | 187 | def last(*args) -> base_types.GraphType: 188 | return make_first(*reversed(args)) 189 | 190 | 191 | @gamla.curry 192 | def _try_connect( 193 | source: base_types.ComputationNode, 194 | key: Optional[str], 195 | priority: int, 196 | is_future: bool, 197 | destination: base_types.ComputationNode, 198 | unbound_destination_signature: base_types.NodeSignature, 199 | ) -> base_types.ComputationEdge: 200 | if key is None and signature.is_unary(unbound_destination_signature): 201 | key = gamla.head(signature.parameters(unbound_destination_signature)) 202 | assert key is not None and key in signature.parameters( 203 | unbound_destination_signature 204 | ), f"Expecting a graph with key `{key}` but got `{destination}`" 205 | return base_types.ComputationEdge( 206 | source=source, 207 | destination=destination, 208 | key=key, 209 | args=(), 210 | priority=priority, 211 | is_future=is_future, 212 | ) 213 | 214 | 215 | @gamla.curry 216 | def _infer_composition_edges( 217 | priority: int, 218 | key: Optional[str], 219 | is_future: bool, 220 | source: base_types.NodeOrGraph, 221 | destination: base_types.NodeOrGraph, 222 | ) -> base_types.GraphType: 223 | try_connect = _try_connect(_infer_sink(source), key, priority, is_future) 224 | 225 | if isinstance(destination, base_types.ComputationNode): 226 | return base_types.merge_graphs( 227 | (try_connect(destination, destination.signature),), 228 | _get_edges_from_node_or_graph(source), 229 | ) 230 | 231 | unbound_signature = graph.unbound_signature( 232 | graph.get_incoming_edges_for_node(destination) 233 | ) 234 | return base_types.merge_graphs( 235 | gamla.sync.pipe( 236 | destination, 237 | gamla.sync.mapcat(graph.get_edge_nodes), 238 | gamla.unique, 239 | gamla.sync.filter( 240 | gamla.sync.compose_left( 241 | unbound_signature, 242 | lambda sig: key in sig.kwargs 243 | or (key is None and signature.is_unary(sig)), 244 | ) 245 | ), 246 | # Do not add edges to nodes from source that are already present in destination (cycle). 247 | gamla.sync.filter( 248 | lambda node: isinstance(source, base_types.ComputationNode) 249 | or node not in graph.get_all_nodes(source) 250 | ), 251 | gamla.sync.map( 252 | lambda destination: try_connect( 253 | destination=destination, 254 | unbound_destination_signature=unbound_signature(destination), 255 | ) 256 | ), 257 | tuple, 258 | gamla.sync.check( 259 | gamla.identity, 260 | AssertionError( 261 | f"Cannot compose, destination signature does not contain key '{key}'" 262 | ), 263 | ), 264 | ), 265 | destination, 266 | _get_edges_from_node_or_graph(source), 267 | ) 268 | 269 | 270 | def _make_compose_inner( 271 | *funcs: base_types.CallableOrNodeOrGraph, 272 | key: Optional[str], 273 | is_future, 274 | priority: int, 275 | ) -> base_types.GraphType: 276 | assert ( 277 | len(funcs) > 1 278 | ), f"Only {len(funcs)} function passed to compose, need at least 2, funcs={funcs}" 279 | return gamla.sync.pipe( 280 | funcs, 281 | reversed, 282 | gamla.sync.map(_callable_or_graph_type_to_node_or_graph_type), 283 | gamla.sliding_window(2), 284 | gamla.sync.map(gamla.star(_infer_composition_edges(priority, key, is_future))), 285 | gamla.sync.star(base_types.merge_graphs), 286 | ) 287 | 288 | 289 | def make_compose( 290 | *funcs: base_types.CallableOrNodeOrGraph, key: Optional[str] = None 291 | ) -> base_types.GraphType: 292 | return _make_compose_inner(*funcs, key=key, is_future=False, priority=0) 293 | 294 | 295 | def compose_unary(*funcs: base_types.CallableOrNodeOrGraph) -> base_types.GraphType: 296 | """Returns a graph of the funcs composed from right to left. All functions need to be unary (have 1 argument). 297 | >>> compose_unary(gamla.add(1), gamla.multiply(2), gamla.divide_by(2)) 298 | (divide_by----y---->multiply, multiply----y---->add) 299 | """ 300 | return _make_compose_inner(*funcs, key=None, is_future=False, priority=0) 301 | 302 | 303 | def make_compose_future( 304 | destination: base_types.CallableOrNodeOrGraph, 305 | source: base_types.CallableOrNodeOrGraph, 306 | key: Optional[str], 307 | default: base_types.Result, 308 | ) -> base_types.GraphType: 309 | def when_memory_unavailable(): 310 | return default 311 | 312 | return base_types.merge_graphs( 313 | _make_compose_inner(destination, source, key=key, is_future=True, priority=0), 314 | _make_compose_inner( 315 | destination, when_memory_unavailable, key=key, is_future=False, priority=1 316 | ), 317 | ) 318 | 319 | 320 | def compose_unary_future( 321 | destination: base_types.CallableOrNodeOrGraph, 322 | source: base_types.CallableOrNodeOrGraph, 323 | default: base_types.Result, 324 | ) -> base_types.GraphType: 325 | return make_compose_future(destination, source, None, default) 326 | 327 | 328 | def compose_source( 329 | destination: base_types.CallableOrNodeOrGraph, 330 | key: str, 331 | source: base_types.CallableOrNodeOrGraph, 332 | ) -> base_types.GraphType: 333 | return _make_compose_inner(destination, source, key=key, is_future=True, priority=0) 334 | 335 | 336 | @gamla.curry 337 | def compose_left_source( 338 | source: base_types.CallableOrNodeOrGraph, 339 | key: str, 340 | destination: base_types.CallableOrNodeOrGraph, 341 | ): 342 | return compose_source(destination, key, source) 343 | 344 | 345 | def compose_source_unary( 346 | destination: base_types.CallableOrNodeOrGraph, 347 | source: base_types.CallableOrNodeOrGraph, 348 | ) -> base_types.GraphType: 349 | return _make_compose_inner( 350 | destination, source, key=None, is_future=True, priority=0 351 | ) 352 | 353 | 354 | def compose_left(*args, key: Optional[str] = None) -> base_types.GraphType: 355 | """Compose a function onto another function on a certain key. 356 | >>>compose_left(gamla.just(1), gamla.between, key="low") 357 | (just----low---->between,) 358 | """ 359 | return make_compose(*reversed(args), key=key) 360 | 361 | 362 | def compose_left_future( 363 | source: base_types.GraphOrCallable, 364 | destination: base_types.GraphOrCallable, 365 | key: Optional[str], 366 | default: base_types.Result, 367 | ) -> base_types.GraphType: 368 | return make_compose_future(destination, source, key, default) 369 | 370 | 371 | def compose_left_unary(*args) -> base_types.GraphType: 372 | """Returns a graph of the funcs composed from left to right. All functions need to be unary (have 1 argument). 373 | >>> compose_left_unary(gamla.add(1), gamla.multiply(2), gamla.divide_by(2)) 374 | (add----y---->multiply, multiply----y---->divide_by) 375 | """ 376 | return compose_unary(*reversed(args)) 377 | 378 | 379 | @gamla.curry 380 | def compose_dict( 381 | f: base_types.GraphOrCallable, d: Dict[str, base_types.CallableOrNodeOrGraph] 382 | ) -> base_types.GraphType: 383 | """Compose the functions in d.values() onto f, where d.keys() specify the arguments 384 | >>> compose_dict(gamla.between, {"low": gamla.just(0), "high": gamla.just(10)}) 385 | (just----low---->between, just----high---->between) 386 | """ 387 | return gamla.pipe( 388 | d, 389 | dict.items, 390 | gamla.sync.map(gamla.star(lambda key, fn: make_compose(f, fn, key=key))), 391 | gamla.sync.star(base_types.merge_graphs), 392 | ) or compose_left_unary(f, lambda x: x) 393 | 394 | 395 | @gamla.curry 396 | def compose_left_dict(d: Dict, f: base_types.GraphOrCallable) -> base_types.GraphType: 397 | return compose_dict(f, d) 398 | 399 | 400 | def make_raise_exception(exception): 401 | def inner(): 402 | raise exception 403 | 404 | return inner 405 | 406 | 407 | def side_effect(f): 408 | def side_effect(g): 409 | return compose_left_unary(g, gamla.side_effect(f)) 410 | 411 | return side_effect 412 | 413 | 414 | @gamla.curry 415 | def compose_many_to_one( 416 | aggregation: Callable, graphs: Iterable[base_types.GraphOrCallable] 417 | ): 418 | return make_and(graphs, aggregation) 419 | 420 | 421 | @gamla.curry 422 | def aggregation( 423 | aggregation: Callable[[Iterable], Any], graphs: Iterable[base_types.GraphOrCallable] 424 | ) -> base_types.GraphType: 425 | """Same as `compose_many_to_one`, but takes care to duplicate `aggregation`, and allows it to have any arg name.""" 426 | return make_and( 427 | graphs, 428 | # It is important that `aggregation` is duplicated here. 429 | # If it weren't for the `compose_left` we would need to do it explicitly. 430 | gamla.compose_left(lambda args: args, aggregation), 431 | ) 432 | -------------------------------------------------------------------------------- /computation_graph/composers/composers_test.py: -------------------------------------------------------------------------------- 1 | import computation_graph.graph 2 | from computation_graph import base_types, composers, graph_runners 3 | 4 | 5 | def test_unary_composition_with_graph_destination(): 6 | result = graph_runners.nullary_infer_sink( 7 | composers.compose_unary( 8 | composers.compose_left(lambda: 1, lambda a, b: a - b, key="a"), lambda: 2 9 | ) 10 | ) 11 | 12 | assert result == -1 13 | 14 | 15 | def test_unary_composition_on_unbounded_kwargs_in_graph_destination(): 16 | result = graph_runners.nullary_infer_sink( 17 | composers.compose_left_unary( 18 | lambda: 2, 19 | composers.compose_left_unary( 20 | lambda *args, **kwargs: args[0] - 3, lambda x: x 21 | ), 22 | ) 23 | ) 24 | assert result == -1 25 | 26 | 27 | def test_infer_sink_edge_case_all_future_edges_with_single_destination(): 28 | s = computation_graph.graph.make_source() 29 | 30 | def c_b(c, b): 31 | return b 32 | 33 | two_future_edges_single_dest = base_types.merge_graphs( 34 | composers.compose_left_source(s, "c", c_b), 35 | composers.compose_left_source(s, "b", c_b), 36 | ) 37 | 38 | composers.compose_left_unary(two_future_edges_single_dest, lambda x: x) 39 | 40 | 41 | def test_ambiguity_does_not_blow_up(): 42 | counter = 0 43 | 44 | def increment(): 45 | nonlocal counter 46 | counter += 1 47 | 48 | def g(): 49 | return "x" 50 | 51 | for _ in range(15): 52 | g = composers.make_first( 53 | composers.compose_left_unary(g, lambda x: increment() or x + "a"), 54 | composers.compose_left_unary(g, lambda x: increment() or x + "b"), 55 | ) 56 | graph_runners.nullary_infer_sink(g) 57 | assert counter == 30 58 | -------------------------------------------------------------------------------- /computation_graph/composers/condition.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, Iterable 2 | 3 | import gamla 4 | 5 | from computation_graph import base_types, composers 6 | from computation_graph.composers import lift, logic 7 | 8 | 9 | def when( 10 | condition: base_types.GraphOrCallable, 11 | transformation: base_types.GraphOrCallable, 12 | source: base_types.GraphOrCallable, 13 | ) -> base_types.GraphType: 14 | """Transform the output of `source` using `transformation`, but only if `condition`'s output is truthy.""" 15 | return if_then_else( 16 | condition, composers.compose_unary(transformation, source), source 17 | ) 18 | 19 | 20 | @gamla.curry 21 | def require( 22 | condition: base_types.GraphOrCallable, result: base_types.GraphOrCallable 23 | ) -> base_types.GraphType: 24 | def check(x): 25 | if x: 26 | return None 27 | raise base_types.SkipComputationError 28 | 29 | return composers.make_and( 30 | (composers.compose_unary(check, condition), result), 31 | merge_fn=lambda args: args[1], 32 | ) 33 | 34 | 35 | case: Callable[[Dict[Callable, Callable]], base_types.GraphType] = gamla.compose_left( 36 | dict.items, gamla.map(gamla.star(require)), gamla.star(composers.make_first) 37 | ) 38 | 39 | 40 | def require_all( 41 | conditions: Iterable[base_types.GraphType], graph: base_types.GraphType 42 | ) -> base_types.GraphType: 43 | return require(logic.all_true(conditions), graph) 44 | 45 | 46 | def if_then_else( 47 | condition: base_types.GraphOrCallable, 48 | if_truthy: base_types.GraphOrCallable, 49 | if_falsy: base_types.GraphOrCallable, 50 | ) -> base_types.GraphType: 51 | def if_then_else(condition_value, true_value, false_value): 52 | if condition_value: 53 | return true_value 54 | return false_value 55 | 56 | return composers.compose_dict( 57 | if_then_else, 58 | { 59 | "condition_value": condition, 60 | "true_value": if_truthy, 61 | "false_value": if_falsy, 62 | }, 63 | ) 64 | 65 | 66 | def require_or_default(default_value): 67 | return lambda condition, value: if_then_else( 68 | condition, value, lift.always(default_value) 69 | ) 70 | -------------------------------------------------------------------------------- /computation_graph/composers/condition_test.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import pytest 4 | 5 | from computation_graph import graph_runners 6 | from computation_graph.composers import condition 7 | 8 | 9 | @pytest.mark.parametrize( 10 | "condition_fn,expected_result", [(lambda: True, "T"), (lambda: False, "F")] 11 | ) 12 | def test_if_then_else(condition_fn: Callable[[], bool], expected_result: str): 13 | assert ( 14 | graph_runners.nullary_infer_sink( 15 | condition.if_then_else(condition_fn, lambda: "T", lambda: "F") 16 | ) 17 | == expected_result 18 | ) 19 | -------------------------------------------------------------------------------- /computation_graph/composers/debug.py: -------------------------------------------------------------------------------- 1 | import builtins 2 | import inspect 3 | import logging 4 | from typing import Callable 5 | 6 | from computation_graph import composers 7 | 8 | 9 | def _debug_with_frame(debugger): 10 | def debug(f, *optional_message): 11 | frame = inspect.currentframe().f_back 12 | 13 | def d(x): 14 | debugger(x, frame, *optional_message) 15 | return x 16 | 17 | return composers.compose_unary(d, f) 18 | 19 | return debug 20 | 21 | 22 | def _debug_inner(x, frame, *optional_message): 23 | logging.info( 24 | f"Debug prompt for {frame.f_code.co_filename}:{frame.f_lineno}. Hit x+enter to see current value." 25 | ) 26 | builtins.breakpoint() 27 | 28 | 29 | #: Makes a pdb breakpoint with the node output (prints the line number!). 30 | debug = _debug_with_frame(_debug_inner) 31 | 32 | 33 | def _debug_log_inner(x, frame, *optional_message: str): 34 | if not optional_message: 35 | logging.info(f"{frame.f_code.co_filename}:{frame.f_lineno} output: {x}") 36 | logging.info( 37 | f"{optional_message[0]} at {frame.f_code.co_filename}:{frame.f_lineno} output: {x}" 38 | ) 39 | 40 | 41 | #: Prints a debug log with the node output (with a line number!). 42 | debug_log = _debug_with_frame(_debug_log_inner) 43 | 44 | 45 | def name_callable(f: Callable, name: str) -> Callable: 46 | f.__code__ = f.__code__.replace(co_name=name) 47 | f.__name__ = name 48 | return f 49 | -------------------------------------------------------------------------------- /computation_graph/composers/duplication.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import functools 3 | import inspect 4 | from typing import Dict 5 | 6 | import gamla 7 | 8 | from computation_graph import base_types, graph 9 | from computation_graph.composers import debug 10 | 11 | 12 | def duplicate_function(func): 13 | if asyncio.iscoroutinefunction(func): 14 | 15 | @functools.wraps(func) 16 | async def inner(*args, **kwargs): 17 | return await func(*args, **kwargs) 18 | 19 | else: 20 | 21 | @functools.wraps(func) 22 | def inner(*args, **kwargs): 23 | return func(*args, **kwargs) 24 | 25 | return debug.name_callable(inner, f"duplicate of {func.__name__}") 26 | 27 | 28 | def _duplicate_computation_edge(get_duplicated_node): 29 | return gamla.compose_left( 30 | gamla.dataclass_transform("source", get_duplicated_node), 31 | gamla.dataclass_transform("destination", get_duplicated_node), 32 | gamla.dataclass_transform( 33 | "args", gamla.compose_left(gamla.map(get_duplicated_node), tuple) 34 | ), 35 | ) 36 | 37 | 38 | _duplicate_node = gamla.compose_left( 39 | base_types.node_implementation, 40 | gamla.when( 41 | gamla.compose_left(inspect.signature, gamla.attrgetter("parameters"), len), 42 | duplicate_function, 43 | ), 44 | graph.make_computation_node, 45 | ) 46 | 47 | duplicate_node = _duplicate_node 48 | 49 | _node_to_duplicated_node = gamla.compose_left( 50 | gamla.remove(base_types.node_is_terminal), 51 | gamla.map(gamla.pair_right(_duplicate_node)), 52 | dict, 53 | gamla.dict_to_getter_with_default(None), 54 | ) 55 | 56 | duplicate_graph = gamla.compose_left( 57 | gamla.pair_with(gamla.compose_left(graph.get_all_nodes, _node_to_duplicated_node)), 58 | gamla.star( 59 | lambda get_duplicated_node, graph: gamla.pipe( 60 | graph, 61 | gamla.map( 62 | _duplicate_computation_edge( 63 | gamla.when(get_duplicated_node, get_duplicated_node) 64 | ) 65 | ), 66 | tuple, 67 | ) 68 | ), 69 | ) 70 | 71 | duplicate_function_or_graph = gamla.ternary( 72 | gamla.is_instance((tuple, set, frozenset)), duplicate_graph, duplicate_function 73 | ) 74 | 75 | 76 | def safe_replace_sources( 77 | source_to_replacement_dict: Dict[ 78 | base_types.CallableOrNode, base_types.CallableOrNodeOrGraph 79 | ], 80 | cg: base_types.GraphType, 81 | ) -> base_types.GraphType: 82 | source_to_replacement_dict = gamla.keymap(graph.make_computation_node)( 83 | source_to_replacement_dict 84 | ) 85 | reachable_to_duplicate_map = gamla.pipe( 86 | gamla.graph_traverse_many( 87 | source_to_replacement_dict.keys(), graph.traverse_forward(cg) 88 | ), 89 | gamla.remove(gamla.contains(source_to_replacement_dict.keys())), 90 | _node_to_duplicated_node, 91 | ) 92 | get_node_replacement = gamla.compose_left( 93 | gamla.lazyjuxt( 94 | gamla.dict_to_getter_with_default(None)(source_to_replacement_dict), 95 | reachable_to_duplicate_map, 96 | gamla.identity, 97 | ), 98 | gamla.find(gamla.identity), 99 | ) 100 | 101 | def update_edge(e: base_types.ComputationEdge) -> base_types.GraphType: 102 | dest = base_types.edge_destination(e) 103 | g = graph.replace_node(dest, get_node_replacement(dest))(frozenset((e,))) 104 | for source in base_types.edge_sources(e): 105 | g = graph.replace_source(source, get_node_replacement(source), g) 106 | return g 107 | 108 | return base_types.merge_graphs(*(update_edge(e) for e in cg)) 109 | -------------------------------------------------------------------------------- /computation_graph/composers/lift.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import gamla 4 | 5 | from computation_graph import base_types, composers 6 | 7 | 8 | def always(x): 9 | def always(): 10 | return x 11 | 12 | always.__name__ = f"always {x!s:.30}" 13 | return always 14 | 15 | 16 | def function_to_graph(f: Callable) -> base_types.GraphType: 17 | # Note that the used identity function must be a new instance every time, 18 | # so can't be replace with something like `gamla.identity`. 19 | return composers.compose_unary(lambda x: x, f) 20 | 21 | 22 | any_to_graph = gamla.case_dict( 23 | { 24 | base_types.is_computation_graph: gamla.identity, 25 | callable: function_to_graph, 26 | gamla.just(True): gamla.compose_left(always, function_to_graph), 27 | } 28 | ) 29 | -------------------------------------------------------------------------------- /computation_graph/composers/logic.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Tuple 2 | 3 | import gamla 4 | 5 | from computation_graph import base_types, composers 6 | 7 | 8 | def all_truthy(functions: Iterable[base_types.GraphOrCallable]) -> base_types.GraphType: 9 | def all_truthy(args: Tuple) -> bool: 10 | return all(args) 11 | 12 | return composers.make_and(functions, all_truthy) 13 | 14 | 15 | def any_truthy(functions: Iterable[base_types.GraphOrCallable]) -> base_types.GraphType: 16 | def any_truthy(args: Tuple) -> bool: 17 | return any(args) 18 | 19 | return composers.make_and(functions, any_truthy) 20 | 21 | 22 | all_true = gamla.compose_left( 23 | gamla.map(lambda graph: composers.compose_unary(gamla.equals(True), graph)), 24 | all_truthy, 25 | ) 26 | 27 | all_false = gamla.compose_left( 28 | gamla.map(lambda graph: composers.compose_unary(gamla.equals(False), graph)), 29 | all_truthy, 30 | ) 31 | 32 | 33 | def complement(function: base_types.GraphOrCallable) -> base_types.GraphType: 34 | def complement(value) -> bool: 35 | return not value 36 | 37 | return composers.compose_unary(complement, function) 38 | 39 | 40 | def is_predicate(expected_value): 41 | def is_predicate(f): 42 | def is_expected_value(value): 43 | return value is expected_value 44 | 45 | is_expected_value.__name__ = f"is {expected_value!s:.30}" 46 | return composers.compose_unary(is_expected_value, f) 47 | 48 | return is_predicate 49 | 50 | 51 | is_true = is_predicate(True) 52 | is_false = is_predicate(False) 53 | -------------------------------------------------------------------------------- /computation_graph/composers/memory.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable 2 | 3 | import gamla 4 | 5 | from computation_graph import base_types, composers, legacy 6 | 7 | 8 | def accumulate(f: base_types.GraphOrCallable) -> base_types.GraphType: 9 | """Accumulate `f`'s results into a `tuple`.""" 10 | 11 | @with_state("state", None) 12 | def accumulate(state, x): 13 | return *(state or ()), x 14 | 15 | return composers.make_compose(accumulate, f, key="x") 16 | 17 | 18 | class _NoValueYet: 19 | pass 20 | 21 | 22 | _NO_VALUE_YET = _NoValueYet() 23 | 24 | 25 | def changed(f: base_types.GraphOrCallable) -> base_types.GraphType: 26 | @legacy.handle_state("memory", _NO_VALUE_YET) 27 | def check_changed(memory, value_to_watch): 28 | return legacy.ComputationResult(value_to_watch != memory, value_to_watch) 29 | 30 | return composers.make_compose(check_changed, f, key="value_to_watch") 31 | 32 | 33 | def ever(bool_node): 34 | @with_state("state", None) 35 | def ever_inner(state, some_bool): 36 | return state or some_bool 37 | 38 | return composers.make_compose(ever_inner, bool_node, key="some_bool") 39 | 40 | 41 | def reduce_with_past_result( 42 | reduce_with_past: Callable[[Any, Any], Any], f: base_types.GraphOrCallable 43 | ): 44 | @legacy.handle_state("state", None) 45 | def lag(state, current): 46 | return legacy.ComputationResult(state, current) 47 | 48 | return composers.compose_dict( 49 | reduce_with_past, 50 | {"previous": composers.make_compose(lag, f, key="current"), "current": f}, 51 | ) 52 | 53 | 54 | @gamla.curry 55 | def with_state(key: str, default, f: Callable) -> base_types.GraphType: 56 | return composers.make_compose_future(f, f, key, default) 57 | -------------------------------------------------------------------------------- /computation_graph/composers/memory_test.py: -------------------------------------------------------------------------------- 1 | from computation_graph import graph_runners 2 | from computation_graph.composers import memory 3 | 4 | 5 | def test_changed(): 6 | graph_runners.nullary_infer_sink_with_state_and_expectations( 7 | memory.changed( 8 | memory.with_state( 9 | "state", 0, lambda state: state + 1 if state < 2 else state 10 | ) 11 | ) 12 | )(True, True, False) 13 | -------------------------------------------------------------------------------- /computation_graph/graph.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import Callable, FrozenSet, Tuple 3 | 4 | import gamla 5 | from gamla.optimized import sync as opt_gamla 6 | 7 | from computation_graph import base_types, signature 8 | 9 | get_edge_nodes = opt_gamla.ternary( 10 | base_types.edge_args, 11 | lambda edge: edge.args + (edge.destination,), 12 | lambda edge: (edge.source, edge.destination), 13 | ) 14 | 15 | get_all_nodes = opt_gamla.compose_left(opt_gamla.mapcat(get_edge_nodes), frozenset) 16 | 17 | 18 | edges_to_node_id_map = opt_gamla.compose_left( 19 | gamla.mapcat(get_edge_nodes), gamla.unique, enumerate, gamla.map(reversed), dict 20 | ) 21 | 22 | 23 | def make_computation_node( 24 | func: base_types.CallableOrNode, 25 | ) -> base_types.ComputationNode: 26 | if isinstance(func, base_types.ComputationNode): 27 | return func 28 | 29 | return base_types.ComputationNode( 30 | name=signature.name(func), 31 | func=func, 32 | signature=gamla.pipe( 33 | func, 34 | signature.from_callable, 35 | gamla.assert_that_with_message( 36 | gamla.just(str(func)), signature.is_supported 37 | ), 38 | ), 39 | is_terminal=False, 40 | ) 41 | 42 | 43 | def get_leaves(edges: base_types.GraphType) -> FrozenSet[base_types.ComputationNode]: 44 | all_nodes = get_all_nodes(edges) 45 | all_destinations = frozenset( 46 | gamla.concat((edge.source, *edge.args) for edge in edges) 47 | ) 48 | return all_nodes - all_destinations 49 | 50 | 51 | def sink_excluding_terminals(edges: base_types.GraphType) -> base_types.ComputationNode: 52 | leaves_not_terminal = frozenset( 53 | node for node in get_leaves(edges) if not node.is_terminal 54 | ) 55 | assert ( 56 | len(leaves_not_terminal) == 1 57 | ), f"Expected exactly one non-terminal sink, got {leaves_not_terminal}" 58 | return gamla.head(leaves_not_terminal) 59 | 60 | 61 | get_incoming_edges_for_node = opt_gamla.compose_left( 62 | opt_gamla.groupby(base_types.edge_destination), 63 | opt_gamla.valmap(frozenset), 64 | gamla.dict_to_getter_with_default(frozenset()), 65 | ) 66 | 67 | 68 | get_terminals = opt_gamla.compose_left( 69 | get_all_nodes, opt_gamla.filter(base_types.node_is_terminal), tuple 70 | ) 71 | 72 | 73 | remove_future_edges = opt_gamla.compose( 74 | tuple, opt_gamla.remove(base_types.edge_is_future) 75 | ) 76 | 77 | 78 | def make_source(): 79 | return make_source_with_name("unnamed source") 80 | 81 | 82 | def make_source_with_name(name: str): 83 | def source(): 84 | raise NotImplementedError(f"pure source [{name}] should never run") 85 | 86 | source.__name__ = f"source:{name}" 87 | return make_computation_node(source) 88 | 89 | 90 | @gamla.curry 91 | def make_terminal(name: str, func: Callable) -> base_types.ComputationNode: 92 | return base_types.ComputationNode( 93 | name=name, func=func, signature=signature.from_callable(func), is_terminal=True 94 | ) 95 | 96 | 97 | _keep_not_in_bound_kwargs = opt_gamla.compose_left( 98 | opt_gamla.map(base_types.edge_key), 99 | gamla.filter(gamla.identity), 100 | frozenset, 101 | gamla.contains, 102 | gamla.remove, 103 | ) 104 | 105 | 106 | @gamla.curry 107 | def unbound_signature( 108 | node_to_incoming_edges, node: base_types.ComputationNode 109 | ) -> base_types.NodeSignature: 110 | """Computes the new signature of unbound variables after considering internal edges.""" 111 | incoming_edges = node_to_incoming_edges(node) 112 | keep_not_in_bound_kwargs = _keep_not_in_bound_kwargs(incoming_edges) 113 | return base_types.NodeSignature( 114 | is_kwargs=node.signature.is_kwargs 115 | and "**kwargs" not in tuple(opt_gamla.map(base_types.edge_key)(incoming_edges)) 116 | and signature.is_unary(node.signature), 117 | is_args=node.signature.is_args 118 | and not any(edge.args for edge in incoming_edges), 119 | kwargs=tuple(keep_not_in_bound_kwargs(node.signature.kwargs)), 120 | optional_kwargs=tuple(keep_not_in_bound_kwargs(node.signature.optional_kwargs)), 121 | ) 122 | 123 | 124 | def _node_in_edge_args( 125 | x: base_types.CallableOrNode, 126 | ) -> Callable[[base_types.ComputationEdge], bool]: 127 | node = make_computation_node(x) 128 | 129 | def node_in_edge_args(edge: base_types.ComputationEdge) -> bool: 130 | return node in edge.args 131 | 132 | return node_in_edge_args 133 | 134 | 135 | def _replace_source_in_edges( 136 | original: base_types.CallableOrNode, replacement: base_types.CallableOrNode 137 | ) -> Callable[[base_types.GraphType], base_types.GraphType]: 138 | return gamla.compose( 139 | transform_edges(edge_source_equals(original), replace_edge_source(replacement)), 140 | transform_edges( 141 | _node_in_edge_args(original), 142 | _replace_edge_source_args(original, replacement), 143 | ), 144 | ) 145 | 146 | 147 | traverse_forward: Callable[ 148 | [base_types.GraphType], 149 | Callable[[base_types.CallableOrNode], Tuple[base_types.ComputationNode, ...]], 150 | ] = opt_gamla.compose_left( 151 | gamla.mapcat( 152 | opt_gamla.compose_left( 153 | gamla.juxt(base_types.edge_sources, base_types.edge_destination), 154 | gamla.explode(0), 155 | ) 156 | ), 157 | gamla.groupby_many_reduce( 158 | opt_gamla.compose_left(gamla.head, gamla.wrap_tuple), 159 | lambda destinations, e: {*(destinations if destinations else ()), e[1]}, 160 | ), 161 | gamla.dict_to_getter_with_default(()), 162 | gamla.before(make_computation_node), 163 | ) 164 | 165 | 166 | @gamla.curry 167 | def replace_source( 168 | original: base_types.CallableOrNode, 169 | replacement: base_types.CallableOrNodeOrGraph, 170 | current_graph: base_types.GraphType, 171 | ) -> base_types.GraphType: 172 | if make_computation_node(original) not in get_all_nodes(current_graph): 173 | return current_graph 174 | 175 | if gamla.is_instance(base_types.CallableOrNode)(replacement): 176 | return _replace_source_in_edges(original, replacement)(current_graph) # type: ignore 177 | 178 | if base_types.is_computation_graph(replacement): 179 | return gamla.pipe( 180 | current_graph, 181 | _replace_source_in_edges(original, sink_excluding_terminals(replacement)), # type: ignore 182 | gamla.concat_with(replacement), 183 | tuple, 184 | ) 185 | 186 | raise RuntimeError(f"Unsupported relacement graph {replacement}") 187 | 188 | 189 | def replace_destination( 190 | original: base_types.CallableOrNode, replacement: base_types.CallableOrNode 191 | ) -> Callable[[base_types.GraphType], base_types.GraphType]: 192 | return transform_edges( 193 | edge_destination_equals(original), _replace_edge_destination(replacement) 194 | ) 195 | 196 | 197 | def replace_node( 198 | original: base_types.CallableOrNode, replacement: base_types.CallableOrNode 199 | ) -> Callable[[base_types.GraphType], base_types.GraphType]: 200 | return gamla.compose( 201 | replace_source(original, replacement), 202 | replace_destination(original, replacement), 203 | ) 204 | 205 | 206 | def transform_edges( 207 | query: Callable[[base_types.ComputationEdge], bool], 208 | edge_mapper: Callable[[base_types.ComputationEdge], base_types.ComputationEdge], 209 | ): 210 | return _operate_on_subgraph( 211 | _split_by_condition(query), 212 | opt_gamla.compose_left(gamla.map(edge_mapper), tuple), 213 | ) 214 | 215 | 216 | def edge_source_equals( 217 | x: base_types.CallableOrNode, 218 | ) -> Callable[[base_types.ComputationEdge], bool]: 219 | x = make_computation_node(x) 220 | 221 | def edge_source_equals(edge): 222 | return edge.source == x 223 | 224 | return edge_source_equals 225 | 226 | 227 | def edge_destination_equals( 228 | x: base_types.CallableOrNode, 229 | ) -> Callable[[base_types.ComputationEdge], bool]: 230 | x = make_computation_node(x) 231 | 232 | def edge_destination_equals(edge): 233 | return edge.destination == x 234 | 235 | return edge_destination_equals 236 | 237 | 238 | def replace_edge_source( 239 | replacement: base_types.CallableOrNode, 240 | ) -> Callable[[base_types.ComputationEdge], base_types.ComputationEdge]: 241 | return gamla.dataclass_replace("source", make_computation_node(replacement)) 242 | 243 | 244 | def _replace_edge_source_args( 245 | original: base_types.CallableOrNode, replacement: base_types.CallableOrNode 246 | ): 247 | def replace_edge_source_args( 248 | edge: base_types.ComputationEdge, 249 | ) -> base_types.ComputationEdge: 250 | return dataclasses.replace( 251 | edge, 252 | args=gamla.pipe( 253 | edge.args, 254 | gamla.map( 255 | gamla.when( 256 | gamla.equals(make_computation_node(original)), 257 | gamla.just(make_computation_node(replacement)), 258 | ) 259 | ), 260 | tuple, 261 | ), 262 | ) 263 | 264 | return replace_edge_source_args 265 | 266 | 267 | def _replace_edge_destination( 268 | replacement: base_types.CallableOrNode, 269 | ) -> Callable[[base_types.ComputationEdge], base_types.ComputationEdge]: 270 | return gamla.dataclass_replace("destination", make_computation_node(replacement)) 271 | 272 | 273 | def _operate_on_subgraph(selector, transformation): 274 | return gamla.compose( 275 | gamla.star( 276 | lambda match, rest: base_types.merge_graphs(rest, transformation(match)) 277 | ), 278 | selector, 279 | ) 280 | 281 | 282 | def _split_by_condition(condition): 283 | return opt_gamla.compose_left( 284 | gamla.bifurcate(gamla.filter(condition), gamla.remove(condition)), 285 | gamla.map(tuple), 286 | tuple, 287 | ) 288 | -------------------------------------------------------------------------------- /computation_graph/graph_runners.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Callable, Union 3 | 4 | import gamla 5 | 6 | from computation_graph import base_types, composers, graph, run 7 | 8 | 9 | def _infer_graph_sink(edges: base_types.GraphType) -> base_types.ComputationNode: 10 | assert edges, "Empty graphs have no sink." 11 | leaves = graph.get_leaves(edges) 12 | assert len(leaves) == 1, f"Cannot determine sink for {edges}, got: {tuple(leaves)}." 13 | return gamla.head(leaves) 14 | 15 | 16 | def unary(g: base_types.GraphType, source: Callable, sink: Callable) -> Callable: 17 | return gamla.compose( 18 | gamla.itemgetter(graph.make_computation_node(sink)), unary_bare(g, source) 19 | ) 20 | 21 | 22 | def unary_bare(g, source): 23 | real_source = graph.make_source() 24 | return gamla.compose( 25 | gamla.star( 26 | run.to_callable_strict( 27 | base_types.merge_graphs( 28 | g, composers.compose_left_future(real_source, source, None, None) 29 | ) 30 | ) 31 | ), 32 | gamla.pair_with(gamla.just({})), 33 | gamla.wrap_dict(real_source), 34 | ) 35 | 36 | 37 | def unary_with_state( 38 | g: base_types.GraphType, 39 | source: Callable, 40 | sink: Union[Callable, base_types.ComputationNode], 41 | ) -> Callable: 42 | real_source = graph.make_source() 43 | f = run.to_callable_strict( 44 | base_types.merge_graphs( 45 | g, composers.compose_left_future(real_source, source, None, None) 46 | ) 47 | ) 48 | 49 | def inner(*turns): 50 | prev = {} 51 | for turn in turns: 52 | prev = f(prev, {real_source: turn}) 53 | return prev[graph.make_computation_node(sink)] 54 | 55 | return inner 56 | 57 | 58 | def unary_with_state_infer_sink(g: base_types.GraphType, source: Callable) -> Callable: 59 | return unary_with_state(g, source, _infer_graph_sink(g)) 60 | 61 | 62 | def unary_with_state_and_expectations( 63 | g: base_types.GraphType, source: Callable, sink: Callable 64 | ) -> Callable: 65 | real_source = graph.make_source() 66 | return gamla.compose( 67 | variadic_with_state_and_expectations( 68 | base_types.merge_graphs( 69 | g, composers.compose_left_future(real_source, source, None, None) 70 | ), 71 | sink, 72 | ), 73 | tuple, 74 | gamla.map(lambda t: ({real_source: t[0]}, t[1])), 75 | ) 76 | 77 | 78 | def variadic_with_state_and_expectations(g, sink): 79 | f = run.to_callable_strict(g) 80 | 81 | if asyncio.iscoroutinefunction(f): 82 | 83 | async def inner(turns): 84 | prev = {} 85 | for turn, expectation in turns: 86 | prev = await f(prev, turn) 87 | assert ( 88 | prev[graph.make_computation_node(sink)] == expectation 89 | ), f"actual={prev[graph.make_computation_node(sink)]}\n expected: {expectation}" 90 | 91 | return inner 92 | 93 | def inner(turns): 94 | prev = {} 95 | for turn, expectation in turns: 96 | prev = f(prev, turn) 97 | assert ( 98 | prev[graph.make_computation_node(sink)] == expectation 99 | ), f"actual={prev[graph.make_computation_node(sink)]}\n expected: {expectation}" 100 | 101 | return inner 102 | 103 | 104 | def variadic_bare(g): 105 | f = run.to_callable_strict(g) 106 | 107 | def inner(*turns): 108 | prev = {} 109 | for turn in turns: 110 | prev = f(prev, turn) 111 | return prev 112 | 113 | return inner 114 | 115 | 116 | def variadic_infer_sink(g): 117 | return gamla.compose_left(variadic_bare(g), gamla.itemgetter(_infer_graph_sink(g))) 118 | 119 | 120 | def variadic_stateful_infer_sink(g): 121 | f = run.to_callable_strict(g) 122 | sink = _infer_graph_sink(g) 123 | 124 | def inner(*turns): 125 | prev = {} 126 | for turn in turns: 127 | prev = f(prev, turn) 128 | return prev[sink] 129 | 130 | return inner 131 | 132 | 133 | def nullary(g, sink): 134 | return gamla.compose( 135 | gamla.itemgetter(graph.make_computation_node(sink)), run.to_callable_strict(g) 136 | )({}, {}) 137 | 138 | 139 | def nullary_infer_sink(g): 140 | return nullary(g, _infer_graph_sink(g)) 141 | 142 | 143 | def nullary_infer_sink_with_state_and_expectations(g): 144 | f = run.to_callable_strict(g) 145 | sink = _infer_graph_sink(g) 146 | 147 | def inner(*expectations): 148 | prev = {} 149 | for expectation in expectations: 150 | prev = f(prev, {}) 151 | assert prev[sink] == expectation 152 | 153 | return inner 154 | -------------------------------------------------------------------------------- /computation_graph/graph_test.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Dict 3 | 4 | import gamla 5 | import pytest 6 | 7 | from computation_graph import base_types, composers, graph, graph_runners, legacy, run 8 | from computation_graph.composers import duplication, memory 9 | 10 | 11 | def _infer_graph_sink_excluding_terminals( 12 | edges: base_types.GraphType, 13 | ) -> base_types.ComputationNode: 14 | leaves = gamla.pipe( 15 | edges, graph.get_leaves, gamla.remove(base_types.node_is_terminal), tuple 16 | ) 17 | assert len(leaves) == 1, f"computation graph has more than one sink: {leaves}" 18 | return gamla.head(leaves) 19 | 20 | 21 | def _node1(arg1): 22 | return f"node1({arg1})" 23 | 24 | 25 | async def _node1_async(arg1): 26 | await asyncio.sleep(0.1) 27 | return f"node1({arg1})" 28 | 29 | 30 | def _node2(arg1): 31 | return f"node2({arg1})" 32 | 33 | 34 | def _node3(arg1, arg2): 35 | return f"node3(arg1={arg1}, arg2={arg2})" 36 | 37 | 38 | def _node_that_raises(): 39 | raise base_types.SkipComputationError 40 | 41 | 42 | def _merger(args, side_effects): 43 | return "[" + ",".join(args) + f"], side_effects={side_effects}" 44 | 45 | 46 | def _next_int(x): 47 | if x is None: 48 | return 0 49 | return x + 1 50 | 51 | 52 | def _reducer_node(arg1, cur_int): 53 | return arg1 + f" cur_int={cur_int + 1}" 54 | 55 | 56 | def test_simple(): 57 | for v in ["root", None]: 58 | assert ( 59 | graph_runners.unary( 60 | composers.compose_left_unary(_node1, _node2), _node1, _node2 61 | )(v) 62 | == f"node2(node1({v}))" 63 | ) 64 | 65 | 66 | async def test_async_run_as_soon_as_possible(capsys): 67 | # This test is to make sure that the async node runs as soon as possible. 68 | # nodes in different topological layers should run concurrently if not dependent on each other. 69 | async def concurrent1(x): 70 | print("start concurrent1") # noqa 71 | await asyncio.sleep(0.1) 72 | print("end concurrent1") # noqa 73 | return "concurrent1" 74 | 75 | async def concurrent2(x): 76 | print("start concurrent2") # noqa 77 | await asyncio.sleep(0.1) 78 | print("end concurrent2") # noqa 79 | return "concurrent2" 80 | 81 | def sink(x, y): 82 | return f"x={x}, y={y}" 83 | 84 | g = base_types.merge_graphs( 85 | composers.compose_unary(concurrent1, lambda: "x"), 86 | composers.compose_unary(concurrent2, lambda y: y, lambda: "y"), 87 | composers.compose_left(concurrent1, sink, key="x"), 88 | composers.compose_left(concurrent2, sink, key="y"), 89 | ) 90 | 91 | await graph_runners.nullary(g, sink) 92 | 93 | # Capture the output 94 | captured = capsys.readouterr() 95 | 96 | # Assertions 97 | assert "start concurrent1" in captured.out 98 | assert "start concurrent2" in captured.out 99 | assert "end concurrent1" in captured.out 100 | assert "end concurrent2" in captured.out 101 | assert captured.out.index("start concurrent1") < captured.out.index( 102 | "end concurrent1" 103 | ) 104 | assert captured.out.index("start concurrent2") < captured.out.index( 105 | "end concurrent2" 106 | ) 107 | assert captured.out.index("start concurrent1") < captured.out.index( 108 | "end concurrent2" 109 | ) 110 | assert captured.out.index("start concurrent2") < captured.out.index( 111 | "end concurrent1" 112 | ) 113 | 114 | 115 | async def test_simple_async(): 116 | assert ( 117 | await graph_runners.unary( 118 | composers.compose_left_unary(_node1_async, _node2), _node1_async, _node2 119 | )("hi") 120 | == "node2(node1(hi))" 121 | ) 122 | 123 | 124 | def test_kwargs(): 125 | def source(x): 126 | return x 127 | 128 | assert ( 129 | graph_runners.unary( 130 | base_types.merge_graphs( 131 | composers.compose_unary(_node1, source), 132 | composers.compose_unary(_node2, source), 133 | composers.compose_dict(_node3, {"arg1": _node1, "arg2": _node2}), 134 | ), 135 | source, 136 | _node3, 137 | )("hi") 138 | == "node3(arg1=node1(hi), arg2=node2(hi))" 139 | ) 140 | 141 | 142 | @legacy.handle_state("state", None) 143 | def _node_with_state_as_arg(arg1, state): 144 | if state is None: 145 | state = 0 146 | return legacy.ComputationResult( 147 | result=arg1 + f" state={state + 1}", state=state + 1 148 | ) 149 | 150 | 151 | def test_state(): 152 | f = graph_runners.unary_with_state( 153 | base_types.merge_graphs( 154 | composers.compose_left(_node1, _node_with_state_as_arg, key="arg1"), 155 | composers.compose_left_unary(_node_with_state_as_arg, _node2), 156 | ), 157 | _node1, 158 | _node2, 159 | ) 160 | assert f("x", "y", "z") == "node2(node1(z) state=3)" 161 | 162 | 163 | def test_self_future_edge(): 164 | f = graph_runners.unary_with_state( 165 | base_types.merge_graphs( 166 | composers.compose_dict( 167 | _reducer_node, {"arg1": _node1, "cur_int": _next_int} 168 | ), 169 | composers.compose_unary(_node2, _reducer_node), 170 | composers.compose_left_future(_next_int, _next_int, "x", None), 171 | ), 172 | _node1, 173 | _node2, 174 | ) 175 | assert f("x", "y", "z") == "node2(node1(z) cur_int=3)" 176 | 177 | 178 | def test_empty_result(): 179 | def raises(x): 180 | del x 181 | raise base_types.SkipComputationError 182 | 183 | with pytest.raises(KeyError): 184 | graph_runners.nullary(composers.compose_unary(raises, lambda: "hi"), raises) 185 | 186 | 187 | def test_optional(): 188 | def raises(): 189 | raise base_types.SkipComputationError 190 | 191 | def sink(x): 192 | return x 193 | 194 | assert ( 195 | graph_runners.nullary( 196 | composers.compose_unary(sink, composers.make_optional(raises, None)), sink 197 | ) 198 | is None 199 | ) 200 | 201 | 202 | def test_optional_with_future_edge(): 203 | def output(x): 204 | return x 205 | 206 | def input(x): 207 | return x 208 | 209 | f = graph_runners.unary_with_state( 210 | base_types.merge_graphs( 211 | composers.make_compose(_reducer_node, input, key="arg1"), 212 | composers.compose_unary( 213 | output, composers.make_optional(_reducer_node, None) 214 | ), 215 | composers.make_compose(_reducer_node, _next_int, key="cur_int"), 216 | composers.compose_left_future(_next_int, _next_int, None, None), 217 | ), 218 | input, 219 | output, 220 | ) 221 | assert f("x", "y", "z") == "z cur_int=3" 222 | 223 | 224 | def test_first(): 225 | def raises(): 226 | raise base_types.SkipComputationError 227 | 228 | assert ( 229 | graph_runners.nullary_infer_sink( 230 | composers.make_first(raises, lambda: 1, lambda: 2) 231 | ) 232 | == 1 233 | ) 234 | 235 | 236 | async def test_raise_handled_from_async(): 237 | async def raises(): 238 | raise base_types.SkipComputationError 239 | 240 | assert ( 241 | await graph_runners.nullary_infer_sink( 242 | composers.make_first(raises, lambda: 1, lambda: 2) 243 | ) 244 | == 1 245 | ) 246 | 247 | 248 | async def test_no_result_for_node_that_raised_handled_exception(): 249 | async def raises(x): 250 | raise base_types.SkipComputationError 251 | 252 | def sink(x): 253 | return x 254 | 255 | @graph.make_computation_node 256 | def source(): 257 | return 4 258 | 259 | res = await run.to_callable_strict(composers.compose_left(source, raises, sink))( 260 | {}, {} 261 | ) 262 | 263 | assert res[source] == 4 264 | assert raises not in res 265 | assert sink not in res 266 | 267 | 268 | def test_raise_unhandled_exception(): 269 | class MyExceptionError(Exception): 270 | ... 271 | 272 | def raises(): 273 | raise MyExceptionError("BAD") 274 | 275 | with pytest.raises(MyExceptionError, match="BAD"): 276 | graph_runners.nullary_infer_sink(composers.make_first(raises, lambda: 1)) 277 | 278 | 279 | async def test_raise_unhandled_exception_async(): 280 | class MyExceptionError(Exception): 281 | ... 282 | 283 | @composers.compose_left_dict( 284 | {"x": composers.compose_left_unary(lambda: 1, lambda x: 1)} 285 | ) 286 | def raises(x): 287 | raise MyExceptionError("BAD") 288 | 289 | @composers.compose_left_dict({"x": lambda: 1}) 290 | async def not_awaited(x): 291 | raise base_types.SkipComputationError("DAB") 292 | 293 | with pytest.raises(MyExceptionError, match="BAD"): 294 | await graph_runners.nullary_infer_sink( 295 | composers.compose_left_dict({"y": raises, "x": not_awaited}, lambda x, y: 1) 296 | ) 297 | assert len(asyncio.all_tasks()) == 1 298 | 299 | 300 | async def test_raise_exception_in_sync_after_async(): 301 | def raises(x): 302 | raise TypeError("BAD") 303 | 304 | async def async_source(): 305 | return 4 306 | 307 | with pytest.raises(TypeError, match="BAD"): 308 | await run.to_callable_strict( 309 | composers.compose_left(async_source, raises, lambda x: 1) 310 | )({}, {}) 311 | 312 | 313 | async def test_raise_exception_in_async_after_sync(): 314 | async def raises(x): 315 | raise TypeError("BAD") 316 | 317 | def source(): 318 | return 4 319 | 320 | with pytest.raises(TypeError, match="BAD"): 321 | await run.to_callable_strict( 322 | composers.compose_left(source, raises, lambda x: 1) 323 | )({}, {}) 324 | 325 | 326 | def test_first_all_unactionable(): 327 | def raises(): 328 | raise base_types.SkipComputationError 329 | 330 | with pytest.raises(KeyError): 331 | graph_runners.nullary_infer_sink(composers.make_first(raises)) 332 | 333 | 334 | def test_first_with_future_edge(): 335 | def input_node(x): 336 | return x 337 | 338 | f = graph_runners.unary_with_state( 339 | base_types.merge_graphs( 340 | composers.make_compose(_reducer_node, input_node, key="arg1"), 341 | composers.make_compose(_node1, input_node, key="arg1"), 342 | composers.make_first(_node_that_raises, _reducer_node, _node1), 343 | composers.make_compose(_reducer_node, _next_int, key="cur_int"), 344 | composers.make_compose_future(_next_int, _next_int, "x", None), 345 | ), 346 | input_node, 347 | _reducer_node, 348 | ) 349 | assert f("x", "y", "z") == "z cur_int=3" 350 | 351 | 352 | def test_and_with_future(): 353 | source1 = graph.make_source() 354 | source2 = graph.make_source() 355 | g = base_types.merge_graphs( 356 | composers.make_and((_reducer_node, _node2, _node1), _merger), 357 | composers.compose_source(_merger, key="side_effects", source=source2), 358 | composers.compose_source(_node1, key="arg1", source=source1), 359 | composers.compose_source(_node2, key="arg1", source=source1), 360 | composers.compose_source(_reducer_node, key="arg1", source=source1), 361 | composers.make_compose(_reducer_node, _next_int, key="cur_int"), 362 | composers.compose_unary_future(_next_int, _next_int, None), 363 | ) 364 | assert ( 365 | graph_runners.variadic_stateful_infer_sink(g)( 366 | {source1: "root", source2: "bla"}, 367 | {source1: "root", source2: "bla"}, 368 | {source1: "root", source2: "bla"}, 369 | ) 370 | == "[root cur_int=3,node2(root),node1(root)], side_effects=bla" 371 | ) 372 | 373 | 374 | def test_and_with_unactionable(): 375 | source1 = graph.make_source() 376 | source2 = graph.make_source() 377 | g = base_types.merge_graphs( 378 | composers.make_and((_reducer_node, _node_that_raises), _merger), 379 | composers.compose_source(_merger, key="side_effects", source=source2), 380 | composers.compose_source(_reducer_node, key="arg1", source=source1), 381 | composers.make_compose(_reducer_node, _next_int, key="cur_int"), 382 | composers.compose_unary_future(_next_int, _next_int, None), 383 | ) 384 | with pytest.raises(KeyError): 385 | graph_runners.variadic_infer_sink(g)({source1: "root", source2: "bla"}) 386 | 387 | 388 | def test_or(): 389 | def merger(args): 390 | return " ".join(map(str, args)) 391 | 392 | f = graph_runners.variadic_stateful_infer_sink( 393 | base_types.merge_graphs( 394 | composers.make_or((_next_int, lambda: "node1", _node_that_raises), merger), 395 | composers.compose_unary_future(_next_int, _next_int, 0), 396 | ) 397 | ) 398 | 399 | assert f({}, {}, {}) == "3 node1" 400 | 401 | 402 | def test_compose(): 403 | assert ( 404 | graph_runners.nullary_infer_sink(composers.make_compose(_node1, lambda: "hi")) 405 | == "node1(hi)" 406 | ) 407 | 408 | 409 | def test_compose_with_future_edge(): 410 | f = graph_runners.unary_with_state( 411 | base_types.merge_graphs( 412 | composers.make_compose(_node1, _node2), 413 | composers.make_compose(_reducer_node, _node1, key="arg1"), 414 | composers.make_compose(_reducer_node, _next_int, key="cur_int"), 415 | composers.make_compose_future(_next_int, _next_int, None, None), 416 | ), 417 | _node2, 418 | _reducer_node, 419 | ) 420 | assert f("hi", "hi", "hi") == "node1(node2(hi)) cur_int=3" 421 | 422 | 423 | def test_optional_memory_sometimes_raises(): 424 | def sometimes_raises(x, cur_int): 425 | if x == "fail": 426 | raise base_types.SkipComputationError 427 | return x + f" state={cur_int + 1}" 428 | 429 | def input_source(x): 430 | return x 431 | 432 | f = graph_runners.unary_with_state_infer_sink( 433 | base_types.merge_graphs( 434 | composers.make_compose(sometimes_raises, input_source, key="x"), 435 | composers.make_optional(sometimes_raises, None), 436 | composers.make_compose(sometimes_raises, _next_int, key="cur_int"), 437 | composers.compose_unary_future(_next_int, _next_int, None), 438 | ), 439 | input_source, 440 | ) 441 | assert f("hi", "fail", "hi") == "hi state=3" 442 | 443 | 444 | def test_first_first(): 445 | def node1(): 446 | return "node1" 447 | 448 | assert ( 449 | graph_runners.nullary_infer_sink( 450 | composers.make_first( 451 | _node_that_raises, 452 | composers.make_first(_node_that_raises, node1, lambda: "node2"), 453 | node1, 454 | ) 455 | ) 456 | == "node1" 457 | ) 458 | 459 | 460 | def test_compose_with_node_already_in_graph(): 461 | def node1(): 462 | return "node1" 463 | 464 | def sink(x, y): 465 | return f"x={x} y={y}" 466 | 467 | def merger(args): 468 | return " ".join(args) 469 | 470 | assert ( 471 | graph_runners.nullary_infer_sink( 472 | composers.make_compose( 473 | composers.make_compose(sink, node1, key="x"), 474 | composers.make_and((lambda: "node2", node1), merge_fn=merger), 475 | key="y", 476 | ) 477 | ) 478 | == "x=node1 y=node2 node1" 479 | ) 480 | 481 | 482 | def test_first_with_subgraph_that_raises(): 483 | def raises(): 484 | raise base_types.SkipComputationError 485 | 486 | def node2(x): 487 | return x 488 | 489 | def node1(): 490 | return "node1" 491 | 492 | assert ( 493 | graph_runners.nullary_infer_sink( 494 | composers.make_first(composers.compose_unary(node2, raises), node1) 495 | ) 496 | == "node1" 497 | ) 498 | 499 | 500 | def test_or_with_sink_that_raises(): 501 | def raises(): 502 | raise base_types.SkipComputationError 503 | 504 | def merge(args): 505 | if not args: 506 | raise base_types.SkipComputationError 507 | return ",".join(args) 508 | 509 | assert ( 510 | graph_runners.nullary_infer_sink( 511 | composers.make_or((raises, lambda: "node1"), merge_fn=merge) 512 | ) 513 | == "node1" 514 | ) 515 | 516 | 517 | def test_unambiguous_composition_using_terminal(): 518 | terminal = graph.make_terminal("1", lambda x: x[0]) 519 | 520 | def source(): 521 | return 1 522 | 523 | with pytest.raises(AssertionError): 524 | composers.compose_unary( 525 | lambda x: x + 1, 526 | base_types.merge_graphs( 527 | composers.compose_unary(lambda x: x + 1, source), 528 | composers.compose_unary(lambda x: x, source), 529 | ), 530 | ) 531 | 532 | g = composers.compose_unary( 533 | lambda x: x + 1, 534 | base_types.merge_graphs( 535 | composers.compose_unary(lambda x: x + 1, source), 536 | composers.compose_unary(terminal, source), 537 | ), 538 | ) 539 | x = run.to_callable_strict(g)({}, {}) 540 | assert x[terminal] == 1 541 | assert x[_infer_graph_sink_excluding_terminals(g)] == 3 542 | 543 | 544 | def test_two_terminals(): 545 | terminal1 = graph.make_terminal("1", lambda x: x) 546 | terminal2 = graph.make_terminal("2", lambda x: x) 547 | result = graph_runners.unary_bare( 548 | base_types.merge_graphs( 549 | composers.compose_unary(terminal1, composers.make_compose(_node2, _node1)), 550 | composers.compose_unary(terminal2, _node1), 551 | ), 552 | _node1, 553 | )("hi") 554 | assert result[terminal1][0] == "node2(node1(hi))" 555 | assert result[terminal2][0] == "node1(hi)" 556 | 557 | 558 | def test_two_paths_succeed(): 559 | source = graph.make_source() 560 | terminal1 = graph.make_terminal("1", lambda x: x) 561 | terminal2 = graph.make_terminal("2", lambda x: x) 562 | result = graph_runners.variadic_bare( 563 | base_types.merge_graphs( 564 | composers.make_first( 565 | composers.compose_source_unary(_node1, source), 566 | composers.compose_unary( 567 | terminal1, composers.compose_source_unary(_node2, source) 568 | ), 569 | ), 570 | composers.compose_unary(terminal2, _node1), 571 | ) 572 | )({source: "hi"}) 573 | assert result[terminal1][0] == "node2(hi)" 574 | assert result[terminal2][0] == "node1(hi)" 575 | 576 | 577 | def test_double_star_signature_considered_unary(): 578 | sink = gamla.juxt( 579 | lambda some_argname: some_argname + 1, 580 | lambda different_argname: different_argname * 2, 581 | ) 582 | assert graph_runners.nullary(composers.make_compose(sink, lambda: 3), sink) == ( 583 | 4, 584 | 6, 585 | ) 586 | 587 | 588 | def test_type_safety_messages(caplog, monkeypatch): 589 | monkeypatch.setenv(base_types.COMPUTATION_GRAPH_DEBUG_ENV_KEY, "j") 590 | 591 | def f(x) -> int: # Bad typing! 592 | return "hello " + x 593 | 594 | assert ( 595 | graph_runners.nullary(composers.make_compose(f, lambda: "world"), f) 596 | == "hello world" 597 | ) 598 | assert "TypeError" in caplog.text 599 | 600 | 601 | def test_type_safety_messages_no_overtrigger(caplog, monkeypatch): 602 | monkeypatch.setenv(base_types.COMPUTATION_GRAPH_DEBUG_ENV_KEY, "h") 603 | 604 | def f(x) -> str: 605 | return "hello " + x 606 | 607 | assert ( 608 | graph_runners.nullary(composers.make_compose(f, lambda: "world"), f) 609 | == "hello world" 610 | ) 611 | assert "TypeError" not in caplog.text 612 | 613 | 614 | def test_anonymous_composition_type_safety(): 615 | def f() -> str: 616 | pass 617 | 618 | def g(x: int): 619 | pass 620 | 621 | with pytest.raises(base_types.ComputationGraphTypeError): 622 | composers.make_compose(g, f) 623 | 624 | 625 | def test_named_composition_type_safety(): 626 | def f() -> str: 627 | pass 628 | 629 | def g(x: int): 630 | pass 631 | 632 | with pytest.raises(base_types.ComputationGraphTypeError): 633 | composers.make_compose(g, f, key="x") 634 | 635 | 636 | def _multiply(a, b): 637 | if b: 638 | return a * b 639 | return a 640 | 641 | 642 | def _plus_1(y): 643 | return y + 1 644 | 645 | 646 | async def _plus_1_async(y): 647 | await asyncio.sleep(0.1) 648 | return y + 1 649 | 650 | 651 | def _times_2(x): 652 | return x * 2 653 | 654 | 655 | def _sum(args): 656 | return sum(args) 657 | 658 | 659 | def test_future_edges(): 660 | graph_runners.unary_with_state_and_expectations( 661 | base_types.merge_graphs( 662 | composers.compose_unary(_plus_1, _times_2), 663 | composers.make_compose(_multiply, _plus_1, key="a"), 664 | composers.make_compose_future(_multiply, _times_2, "b", None), 665 | ), 666 | _times_2, 667 | _multiply, 668 | )([[3, 7], [3, 42]]) 669 | 670 | 671 | def test_future_edges_with_circuit(): 672 | def some_input(x): 673 | return x 674 | 675 | f = graph_runners.unary_with_state( 676 | base_types.merge_graphs( 677 | composers.make_compose(_plus_1, _multiply), 678 | composers.make_compose(_times_2, _plus_1), 679 | composers.make_compose(_multiply, some_input, key="a"), 680 | composers.make_compose_future(_multiply, _times_2, "b", None), 681 | ), 682 | some_input, 683 | _times_2, 684 | ) 685 | assert f(3, 3) == 50 686 | 687 | 688 | def test_sink_with_incoming_future_edge(): 689 | def f(x): 690 | return x 691 | 692 | def g(x, y): 693 | if y is None: 694 | y = 4 695 | return f"x={x}, y={y}" 696 | 697 | f = graph_runners.unary( 698 | base_types.merge_graphs( 699 | composers.make_compose(g, f, key="x"), 700 | composers.make_compose_future(g, g, "y", None), 701 | ), 702 | f, 703 | g, 704 | ) 705 | assert f(3) == "x=3, y=4" 706 | 707 | 708 | def test_compose_future(): 709 | a = graph.make_source() 710 | b = graph.make_source() 711 | c = graph.make_source() 712 | graph_runners.variadic_with_state_and_expectations( 713 | base_types.merge_graphs( 714 | composers.compose_source_unary(_plus_1, c), 715 | composers.compose_source_unary(_times_2, b), 716 | composers.compose_source(_multiply, "a", a), 717 | composers.make_compose_future( 718 | _multiply, 719 | composers.make_and([_plus_1, _times_2, _multiply], merge_fn=_sum), 720 | "b", 721 | None, 722 | ), 723 | ), 724 | _sum, 725 | )(([[{a: 2, b: 2, c: 2}, 9], [{a: 2, b: 2, c: 2}, 25]])) 726 | 727 | 728 | async def test_compose_future_async(): 729 | a = graph.make_source() 730 | b = graph.make_source() 731 | c = graph.make_source() 732 | await graph_runners.variadic_with_state_and_expectations( 733 | base_types.merge_graphs( 734 | composers.compose_source_unary(_plus_1_async, c), 735 | composers.compose_source_unary(_times_2, b), 736 | composers.compose_source(_multiply, "a", a), 737 | composers.make_compose_future( 738 | _multiply, 739 | composers.make_and([_plus_1_async, _times_2, _multiply], merge_fn=_sum), 740 | "b", 741 | None, 742 | ), 743 | ), 744 | _sum, 745 | )(([[{a: 2, b: 2, c: 2}, 9], [{a: 2, b: 2, c: 2}, 25]])) 746 | 747 | 748 | def test_dont_duplicate_sources(): 749 | a = graph.make_source() 750 | assert ( 751 | graph_runners.variadic_infer_sink( 752 | duplication.duplicate_function_or_graph( 753 | composers.compose_source_unary(_plus_1, a) 754 | ) 755 | )({a: 2}) 756 | == 3 757 | ) 758 | 759 | 760 | def test_badly_composed_graph_raises(): 761 | with pytest.raises(AssertionError): 762 | run.to_callable_strict( 763 | composers.make_compose(lambda x, y: x + y, lambda: 1, key="x") 764 | ) 765 | 766 | 767 | def test_memory_persists_when_unactionable(): 768 | def input_node(x): 769 | return x 770 | 771 | def output_node(x): 772 | return x 773 | 774 | def skipper(upstream, x): 775 | return x or upstream 776 | 777 | remember_first = memory.with_state("x", None, skipper) 778 | skip_or_passthrough = ( 779 | lambda input: input 780 | if input != "skip state" 781 | else gamla.just_raise(base_types.SkipComputationError) 782 | ) 783 | graph_runners.unary_with_state_and_expectations( 784 | composers.compose_left( 785 | composers.make_first( 786 | composers.make_compose( 787 | composers.compose_left( 788 | skip_or_passthrough, remember_first, key="upstream" 789 | ), 790 | input_node, 791 | key="input", 792 | ), 793 | lambda: "state skipped", 794 | ), 795 | output_node, 796 | ), 797 | input_node, 798 | output_node, 799 | )( 800 | [ 801 | ["remember this", "remember this"], 802 | ["skip state", "state skipped"], 803 | ["recall", "remember this"], 804 | ] 805 | ) 806 | 807 | 808 | def test_replace_source(): 809 | a = graph.make_source() 810 | 811 | assert graph.replace_source(_node1, _node1_async)( 812 | base_types.merge_graphs( 813 | composers.compose_source_unary(_node1, a), 814 | composers.compose_left_unary(_node1, _node2), 815 | ) 816 | ) == base_types.merge_graphs( 817 | composers.compose_source_unary(_node1, a), 818 | composers.compose_left_unary(_node1_async, _node2), 819 | ) 820 | 821 | 822 | def test_replace_source_with_args(): 823 | assert graph.replace_source(_node1, _node1_async)( 824 | ( 825 | base_types.ComputationEdge( 826 | is_future=False, 827 | priority=0, 828 | source=None, 829 | args=( 830 | graph.make_computation_node(_node1), 831 | graph.make_computation_node(_node2), 832 | ), 833 | destination=graph.make_computation_node(_merger), 834 | key="args", 835 | ), 836 | ) 837 | ) == frozenset( 838 | ( 839 | base_types.ComputationEdge( 840 | is_future=False, 841 | priority=0, 842 | source=None, 843 | args=( 844 | graph.make_computation_node(_node1_async), 845 | graph.make_computation_node(_node2), 846 | ), 847 | destination=graph.make_computation_node(_merger), 848 | key="args", 849 | ), 850 | ) 851 | ) 852 | 853 | 854 | def test_replace_source_with_graph(): 855 | a = graph.make_source() 856 | 857 | assert frozenset( 858 | graph.replace_source( 859 | _node1, composers.compose_left_unary(_node1_async, _next_int) 860 | )( 861 | base_types.merge_graphs( 862 | composers.compose_source_unary(_node1, a), 863 | composers.compose_left_unary(_node1, _node2), 864 | ) 865 | ) 866 | ) == base_types.merge_graphs( 867 | composers.compose_source_unary(_node1, a), 868 | composers.compose_left_unary(_next_int, _node2), 869 | composers.compose_left_unary(_node1_async, _next_int), 870 | ) 871 | 872 | 873 | def test_replace_source_that_doesnt_exist(): 874 | a = graph.make_source() 875 | 876 | assert graph.replace_source( 877 | lambda x: x, composers.compose_left_unary(_node1_async, _next_int) 878 | )( 879 | base_types.merge_graphs( 880 | composers.compose_source_unary(_node1, a), 881 | composers.compose_left_unary(_node1, _node2), 882 | ) 883 | ) == base_types.merge_graphs( 884 | composers.compose_source_unary(_node1, a), 885 | composers.compose_left_unary(_node1, _node2), 886 | ) 887 | 888 | 889 | def test_replace_destination(): 890 | assert graph.replace_destination(_node1, _node1_async)( 891 | ( 892 | base_types.ComputationEdge( 893 | is_future=False, 894 | priority=0, 895 | source=graph.make_computation_node(_node2), 896 | args=(), 897 | destination=graph.make_computation_node(_node1), 898 | key="arg1", 899 | ), 900 | ) 901 | ) == base_types.merge_graphs( 902 | ( 903 | base_types.ComputationEdge( 904 | is_future=False, 905 | priority=0, 906 | source=graph.make_computation_node(_node2), 907 | args=(), 908 | destination=graph.make_computation_node(_node1_async), 909 | key="arg1", 910 | ), 911 | ) 912 | ) 913 | 914 | 915 | def test_replace_node(): 916 | a = graph.make_source() 917 | 918 | assert graph.replace_node(_node1, _node1_async)( 919 | base_types.merge_graphs( 920 | composers.compose_source_unary(_node1, a), 921 | composers.compose_left_unary(_node1, _node2), 922 | ) 923 | ) == base_types.merge_graphs( 924 | composers.compose_source_unary(_node1_async, a), 925 | composers.compose_left_unary(_node1_async, _node2), 926 | ) 927 | 928 | 929 | def test_ambig_edges_assertion_in_merge_graphs_active_only_when_env_var_is_active( 930 | monkeypatch, 931 | ): 932 | def a(x): 933 | pass 934 | 935 | monkeypatch.delenv(base_types.COMPUTATION_GRAPH_DEBUG_ENV_KEY, raising=False) 936 | base_types.merge_graphs( 937 | composers.compose_left_unary(lambda: 1, a), 938 | composers.compose_left_unary(lambda: 1, a), 939 | ) 940 | 941 | monkeypatch.setenv(base_types.COMPUTATION_GRAPH_DEBUG_ENV_KEY, "1") 942 | with pytest.raises( 943 | Exception, match=r".*There are multiple edges with the same destination.*" 944 | ): 945 | base_types.merge_graphs( 946 | composers.compose_left_unary(lambda: 1, a), 947 | composers.compose_left_unary(lambda: 1, a), 948 | ) 949 | 950 | 951 | def a(): 952 | pass 953 | 954 | 955 | def b(x): 956 | pass 957 | 958 | 959 | def c(x): 960 | pass 961 | 962 | 963 | def d(x, y): 964 | pass 965 | 966 | 967 | def kuky(): 968 | pass 969 | 970 | 971 | def kuku(): 972 | pass 973 | 974 | 975 | t = graph.make_terminal("t", lambda x: x) 976 | 977 | g = base_types.merge_graphs( 978 | composers.compose_left(a, c), 979 | composers.compose_left(c, d, key="x"), 980 | composers.compose_left(b, d, key="y"), 981 | composers.compose_left_future(d, b, "x", "bla"), 982 | composers.compose_left(a, t), 983 | ) 984 | 985 | 986 | @pytest.mark.parametrize( 987 | "to_replace,expected_edges_strs", 988 | [ 989 | pytest.param( 990 | {a: kuky}, 991 | { 992 | "kuky----x---->duplicate of c", 993 | "kuky----x---->t", 994 | "duplicate of c----x---->duplicate of d", 995 | "when_memory_unavailable----x---->duplicate of b", 996 | "duplicate of d....x....>duplicate of b", 997 | "duplicate of b----y---->duplicate of d", 998 | }, 999 | id="replace source node", 1000 | ), 1001 | pytest.param( 1002 | {c: kuky}, 1003 | { 1004 | "a----x---->kuky", 1005 | "a----x---->t", 1006 | "kuky----x---->duplicate of d", 1007 | "when_memory_unavailable----x---->duplicate of b", 1008 | "duplicate of d....x....>duplicate of b", 1009 | "duplicate of b----y---->duplicate of d", 1010 | }, 1011 | id="replace node not in cycle", 1012 | ), 1013 | pytest.param( 1014 | {b: kuky}, 1015 | { 1016 | "a----x---->c", 1017 | "a----x---->t", 1018 | "c----x---->duplicate of d", 1019 | "when_memory_unavailable----x---->kuky", 1020 | "duplicate of d....x....>kuky", 1021 | "kuky----y---->duplicate of d", 1022 | }, 1023 | id="replace node in cycle", 1024 | ), 1025 | pytest.param( 1026 | {a: kuku, b: kuky}, 1027 | { 1028 | "duplicate of c----x---->duplicate of d", 1029 | "duplicate of d....x....>kuky", 1030 | "kuku----x---->duplicate of c", 1031 | "kuku----x---->t", 1032 | "kuky----y---->duplicate of d", 1033 | "when_memory_unavailable----x---->kuky", 1034 | }, 1035 | id="replace multiple nodes - duplicate reachables once", 1036 | ), 1037 | ], 1038 | ) 1039 | def test_safe_replace_node( 1040 | to_replace: Dict[base_types.CallableOrNode, base_types.CallableOrNodeOrGraph], 1041 | expected_edges_strs: str, 1042 | ): 1043 | assert expected_edges_strs == { 1044 | str(e) for e in duplication.safe_replace_sources(to_replace, g) 1045 | } 1046 | -------------------------------------------------------------------------------- /computation_graph/legacy.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import Any 3 | 4 | import gamla 5 | 6 | from computation_graph import base_types, composers, graph 7 | 8 | 9 | @dataclasses.dataclass(frozen=True) 10 | class ComputationResult: 11 | result: Any 12 | state: Any 13 | 14 | 15 | @gamla.curry 16 | def handle_state( 17 | key: str, default: Any, g: base_types.GraphOrCallable 18 | ) -> base_types.GraphType: 19 | @graph.make_terminal("retrieve_state") 20 | def retrieve_state(x): 21 | if x == (): 22 | raise base_types.SkipComputationError 23 | x = x[0] 24 | assert isinstance(x, ComputationResult), x 25 | return x.state 26 | 27 | return base_types.merge_graphs( 28 | composers.make_compose_future(g, retrieve_state, key, default), 29 | composers.compose_unary(retrieve_state, g), 30 | composers.compose_unary(gamla.attrgetter("result"), g), 31 | ) 32 | -------------------------------------------------------------------------------- /computation_graph/run.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import dataclasses 3 | import functools 4 | import inspect 5 | import itertools 6 | import logging 7 | import os 8 | import time 9 | import typing 10 | from typing import ( 11 | Any, 12 | Awaitable, 13 | Callable, 14 | Dict, 15 | FrozenSet, 16 | Iterable, 17 | Mapping, 18 | Optional, 19 | Set, 20 | Tuple, 21 | Type, 22 | ) 23 | 24 | import gamla 25 | import immutables 26 | import termcolor 27 | import toposort 28 | import typeguard 29 | from gamla.optimized import sync as opt_gamla 30 | 31 | from computation_graph import base_types, composers, graph, signature 32 | 33 | CG_NO_RESULT = "CG_NO_RESULT" 34 | 35 | 36 | class _DepNotFoundError(Exception): 37 | pass 38 | 39 | 40 | _NodeToResults = Dict[base_types.ComputationNode, base_types.Result] 41 | _ComputationInput = Tuple[Tuple[base_types.Result, ...], Dict[str, base_types.Result]] 42 | _SingleNodeSideEffect = Callable[[base_types.ComputationNode, Any], None] 43 | _ComputationInputSpec = Tuple[ 44 | Tuple[base_types.ComputationNode, ...], Dict[str, base_types.ComputationNode] 45 | ] 46 | _NodeExecutor = Callable[ 47 | [ 48 | Mapping[ 49 | base_types.ComputationNode, 50 | base_types.Result | Awaitable[base_types.Result], 51 | ], 52 | base_types.ComputationNode, 53 | ], 54 | base_types.Result, 55 | ] 56 | 57 | 58 | def _transpose_graph( 59 | graph: Dict[base_types.ComputationNode, Set[base_types.ComputationNode]] 60 | ) -> Dict[base_types.ComputationNode, Set[base_types.ComputationNode]]: 61 | return opt_gamla.pipe( 62 | graph, dict.keys, opt_gamla.groupby_many(graph.get), opt_gamla.valmap(set) 63 | ) 64 | 65 | 66 | _toposort_nodes: Callable[ 67 | [base_types.GraphType], Tuple[FrozenSet[base_types.ComputationNode], ...] 68 | ] = opt_gamla.compose_left( 69 | opt_gamla.groupby_many(base_types.edge_sources), 70 | opt_gamla.valmap( 71 | opt_gamla.compose_left(opt_gamla.map(base_types.edge_destination), set) 72 | ), 73 | _transpose_graph, 74 | toposort.toposort, 75 | # Make async functions come first in each layer so they'll start running before all the sync functions 76 | opt_gamla.maptuple( 77 | gamla.sort_by(lambda n: 0 if inspect.iscoroutinefunction(n.func) else 1) 78 | ), 79 | gamla.concat, 80 | tuple, 81 | ) 82 | 83 | 84 | def _type_check(node: base_types.ComputationNode, result): 85 | return_typing = typing.get_type_hints(node.func).get("return", None) 86 | if return_typing: 87 | try: 88 | typeguard.check_type(str(node), result, return_typing) 89 | except TypeError as e: 90 | logging.error([node.func.__code__, e]) 91 | 92 | 93 | def _profile(node, time_started: float): 94 | elapsed = time.perf_counter() - time_started 95 | if elapsed <= 0.1: 96 | return 97 | logging.warning( 98 | termcolor.colored( 99 | f"function took {elapsed:.6f} seconds: {base_types.pretty_print_function_name(node.func)}", 100 | color="red", 101 | ) 102 | ) 103 | 104 | 105 | _group_by_is_future = opt_gamla.groupby(lambda k_v: asyncio.isfuture(k_v[1])) 106 | 107 | 108 | _is_graph_async = opt_gamla.compose_left( 109 | opt_gamla.mapcat(lambda edge: (edge.source, *edge.args)), 110 | opt_gamla.remove(gamla.equals(None)), 111 | opt_gamla.map(base_types.node_implementation), 112 | gamla.anymap(asyncio.iscoroutinefunction), 113 | ) 114 | 115 | 116 | def _future_edge_to_regular_edge_with_placeholder( 117 | source_to_placeholder: dict[base_types.ComputationNode, base_types.ComputationNode] 118 | ) -> Callable[[base_types.ComputationEdge], base_types.ComputationEdge]: 119 | def replace_source(edge): 120 | assert edge.source, "only supports singular edges for now" 121 | 122 | return dataclasses.replace( 123 | edge, is_future=False, source=source_to_placeholder[edge.source] 124 | ) 125 | 126 | return replace_source 127 | 128 | 129 | _map_future_source_to_placeholder = opt_gamla.compose_left( 130 | opt_gamla.map( 131 | opt_gamla.pair_right( 132 | opt_gamla.compose_left( 133 | gamla.attrgetter("name"), graph.make_source_with_name 134 | ) 135 | ) 136 | ), 137 | gamla.frozendict, 138 | ) 139 | _graph_to_future_sources = opt_gamla.compose_left( 140 | opt_gamla.filter(base_types.edge_is_future), 141 | # We assume that future edges cannot be with multiple sources 142 | opt_gamla.map(base_types.edge_source), 143 | frozenset, 144 | ) 145 | 146 | 147 | """Replace multiple edges pointing to a terminal with one edge that has multiple args""" 148 | 149 | 150 | def _merge_edges_pointing_to_terminals(g: base_types.GraphType) -> base_types.GraphType: 151 | return gamla.compose_left( 152 | gamla.groupby(gamla.attrgetter("destination")), 153 | gamla.itemmap( 154 | gamla.star( 155 | lambda dest, edges_for_dest: ( 156 | dest, 157 | base_types.merge_graphs( 158 | composers.make_or( 159 | opt_gamla.maptuple(base_types.edge_source)(edges_for_dest), 160 | merge_fn=(aggregate := lambda args: args), 161 | ), 162 | composers.compose_left_unary(aggregate, dest), 163 | ) 164 | if dest.is_terminal 165 | else edges_for_dest, 166 | ) 167 | ) 168 | ), 169 | dict.values, 170 | gamla.concat, 171 | tuple, 172 | )(g) 173 | 174 | 175 | def _to_callable_with_side_effect_for_single_and_multiple( 176 | single_node_side_effect: _SingleNodeSideEffect, 177 | all_nodes_side_effect: Callable, 178 | edges: base_types.GraphType, 179 | handled_exceptions: Tuple[Type[Exception], ...], 180 | ) -> Callable[[_NodeToResults, _NodeToResults], _NodeToResults]: 181 | edges = _merge_edges_pointing_to_terminals(edges) 182 | single_node_side_effect = ( 183 | (lambda node, result: result) 184 | if os.getenv(base_types.COMPUTATION_GRAPH_DEBUG_ENV_KEY) is None 185 | else single_node_side_effect 186 | ) 187 | 188 | future_sources = _graph_to_future_sources(edges) 189 | future_source_to_placeholder = _map_future_source_to_placeholder(future_sources) 190 | edges = gamla.pipe( 191 | edges, 192 | gamla.unique, 193 | opt_gamla.map( 194 | opt_gamla.when( 195 | base_types.edge_is_future, 196 | _future_edge_to_regular_edge_with_placeholder( 197 | future_source_to_placeholder 198 | ), 199 | ) 200 | ), 201 | tuple, 202 | gamla.side_effect(_assert_composition_is_valid), 203 | gamla.side_effect(base_types.assert_no_unwanted_ambiguity), 204 | ) 205 | is_async = _is_graph_async(edges) 206 | placeholder_to_future_source = opt_gamla.pipe( 207 | future_source_to_placeholder, gamla.itemmap(lambda k_v: (k_v[1], k_v[0])) 208 | ) 209 | get_node_executor = _make_get_node_executor( 210 | edges, handled_exceptions, single_node_side_effect 211 | ) 212 | 213 | topological_sorted_nodes = opt_gamla.pipe( 214 | edges, 215 | _toposort_nodes, 216 | gamla.remove(gamla.contains(placeholder_to_future_source)), 217 | opt_gamla.maptuple(opt_gamla.pair_right(get_node_executor)), 218 | ) 219 | 220 | translate_source_to_placeholder = opt_gamla.compose_left( 221 | opt_gamla.keyfilter(gamla.contains(future_sources)), 222 | opt_gamla.keymap(future_source_to_placeholder.__getitem__), 223 | ) 224 | all_node_side_effects_on_edges = gamla.side_effect(all_nodes_side_effect(edges)) 225 | 226 | if is_async: 227 | 228 | async def final_runner(sources_to_values): 229 | inputs = translate_source_to_placeholder(sources_to_values) 230 | all_results = await _run_graph_async( 231 | inputs, handled_exceptions, topological_sorted_nodes 232 | ) 233 | 234 | return all_node_side_effects_on_edges(all_results) 235 | 236 | else: 237 | 238 | def final_runner(sources_to_values): 239 | return all_node_side_effects_on_edges( 240 | _run_graph( 241 | translate_source_to_placeholder(sources_to_values), 242 | handled_exceptions, 243 | topological_sorted_nodes, 244 | ) 245 | ) 246 | 247 | return (_async_graph_reducer if is_async else _graph_reducer)(final_runner) 248 | 249 | 250 | _get_args_nodes: Callable[ 251 | [Tuple[base_types.ComputationEdge, ...]], Tuple[base_types.ComputationNode, ...] 252 | ] = gamla.compose_left( 253 | opt_gamla.filter(base_types.edge_args), gamla.head, base_types.edge_args 254 | ) 255 | _get_kwargs_nodes = opt_gamla.compose_left( 256 | opt_gamla.filter( 257 | gamla.compose_left(base_types.edge_key, gamla.not_equals("*args")) 258 | ), 259 | gamla.map(gamla.juxt(base_types.edge_key, base_types.edge_source)), 260 | gamla.frozendict, 261 | ) 262 | 263 | 264 | def _node_incoming_edges_to_input_spec( 265 | node_incoming_edges: Tuple[base_types.ComputationEdge], 266 | ) -> _ComputationInputSpec: 267 | if not len(node_incoming_edges): 268 | return (), {} 269 | first_incoming_edge = gamla.head(node_incoming_edges) 270 | node = base_types.edge_destination(first_incoming_edge) 271 | if node.signature.is_kwargs: 272 | return (base_types.edge_source(first_incoming_edge),), {} 273 | return ( 274 | _get_args_nodes(node_incoming_edges) if node.signature.is_args else (), 275 | _get_kwargs_nodes(node_incoming_edges), 276 | ) 277 | 278 | 279 | def _to_awaitable(v) -> Awaitable: 280 | if inspect.isawaitable(v): 281 | return v 282 | f = asyncio.get_event_loop().create_future() 283 | f.set_result(v) 284 | return f 285 | 286 | 287 | def _make_get_node_executor( 288 | edges, handled_exceptions, single_node_side_effect: _SingleNodeSideEffect 289 | ): 290 | node_to_incoming_edges = functools.cache(graph.get_incoming_edges_for_node(edges)) 291 | node_to_computation_input_spec_options: Callable[ 292 | [base_types.ComputationNode], Tuple[_ComputationInputSpec] 293 | ] = functools.cache( 294 | gamla.compose_left( 295 | node_to_incoming_edges, 296 | opt_gamla.groupby(base_types.edge_key), 297 | opt_gamla.valmap(gamla.sort_by(base_types.edge_priority)), 298 | dict.values, 299 | opt_gamla.star(itertools.product), 300 | opt_gamla.maptuple(_node_incoming_edges_to_input_spec), 301 | ) 302 | ) 303 | 304 | async def gather( 305 | args: Tuple[Awaitable, ...], kwargs: Mapping[str, Awaitable] 306 | ) -> tuple[tuple[Any, ...], dict[str, Any]]: 307 | if args or kwargs: 308 | try: 309 | gathered = await asyncio.gather(*args, *kwargs.values()) 310 | return gathered[: len(args)], dict( 311 | zip(kwargs.keys(), gathered[len(args) :]) 312 | ) 313 | except ( 314 | _DepNotFoundError, 315 | base_types.SkipComputationError, 316 | *handled_exceptions, 317 | ): 318 | # We delete the references to the upstream tasks to avoid circular reference (task->exception->traceback->task) and improve memory performance 319 | del args, kwargs 320 | raise _DepNotFoundError() from None 321 | return (), {} 322 | 323 | def node_to_input_sync( 324 | accumulated_results: Mapping[base_types.ComputationNode, base_types.Result], 325 | input_options: Iterable[_ComputationInputSpec], 326 | ) -> Optional[_ComputationInput]: 327 | for input_spec in input_options: 328 | args, kwargs = input_spec 329 | try: 330 | return tuple(accumulated_results[arg] for arg in args), { 331 | k: accumulated_results[v] for k, v in kwargs.items() 332 | } 333 | except KeyError: 334 | ... 335 | return None 336 | 337 | async def node_to_input_async( 338 | accumulated_results: Mapping[ 339 | base_types.ComputationNode, base_types.Result | Awaitable[base_types.Result] 340 | ], 341 | input_options: Iterable[_ComputationInputSpec], 342 | ) -> Optional[_ComputationInput]: 343 | for input_spec in input_options: 344 | args_spec, kwargs_spec = input_spec 345 | if all( 346 | accumulated_results.get(arg, CG_NO_RESULT) is not CG_NO_RESULT 347 | for arg in args_spec 348 | ) and all( 349 | accumulated_results.get(kwarg, CG_NO_RESULT) is not CG_NO_RESULT 350 | for kwarg in kwargs_spec.values() 351 | ): 352 | try: 353 | return await gather( 354 | tuple(_to_awaitable(accumulated_results[a]) for a in args_spec), 355 | { 356 | k: _to_awaitable(accumulated_results[v]) 357 | for k, v in kwargs_spec.items() 358 | }, 359 | ) 360 | except ( 361 | _DepNotFoundError, 362 | base_types.SkipComputationError, 363 | *handled_exceptions, 364 | ): 365 | ... 366 | return None 367 | 368 | @opt_gamla.after(asyncio.create_task) 369 | async def await_deps_and_apply( 370 | accumulated_results: Mapping[ 371 | base_types.ComputationNode, base_types.Result | Awaitable[base_types.Result] 372 | ], 373 | node: base_types.ComputationNode, 374 | ) -> base_types.Result: 375 | args_kwargs = await node_to_input_async( 376 | accumulated_results, node_to_computation_input_spec_options(node) 377 | ) 378 | # We delete the references to the upstream tasks to avoid circular reference (task->exception->traceback->task) and improve memory performance 379 | del accumulated_results 380 | if args_kwargs is None: 381 | raise _DepNotFoundError() 382 | 383 | args, kwargs = args_kwargs 384 | before = time.perf_counter() 385 | result = node.func(*args, **kwargs) 386 | single_node_side_effect(node, result) 387 | if inspect.isawaitable(result): 388 | raise Exception( 389 | f"{node} returned an awaitable result but is not an async function" 390 | ) 391 | _profile(node, before) 392 | return result 393 | 394 | @opt_gamla.after(asyncio.create_task) 395 | async def await_deps_and_await( 396 | accumulated_results: Mapping[ 397 | base_types.ComputationNode, base_types.Result | Awaitable[base_types.Result] 398 | ], 399 | node: base_types.ComputationNode, 400 | ) -> base_types.Result: 401 | args_kwargs = await node_to_input_async( 402 | accumulated_results, node_to_computation_input_spec_options(node) 403 | ) 404 | # We delete the references to the upstream tasks to avoid circular reference (task->exception->traceback->task) and improve memory performance 405 | del accumulated_results 406 | if args_kwargs is None: 407 | raise _DepNotFoundError() 408 | 409 | args, kwargs = args_kwargs 410 | before = time.perf_counter() 411 | result = await node.func(*args, **kwargs) 412 | single_node_side_effect(node, result) 413 | _profile(node, before) 414 | return result 415 | 416 | @opt_gamla.after(asyncio.create_task) 417 | async def get_deps_and_await( 418 | accumulated_results: Mapping[base_types.ComputationNode, base_types.Result], 419 | node: base_types.ComputationNode, 420 | ) -> base_types.Result: 421 | args_kwargs = node_to_input_sync( 422 | accumulated_results, node_to_computation_input_spec_options(node) 423 | ) 424 | # We delete the references to the upstream tasks to avoid circular reference (task->exception->traceback->task) and improve memory performance 425 | del accumulated_results 426 | if args_kwargs is None: 427 | raise _DepNotFoundError() 428 | 429 | args, kwargs = args_kwargs 430 | before = time.perf_counter() 431 | result = await node.func(*args, **kwargs) 432 | single_node_side_effect(node, result) 433 | _profile(node, before) 434 | return result 435 | 436 | def get_deps_and_apply( 437 | accumulated_results: Mapping[base_types.ComputationNode, base_types.Result], 438 | node: base_types.ComputationNode, 439 | ) -> base_types.Result: 440 | args_kwargs = node_to_input_sync( 441 | accumulated_results, node_to_computation_input_spec_options(node) 442 | ) 443 | # We delete the references to the upstream tasks to avoid circular reference (task->exception->traceback->task) and improve memory performance 444 | del accumulated_results 445 | if args_kwargs is None: 446 | raise _DepNotFoundError() 447 | 448 | args, kwargs = args_kwargs 449 | before = time.perf_counter() 450 | result = node.func(*args, **kwargs) 451 | single_node_side_effect(node, result) 452 | if inspect.isawaitable(result): 453 | raise Exception( 454 | f"{node} returned an awaitable result but is not an async function" 455 | ) 456 | _profile(node, before) 457 | return result 458 | 459 | all_nodes = graph.get_all_nodes(edges) 460 | async_nodes = {n for n in all_nodes if asyncio.iscoroutinefunction(n.func)} 461 | sync = all_nodes - async_nodes 462 | tf = graph.traverse_forward(edges) 463 | downstream_from_async = set(gamla.graph_traverse_many(async_nodes, tf)) 464 | 465 | async_and_downstream = async_nodes & downstream_from_async 466 | async_not_downstream = async_nodes - downstream_from_async 467 | sync_and_downstream = sync & downstream_from_async 468 | sync_not_downstream = sync - downstream_from_async 469 | 470 | def get_executor(node: base_types.ComputationNode) -> _NodeExecutor: 471 | if node in async_and_downstream: 472 | return await_deps_and_await 473 | if node in async_not_downstream: 474 | return get_deps_and_await 475 | if node in sync_and_downstream: 476 | return await_deps_and_apply 477 | if node in sync_not_downstream: 478 | # This is fully sync so it only uses sync results from the mapping, its typing says the whole mapping is sync. 479 | return get_deps_and_apply # type: ignore 480 | raise Exception("no executor found") 481 | 482 | return get_executor 483 | 484 | 485 | async def _run_graph_async(inputs, handled_exceptions, topological_sorted_nodes): 486 | node_to_task_or_result = inputs.copy() 487 | unhandled_exception = None 488 | try: 489 | for node_executor in topological_sorted_nodes: 490 | try: 491 | node_to_task_or_result[node_executor[0]] = node_executor[1]( 492 | node_to_task_or_result, node_executor[0] 493 | ) 494 | except ( 495 | _DepNotFoundError, 496 | base_types.SkipComputationError, 497 | *handled_exceptions, 498 | ): 499 | pass 500 | except Exception as exc: 501 | unhandled_exception = exc 502 | finally: 503 | results_by_is_async = _group_by_is_future(node_to_task_or_result.items()) 504 | async_results = tuple(zip(*results_by_is_async.get(True, ()))) 505 | sync_results = dict(results_by_is_async.get(False, ())) 506 | 507 | all_results = sync_results 508 | if async_results: 509 | for (node, node_result) in zip( 510 | async_results[0], 511 | await asyncio.gather(*async_results[1], return_exceptions=True), 512 | ): 513 | task_e = node_to_task_or_result[node].exception() 514 | if not task_e: 515 | all_results[node] = node_result 516 | elif not unhandled_exception and not isinstance( 517 | task_e, 518 | ( 519 | _DepNotFoundError, 520 | base_types.SkipComputationError, 521 | *handled_exceptions, 522 | ), 523 | ): 524 | unhandled_exception = task_e 525 | if unhandled_exception: 526 | # this trick avoids cyclic reference and garbage collection issues 527 | try: 528 | raise unhandled_exception from unhandled_exception 529 | except Exception as e: 530 | del node_to_task_or_result 531 | del unhandled_exception 532 | raise e from e 533 | return all_results 534 | 535 | 536 | def _run_graph( 537 | inputs: dict, 538 | handled_exceptions, 539 | topological_sorted_nodes: tuple[tuple[base_types.ComputationNode, _NodeExecutor]], 540 | ) -> _NodeToResults: 541 | accumulated_results = inputs.copy() 542 | for node_executor in topological_sorted_nodes: 543 | try: 544 | accumulated_results[node_executor[0]] = node_executor[1]( 545 | accumulated_results, node_executor[0] 546 | ) 547 | except ( 548 | _DepNotFoundError, 549 | base_types.SkipComputationError, 550 | *handled_exceptions, 551 | ): 552 | pass 553 | return accumulated_results 554 | 555 | 556 | def _graph_reducer(graph_callable): 557 | def reducer(prev: _NodeToResults, sources: _NodeToResults) -> _NodeToResults: 558 | return {**prev, **(graph_callable({**prev, **sources}))} 559 | 560 | return reducer 561 | 562 | 563 | def _async_graph_reducer(graph_callable): 564 | async def reducer(prev: _NodeToResults, sources: _NodeToResults) -> _NodeToResults: 565 | return {**prev, **(await graph_callable({**prev, **sources}))} 566 | 567 | return reducer 568 | 569 | 570 | to_callable_with_side_effect = gamla.curry( 571 | _to_callable_with_side_effect_for_single_and_multiple 572 | )(_type_check) 573 | 574 | # Use the second line if you want to see the winning path in the computation graph (a little slower). 575 | to_callable = to_callable_with_side_effect(gamla.just(gamla.just(None))) 576 | # to_callable = to_callable_with_side_effect(graphviz.computation_trace('utterance_computation.dot')) 577 | 578 | 579 | def _node_is_properly_composed( 580 | node_to_incoming_edges: base_types.GraphType, 581 | ) -> Callable[[base_types.ComputationNode], bool]: 582 | return gamla.compose_left( 583 | graph.unbound_signature(node_to_incoming_edges), 584 | signature.parameters, 585 | gamla.len_equals(0), 586 | ) 587 | 588 | 589 | def _assert_composition_is_valid(g: base_types.GraphType): 590 | return opt_gamla.pipe( 591 | g, 592 | graph.get_all_nodes, 593 | opt_gamla.remove( 594 | _node_is_properly_composed(graph.get_incoming_edges_for_node(g)) 595 | ), 596 | opt_gamla.map(gamla.wrap_str("{0} at {0.func.__code__}")), 597 | tuple, 598 | gamla.assert_that_with_message( 599 | gamla.wrap_str("Bad composition for: {}"), gamla.len_equals(0) 600 | ), 601 | ) 602 | 603 | 604 | def to_callable_strict( 605 | g: base_types.GraphType, 606 | ) -> Callable[[_NodeToResults, _NodeToResults], _NodeToResults]: 607 | return gamla.compose( 608 | gamla.star(to_callable(g, frozenset())), gamla.map(immutables.Map), gamla.pack 609 | ) 610 | -------------------------------------------------------------------------------- /computation_graph/signature.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import inspect 3 | from types import MappingProxyType 4 | from typing import Callable, FrozenSet 5 | 6 | import gamla 7 | 8 | from computation_graph import base_types 9 | 10 | 11 | def is_supported(signature: base_types.NodeSignature) -> bool: 12 | return not signature.optional_kwargs and not ( 13 | signature.kwargs and signature.is_args 14 | ) 15 | 16 | 17 | def name(func: Callable) -> str: 18 | if isinstance(func, functools.partial): 19 | return func.func.__name__ 20 | return func.__name__ 21 | 22 | 23 | def parameter_is_star(parameter) -> bool: 24 | return parameter.kind == inspect.Parameter.VAR_POSITIONAL 25 | 26 | 27 | def parameter_is_double_star(parameter) -> bool: 28 | return parameter.kind == inspect.Parameter.VAR_KEYWORD 29 | 30 | 31 | def _is_default(parameter): 32 | return parameter.default != parameter.empty 33 | 34 | 35 | _parameter_name = gamla.attrgetter("name") 36 | 37 | _func_parameters = gamla.compose_left( 38 | inspect.signature, gamla.attrgetter("parameters"), MappingProxyType.values, tuple 39 | ) 40 | 41 | 42 | def from_callable(func: Callable) -> base_types.NodeSignature: 43 | function_parameters = _func_parameters(func) 44 | return base_types.NodeSignature( 45 | is_args=gamla.anymap(parameter_is_star)(function_parameters), 46 | is_kwargs=gamla.anymap(parameter_is_double_star)(function_parameters), 47 | kwargs=gamla.pipe( 48 | function_parameters, 49 | gamla.remove(gamla.anyjuxt(parameter_is_star, parameter_is_double_star)), 50 | gamla.map(_parameter_name), 51 | tuple, 52 | ), 53 | optional_kwargs=gamla.pipe( 54 | function_parameters, 55 | gamla.remove(gamla.anyjuxt(parameter_is_star, parameter_is_double_star)), 56 | gamla.filter(_is_default), 57 | gamla.map(_parameter_name), 58 | tuple, 59 | ), 60 | ) 61 | 62 | 63 | def parameters(signature: base_types.NodeSignature) -> FrozenSet[str]: 64 | return frozenset( 65 | { 66 | *signature.kwargs, 67 | *signature.optional_kwargs, 68 | *(("**kwargs",) if signature.is_kwargs else ()), 69 | } 70 | ) 71 | 72 | 73 | is_unary = gamla.compose_left(parameters, gamla.len_equals(1)) 74 | -------------------------------------------------------------------------------- /computation_graph/trace/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyroai/computation-graph/091e1c3b8817ae7f0d4dfb12f1c84d40a655554a/computation_graph/trace/__init__.py -------------------------------------------------------------------------------- /computation_graph/trace/ascii.py: -------------------------------------------------------------------------------- 1 | import pprint 2 | from typing import Callable, Dict, Iterable, Tuple 3 | 4 | import gamla 5 | from gamla.optimized import sync as opt_gamla 6 | 7 | from computation_graph import base_types 8 | from computation_graph.trace import trace_utils 9 | 10 | _NodeTree = Tuple[base_types.ComputationNode, Tuple["_NodeTree", ...]] # type: ignore 11 | _NodeAndResultTree = Tuple[ # type: ignore 12 | base_types.ComputationNode, 13 | base_types.Result, 14 | Tuple["_NodeAndResultTree", ...], # type: ignore 15 | ] 16 | 17 | 18 | @gamla.curry 19 | def _process_node( 20 | node_to_result: Callable[[_NodeTree], base_types.Result], 21 | source_and_destination_to_edges: Callable[ 22 | [Tuple[base_types.ComputationNode, base_types.ComputationNode]], 23 | Iterable[base_types.ComputationEdge], 24 | ], 25 | node: _NodeTree, 26 | children: Iterable[_NodeAndResultTree], 27 | ) -> _NodeAndResultTree: 28 | return ( 29 | node[0], 30 | node_to_result(node), 31 | gamla.pipe( 32 | children, 33 | gamla.map( 34 | gamla.juxt( 35 | gamla.compose_left( 36 | gamla.head, 37 | gamla.pair_right(gamla.just(node[0])), 38 | source_and_destination_to_edges, 39 | # In theory there can be >1 connections between two nodes. 40 | gamla.map(base_types.edge_key), 41 | frozenset, 42 | ), 43 | gamla.identity, 44 | ) 45 | ), 46 | tuple, 47 | ), 48 | ) 49 | 50 | 51 | _should_render = gamla.compose_left(str, gamla.len_smaller(1000)) 52 | 53 | 54 | @gamla.curry 55 | def _skip_uninsteresting_nodes(node_to_result, node, children) -> _NodeTree: 56 | """To make the debug log more readable, we try to reduce uninteresting steps by some heuristics.""" 57 | result = node_to_result(node) 58 | children = tuple(children) 59 | if len(children) == 1: 60 | first_child = gamla.head(children) 61 | if not _should_render(result): 62 | return first_child 63 | if result == node_to_result(children[0][0]): 64 | return first_child 65 | return node, children 66 | 67 | 68 | _index_by_destination = gamla.compose_left( 69 | gamla.groupby(base_types.edge_destination), gamla.dict_to_getter_with_default(()) 70 | ) 71 | 72 | 73 | def _edge_to_node_pairs(edge: base_types.ComputationEdge): 74 | if edge.source: 75 | return [(edge.source, edge.destination)] 76 | return gamla.pipe((edge.args, edge.destination), gamla.explode(0)) 77 | 78 | 79 | _index_by_source_and_destination = gamla.compose_left( 80 | gamla.groupby_many(_edge_to_node_pairs), gamla.dict_to_getter_with_default(()) 81 | ) 82 | 83 | _sources = gamla.compose(frozenset, gamla.mapcat(base_types.edge_sources)) 84 | 85 | 86 | @gamla.curry 87 | def _trace_single_output( 88 | source_and_destination_to_edges, destination_to_edges, node_to_result: Dict 89 | ) -> Callable: 90 | return gamla.compose_left( 91 | gamla.tree_reduce( 92 | gamla.compose_left( 93 | destination_to_edges, 94 | gamla.filter( 95 | trace_utils.is_edge_participating(gamla.contains(node_to_result)) 96 | ), 97 | gamla.mapcat(base_types.edge_sources), 98 | ), 99 | _skip_uninsteresting_nodes(node_to_result.__getitem__), 100 | ), 101 | gamla.tree_reduce( 102 | gamla.nth(1), 103 | _process_node( 104 | gamla.compose_left(gamla.head, node_to_result.__getitem__), 105 | source_and_destination_to_edges, 106 | ), 107 | ), 108 | ) 109 | 110 | 111 | def _sinks_for_trace(non_future_edges): 112 | return gamla.pipe( 113 | non_future_edges, 114 | gamla.map(base_types.edge_destination), 115 | gamla.remove(gamla.contains(_sources(non_future_edges))), 116 | frozenset, 117 | ) 118 | 119 | 120 | @gamla.before(gamla.compose(tuple, gamla.remove(base_types.edge_is_future))) 121 | def computation_trace(g: base_types.GraphType) -> Callable: 122 | return gamla.compose_left( 123 | _trace_single_output( 124 | _index_by_source_and_destination(g), _index_by_destination(g) 125 | ), 126 | opt_gamla.maptuple, 127 | gamla.apply(_sinks_for_trace(g)), 128 | pprint.pprint, 129 | ) 130 | -------------------------------------------------------------------------------- /computation_graph/trace/graphviz.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Iterable, Tuple 2 | from xml.sax import saxutils 3 | 4 | import gamla 5 | import pygraphviz as pgv 6 | 7 | from computation_graph import base_types 8 | from computation_graph.trace import trace_utils 9 | 10 | 11 | def _get_node_shape(node: base_types.ComputationNode): 12 | if node.name == "first": 13 | return "triangle" 14 | if node.is_terminal: 15 | return "doublecircle" 16 | return "ellipse" 17 | 18 | 19 | def _add_computation_node(pgv_graph: pgv.AGraph, node: base_types.ComputationNode): 20 | node_id = hash(node) 21 | pgv_graph.add_node( 22 | node_id, label=saxutils.quoteattr(str(node))[1:-1], shape=_get_node_shape(node) 23 | ) 24 | 25 | 26 | def _handle_edge(pgv_graph, edge): 27 | _add_computation_node(pgv_graph, edge.destination) 28 | if edge.source: 29 | _add_computation_node(pgv_graph, edge.source) 30 | pgv_graph.add_edge( 31 | hash(edge.source), 32 | hash(edge.destination), 33 | label=trace_utils.get_edge_label(edge), 34 | style="dashed" if edge.is_future else "", 35 | ) 36 | else: 37 | for source in edge.args: 38 | _add_computation_node(pgv_graph, source) 39 | pgv_graph.add_edge( 40 | hash(source), 41 | hash(edge.destination), 42 | label=trace_utils.get_edge_label(edge), 43 | style="dashed" if edge.is_future else "", 44 | ) 45 | 46 | 47 | def computation_graph_to_graphviz(edges: base_types.GraphType) -> pgv.AGraph: 48 | pgv_graph = pgv.AGraph(directed=True) 49 | for edge in edges: 50 | _handle_edge(pgv_graph, edge) 51 | return pgv_graph 52 | 53 | 54 | def _do_add_edge(result_graph: pgv.AGraph) -> Callable[[pgv.Edge], pgv.AGraph]: 55 | return lambda edge: gamla.pipe( 56 | result_graph, gamla.side_effect(lambda g: g.add_edge(edge, **edge.attr)) 57 | ) 58 | 59 | 60 | def _do_add_node(result_graph: pgv.AGraph) -> Callable[[pgv.Node], pgv.AGraph]: 61 | return lambda node: gamla.pipe( 62 | result_graph, gamla.side_effect(lambda g: g.add_node(node, **node.attr)) 63 | ) 64 | 65 | 66 | union_graphviz: Callable[[Iterable[pgv.AGraph]], pgv.AGraph] = gamla.compose_left( 67 | # copy to avoid side effects that influence caller 68 | gamla.juxt(gamla.compose_left(gamla.head, pgv.AGraph.copy), gamla.drop(1)), 69 | gamla.star( 70 | gamla.reduce( 71 | lambda result_graph, another_graph: gamla.pipe( 72 | another_graph, 73 | # following does side effects on result_graph, that's why we return just(result_graph) 74 | # we assume no parallelization in gamla.juxt 75 | gamla.side_effect( 76 | gamla.juxt( 77 | gamla.compose_left( 78 | pgv.AGraph.nodes, 79 | gamla.map(_do_add_node(result_graph)), 80 | tuple, 81 | ), 82 | gamla.compose_left( 83 | pgv.AGraph.edges, 84 | gamla.map(_do_add_edge(result_graph)), 85 | tuple, 86 | ), 87 | ) 88 | ), 89 | gamla.just(result_graph), 90 | ) 91 | ) 92 | ), 93 | ) 94 | 95 | 96 | def _save_as_png(filename: str) -> Callable[[pgv.AGraph], pgv.AGraph]: 97 | return gamla.side_effect(lambda pgv_graph: pgv_graph.write(filename)) 98 | 99 | 100 | def computation_trace_to_graphviz( 101 | computation_trace: Iterable[Tuple[base_types.ComputationNode, base_types.Result]] 102 | ) -> pgv.AGraph: 103 | pgv_graph = pgv.AGraph() 104 | for node, result in computation_trace: 105 | _add_computation_node(pgv_graph, node) 106 | graphviz_node = pgv_graph.get_node(hash(node)) 107 | graphviz_node.attr["color"] = "red" 108 | graphviz_node.attr["result"] = str(result)[:200] 109 | 110 | return pgv_graph 111 | 112 | 113 | visualize_graph = gamla.compose_left( 114 | computation_graph_to_graphviz, _save_as_png("bot_computation_graph.dot") 115 | ) 116 | 117 | 118 | @gamla.curry 119 | def computation_trace( 120 | filename: str, graph_instance: base_types.GraphType, node_to_results 121 | ): 122 | gviz = union_graphviz( 123 | [ 124 | computation_graph_to_graphviz(graph_instance), 125 | computation_trace_to_graphviz( 126 | gamla.pipe(node_to_results, dict.items, frozenset) 127 | ), 128 | ] 129 | ) 130 | gviz.layout(prog="dot") 131 | gviz.write(filename) 132 | -------------------------------------------------------------------------------- /computation_graph/trace/graphviz_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | 4 | import pygraphviz as pgv 5 | 6 | from computation_graph import base_types, composers, run 7 | from computation_graph.trace import graphviz 8 | 9 | 10 | def test_computation_trace(tmp_path: pathlib.Path): 11 | def node1(x): 12 | return f"node1({x})" 13 | 14 | def node2(): 15 | return "node2" 16 | 17 | def raises(): 18 | raise base_types.SkipComputationError 19 | 20 | filename = "visualize.dot" 21 | f = run.to_callable_with_side_effect( 22 | graphviz.computation_trace(filename), 23 | composers.make_first(raises, composers.compose_unary(node1, node2)), 24 | frozenset(), 25 | ) 26 | cwd = os.getcwd() 27 | os.chdir(tmp_path) 28 | f({}, {}) 29 | assert (tmp_path / filename).exists() 30 | g = pgv.AGraph() 31 | g.read(str(tmp_path / filename)) 32 | assert g.get_node(hash(node1)).attr["result"] == "node1(node2)" 33 | assert g.get_node(hash(node1)).attr["color"] == "red" 34 | assert g.get_node(hash(node2)).attr["result"] == "node2" 35 | assert g.get_node(hash(node2)).attr["color"] == "red" 36 | assert not g.get_node(hash(raises)).attr["color"] 37 | os.chdir(cwd) 38 | -------------------------------------------------------------------------------- /computation_graph/trace/mermaid.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import gamla 4 | 5 | from computation_graph import base_types 6 | from computation_graph.trace import trace_utils 7 | 8 | 9 | def _clean_for_mermaid_name(obj): 10 | return str(obj).replace('"', "") 11 | 12 | 13 | @gamla.star 14 | def _render_mermaid_node(node, result): 15 | node_id = hash(node) 16 | pretty_name = " ".join(map(_clean_for_mermaid_name, [node, result]))[:100] 17 | return f'{node_id}("{pretty_name}")' 18 | 19 | 20 | def _render_mermaid_edge(edge) -> str: 21 | if edge.source: 22 | return f"{hash(edge.source)} --{trace_utils.get_edge_label(edge)}--> {hash(edge.destination)}" 23 | return gamla.pipe( 24 | edge.args, 25 | gamla.map( 26 | lambda source: f"{hash(source)} --{trace_utils.get_edge_label(edge)}--> {hash(edge.destination)}" 27 | ), 28 | "\n".join, 29 | ) 30 | 31 | 32 | def mermaid_computation_trace(graph_instance: base_types.GraphType): 33 | return gamla.compose_left( 34 | dict.items, 35 | frozenset, 36 | lambda trace: [ 37 | "", # This is to avoid the indent of the logging details. 38 | "graph TD", 39 | gamla.pipe(trace, gamla.map(_render_mermaid_node), "\n".join), 40 | gamla.pipe( 41 | graph_instance, 42 | gamla.filter( 43 | trace_utils.is_edge_participating( 44 | gamla.contains(frozenset(map(gamla.head, trace))) 45 | ) 46 | ), 47 | gamla.map(_render_mermaid_edge), 48 | "\n".join, 49 | ), 50 | ], 51 | "\n".join, 52 | logging.info, 53 | ) 54 | -------------------------------------------------------------------------------- /computation_graph/trace/trace_utils.py: -------------------------------------------------------------------------------- 1 | import gamla 2 | 3 | from computation_graph import base_types 4 | 5 | 6 | def is_edge_participating(in_trace_nodes): 7 | def is_edge_participating(edge): 8 | return in_trace_nodes(edge.destination) and gamla.anymap(in_trace_nodes)( 9 | [edge.source, *edge.args] 10 | ) 11 | 12 | return is_edge_participating 13 | 14 | 15 | def get_edge_label(edge: base_types.ComputationEdge) -> str: 16 | if edge.key in (None, "args"): 17 | return "" 18 | if edge.key == "first_input": 19 | return str(edge.priority) 20 | return edge.key or "" 21 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | skip-magic-trailing-comma = true 3 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | asyncio_mode = auto 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | _LONG_DESCRIPTION = fh.read() 5 | 6 | 7 | setuptools.setup( 8 | name="computation-graph", 9 | python_requires=">=3", 10 | version="64", 11 | long_description=_LONG_DESCRIPTION, 12 | long_description_content_type="text/markdown", 13 | packages=setuptools.find_namespace_packages(), 14 | install_requires=[ 15 | "gamla", 16 | "typeguard==2.13.3", 17 | "toposort", 18 | "immutables", 19 | "termcolor", 20 | ], 21 | extras_require={"test": ["pygraphviz", "pytest>=5.4.0"]}, 22 | ) 23 | --------------------------------------------------------------------------------