├── .coveragerc ├── .github └── workflows │ └── test.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── drive_flow ├── __init__.py ├── broker.py ├── core.py ├── dynamic.py ├── types.py └── utils.py ├── examples ├── 1_hello_world_in_order.py ├── 2_hello_world_parallele.py ├── 3_use_event_output.py ├── 4_endless_tick_timer.py ├── 5_retrigger_type.py └── 6_llm_agent_ReAct.py ├── readme.md ├── requirements-dev.txt ├── requirements.txt ├── setup.py └── tests ├── __init__.py ├── test_define.py ├── test_dynamic_run.py ├── test_run.py ├── test_types.py └── test_utils.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [report] 2 | exclude_lines = 3 | # Have to re-enable the standard pragma 4 | pragma: no cover 5 | 6 | # Don't complain if tests don't hit defensive assertion code: 7 | raise NotImplementedError 8 | logger. 9 | 10 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - dev 8 | paths-ignore: 9 | - '**/*.md' 10 | - '**/*.ipynb' 11 | - 'examples/**' 12 | pull_request: 13 | branches: 14 | - main 15 | - dev 16 | paths-ignore: 17 | - '**/*.md' 18 | - '**/*.ipynb' 19 | - 'examples/**' 20 | 21 | jobs: 22 | test: 23 | name: Tests on ${{ matrix.os }} for ${{ matrix.python-version }} 24 | strategy: 25 | matrix: 26 | python-version: [3.9, 3.11] 27 | os: [ubuntu-latest, windows-latest] 28 | runs-on: ${{ matrix.os }} 29 | timeout-minutes: 10 30 | steps: 31 | - uses: actions/checkout@v4 32 | - name: Set up Python ${{ matrix.python-version }} 33 | uses: actions/setup-python@v3 34 | with: 35 | python-version: ${{ matrix.python-version }} 36 | - name: Install dependencies 37 | run: | 38 | python -m pip install --upgrade pip 39 | pip install -r requirements.txt 40 | pip install -r requirements-dev.txt 41 | - name: Lint with flake8 42 | run: | 43 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 44 | - name: Build and Test 45 | run: | 46 | python -m pytest -o log_cli=true -o log_cli_level="INFO" --cov=drive_flow --cov-report=xml -v ./ 47 | - name: Check codecov file 48 | id: check_files 49 | uses: andstor/file-existence-action@v1 50 | with: 51 | files: './coverage.xml' 52 | - name: Upload coverage from test to Codecov 53 | uses: codecov/codecov-action@v2 54 | with: 55 | file: ./coverage.xml 56 | token: ${{ secrets.CODECOV_TOKEN }} 57 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/python 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=python 3 | 4 | ### Python ### 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 106 | #poetry.lock 107 | 108 | # pdm 109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 110 | #pdm.lock 111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 112 | # in version control. 113 | # https://pdm.fming.dev/#use-with-ide 114 | .pdm.toml 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | #.idea/ 165 | 166 | ### Python Patch ### 167 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 168 | poetry.toml 169 | 170 | # ruff 171 | .ruff_cache/ 172 | 173 | # LSP config files 174 | pyrightconfig.json 175 | 176 | # End of https://www.toptal.com/developers/gitignore/api/python 177 | 178 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Gustavo Ye 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include readme.md 2 | -------------------------------------------------------------------------------- /drive_flow/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import EventEngineCls 2 | from .types import EventInput, ReturnBehavior 3 | 4 | __version__ = "0.0.1" 5 | __author__ = "Jianbai Ye" 6 | __url__ = "https://github.com/memodb-io/drive-flow" 7 | 8 | default_drive = EventEngineCls() 9 | -------------------------------------------------------------------------------- /drive_flow/broker.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from .types import BaseEvent, EventInput, Task, GroupEventReturns 3 | from .utils import generate_uuid 4 | 5 | 6 | class BaseBroker: 7 | async def append(self, event: BaseEvent, event_input: EventInput) -> Task: 8 | raise NotImplementedError() 9 | 10 | async def callback_after_run_done(self) -> tuple[BaseEvent, Any]: 11 | raise NotImplementedError() 12 | -------------------------------------------------------------------------------- /drive_flow/core.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import asyncio 3 | from typing import Callable, Optional, Union, Any, Tuple, Literal 4 | from .types import ( 5 | BaseEvent, 6 | EventFunction, 7 | EventGroup, 8 | EventInput, 9 | _SpecialEventReturn, 10 | ReturnBehavior, 11 | InvokeInterCache, 12 | ) 13 | from .broker import BaseBroker 14 | from .utils import logger, string_to_md5_hash, generate_uuid 15 | 16 | 17 | class EventEngineCls: 18 | def __init__(self, name="default", broker: Optional[BaseBroker] = None): 19 | self.name = name 20 | self.broker = broker or BaseBroker() 21 | self.__event_maps: dict[str, BaseEvent] = {} 22 | self.__max_group_size = 0 23 | 24 | def reset(self): 25 | self.__event_maps = {} 26 | 27 | def get_event_from_id(self, event_id: str) -> Optional[BaseEvent]: 28 | return self.__event_maps.get(event_id) 29 | 30 | def make_event(self, func: Union[EventFunction, BaseEvent]) -> BaseEvent: 31 | if isinstance(func, BaseEvent): 32 | self.__event_maps[func.id] = func 33 | return func 34 | assert inspect.iscoroutinefunction( 35 | func 36 | ), "Event function must be a coroutine function" 37 | event = BaseEvent(func) 38 | self.__event_maps[event.id] = event 39 | return event 40 | 41 | def listen_group( 42 | self, 43 | group_markers: list[BaseEvent], 44 | group_name: Optional[str] = None, 45 | retrigger_type: Literal["all", "any"] = "all", 46 | ) -> Callable[[BaseEvent], BaseEvent]: 47 | assert all( 48 | [isinstance(m, BaseEvent) for m in group_markers] 49 | ), "group_markers must be a list of BaseEvent" 50 | assert all( 51 | [m.id in self.__event_maps for m in group_markers] 52 | ), f"group_markers must be registered in the same event engine, current event engine is {self.name}" 53 | group_markers_in_dict = {event.id: event for event in group_markers} 54 | 55 | def decorator(func: BaseEvent) -> BaseEvent: 56 | if not isinstance(func, BaseEvent): 57 | func = self.make_event(func) 58 | assert ( 59 | func.id in self.__event_maps 60 | ), f"Event function must be registered in the same event engine, current event engine is {self.name}" 61 | this_group_name = group_name or f"{len(func.parent_groups)}" 62 | this_group_hash = string_to_md5_hash(":".join(group_markers_in_dict.keys())) 63 | new_group = EventGroup( 64 | this_group_name, 65 | this_group_hash, 66 | group_markers_in_dict, 67 | retrigger_type=retrigger_type, 68 | ) 69 | self.__max_group_size = max( 70 | self.__max_group_size, len(group_markers_in_dict) 71 | ) 72 | if new_group.hash() in func.parent_groups: 73 | logger.warning(f"Group {group_markers} already listened by {func}") 74 | return func 75 | func.parent_groups[new_group.hash()] = new_group 76 | return func 77 | 78 | return decorator 79 | 80 | def goto(self, group_markers: list[BaseEvent], *args): 81 | raise NotImplementedError() 82 | 83 | async def invoke_event( 84 | self, 85 | event: BaseEvent, 86 | event_input: Optional[EventInput] = None, 87 | global_ctx: Any = None, 88 | max_async_events: Optional[int] = None, 89 | ) -> dict[str, Any]: 90 | this_run_ctx: dict[str, InvokeInterCache] = {} 91 | queue: list[Tuple[str, EventInput]] = [(event.id, event_input)] 92 | 93 | async def run_event(current_event_id: str, current_event_input: Any): 94 | current_event = self.get_event_from_id(current_event_id) 95 | assert current_event is not None, f"Event {current_event_id} not found" 96 | result = await current_event.solo_run(current_event_input, global_ctx) 97 | this_run_ctx[current_event.id] = { 98 | "result": result, 99 | "already_sent_to_event_group": set(), 100 | } 101 | if isinstance(result, _SpecialEventReturn): 102 | if result.behavior == ReturnBehavior.GOTO: 103 | group_markers, any_return = result.returns 104 | for group_marker in group_markers: 105 | this_group_returns = {current_event.id: any_return} 106 | build_input_goto = EventInput( 107 | group_name="$goto", 108 | results=this_group_returns, 109 | behavior=ReturnBehavior.GOTO, 110 | ) 111 | queue.append((group_marker.id, build_input_goto)) 112 | elif result.behavior == ReturnBehavior.ABORT: 113 | return 114 | else: 115 | # dispath to events who listen 116 | for cand_event in self.__event_maps.values(): 117 | cand_event_parents = cand_event.parent_groups 118 | for group_hash, group in cand_event_parents.items(): 119 | if_current_event_trigger = current_event.id in group.events 120 | if_ctx_cover = all( 121 | [event_id in this_run_ctx for event_id in group.events] 122 | ) 123 | event_group_id = f"{cand_event.id}:{group_hash}" 124 | if if_current_event_trigger and if_ctx_cover: 125 | if ( 126 | any( 127 | [ 128 | event_group_id 129 | in this_run_ctx[event_id][ 130 | "already_sent_to_event_group" 131 | ] 132 | for event_id in group.events 133 | ] 134 | ) 135 | and group.retrigger_type == "all" 136 | ): 137 | # some events already dispatched to this event and group, skip 138 | logger.debug(f"Skip {cand_event} for {current_event}") 139 | continue 140 | this_group_returns = { 141 | event_id: this_run_ctx[event_id]["result"] 142 | for event_id in group.events 143 | } 144 | for event_id in group.events: 145 | this_run_ctx[event_id][ 146 | "already_sent_to_event_group" 147 | ].add(event_group_id) 148 | build_input = EventInput( 149 | group_name=group.name, results=this_group_returns 150 | ) 151 | queue.append((cand_event.id, build_input)) 152 | 153 | tasks = set() 154 | try: 155 | while len(queue) or len(tasks): 156 | this_batch_events = ( 157 | queue[:max_async_events] if max_async_events else queue 158 | ) 159 | queue = queue[max_async_events:] if max_async_events else [] 160 | new_tasks = { 161 | asyncio.create_task(run_event(*run_event_input)) 162 | for run_event_input in this_batch_events 163 | } 164 | tasks.update(new_tasks) 165 | done, tasks = await asyncio.wait( 166 | tasks, return_when=asyncio.FIRST_COMPLETED 167 | ) 168 | for task in done: 169 | await task # Handle any exceptions 170 | except asyncio.CancelledError: 171 | for task in tasks: 172 | task.cancel() 173 | await asyncio.gather(*tasks, return_exceptions=True) 174 | raise 175 | return this_run_ctx 176 | -------------------------------------------------------------------------------- /drive_flow/dynamic.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from .types import ( 3 | BaseEvent, 4 | _SpecialEventReturn, 5 | ReturnBehavior, 6 | ) 7 | 8 | 9 | def goto_events( 10 | group_markers: list[BaseEvent], any_return: Any = None 11 | ) -> _SpecialEventReturn: 12 | return _SpecialEventReturn( 13 | behavior=ReturnBehavior.GOTO, returns=(group_markers, any_return) 14 | ) 15 | 16 | 17 | def abort_this(): 18 | return _SpecialEventReturn(behavior=ReturnBehavior.ABORT, returns=None) 19 | -------------------------------------------------------------------------------- /drive_flow/types.py: -------------------------------------------------------------------------------- 1 | from copy import copy 2 | from enum import Enum 3 | from dataclasses import dataclass, field 4 | from datetime import datetime 5 | from typing import Any, Awaitable, Optional, Union, Callable, TypedDict, Literal 6 | 7 | from .utils import ( 8 | string_to_md5_hash, 9 | generate_uuid, 10 | function_or_method_to_string, 11 | function_or_method_to_repr, 12 | ) 13 | 14 | 15 | class ReturnBehavior(Enum): 16 | DISPATCH = "dispatch" 17 | GOTO = "goto" 18 | ABORT = "abort" 19 | INPUT = "input" 20 | 21 | 22 | class TaskStatus(Enum): 23 | RUNNING = "running" 24 | SUCCESS = "success" 25 | FAILURE = "failure" 26 | PENDING = "pending" 27 | 28 | 29 | class InvokeInterCache(TypedDict): 30 | result: Any 31 | already_sent_to_event_group: set[str] 32 | 33 | 34 | GroupEventReturns = dict[str, Any] 35 | 36 | 37 | @dataclass 38 | class EventGroupInput: 39 | group_name: str 40 | results: GroupEventReturns 41 | behavior: ReturnBehavior = ReturnBehavior.DISPATCH 42 | 43 | 44 | @dataclass 45 | class EventInput(EventGroupInput): 46 | task_id: str = field(default_factory=generate_uuid) 47 | 48 | @classmethod 49 | def from_input(cls: "EventInput", input_data: dict[str, Any]) -> "EventInput": 50 | return cls( 51 | group_name="user_input", results=input_data, behavior=ReturnBehavior.INPUT 52 | ) 53 | 54 | 55 | @dataclass 56 | class _SpecialEventReturn: 57 | behavior: ReturnBehavior 58 | returns: Any 59 | 60 | def __post_init__(self): 61 | if not isinstance(self.behavior, ReturnBehavior): 62 | raise TypeError( 63 | f"behavior must be a ReturnBehavior, not {type(self.behavior)}" 64 | ) 65 | 66 | 67 | # (group_event_results, global ctx set by user) -> result 68 | EventFunction = Callable[ 69 | [Optional[EventInput], Optional[Any]], Awaitable[Union[Any, _SpecialEventReturn]] 70 | ] 71 | 72 | 73 | @dataclass 74 | class EventGroup: 75 | name: str 76 | events_hash: str 77 | events: dict[str, "BaseEvent"] 78 | retrigger_type: Literal["all", "any"] = "all" 79 | 80 | def hash(self) -> str: 81 | return self.events_hash 82 | 83 | 84 | class BaseEvent: 85 | parent_groups: dict[str, EventGroup] 86 | func_inst: EventFunction 87 | id: str 88 | repr_name: str 89 | 90 | def __init__( 91 | self, 92 | func_inst: EventFunction, 93 | parent_groups: Optional[dict[str, EventGroup]] = None, 94 | ): 95 | self.parent_groups = parent_groups or {} 96 | self.func_inst = func_inst 97 | self.id = string_to_md5_hash(function_or_method_to_string(self.func_inst)) 98 | self.repr_name = function_or_method_to_repr(self.func_inst) 99 | self.meta = {"func_body": function_or_method_to_string(self.func_inst)} 100 | 101 | def debug_string(self, exclude_events: Optional[set[str]] = None) -> str: 102 | exclude_events = exclude_events or set([self.id]) 103 | parents_str = format_parents(self.parent_groups, exclude_events=exclude_events) 104 | return f"{self.repr_name}\n{parents_str}" 105 | 106 | def __repr__(self) -> str: 107 | return f"Node(source={self.repr_name})" 108 | 109 | async def solo_run( 110 | self, event_input: EventInput, global_ctx: Any = None 111 | ) -> Awaitable[Any]: 112 | return await self.func_inst(event_input, global_ctx) 113 | 114 | 115 | @dataclass 116 | class Task: 117 | task_id: str 118 | status: TaskStatus = TaskStatus.PENDING 119 | created_at: datetime = field(default_factory=datetime.now) 120 | upated_at: datetime = field(default_factory=datetime.now) 121 | 122 | 123 | def format_parents(parents: dict[str, EventGroup], exclude_events: set[str], indent=""): 124 | # Below code is ugly 125 | # But it works and only for debug display 126 | result = [] 127 | for i, parent_group in enumerate(parents.values()): 128 | is_last_group = i == len(parents) - 1 129 | group_prefix = "└─ " if is_last_group else "├─ " 130 | result.append(indent + group_prefix + f"<{parent_group.name}>") 131 | for j, parent in enumerate(parent_group.events.values()): 132 | root_events = copy(exclude_events) 133 | is_last = j == len(parent_group.events) - 1 134 | child_indent = indent + (" " if is_last_group else "│ ") 135 | inter_indent = " " if is_last else "│ " 136 | prefix = "└─ " if is_last else "├─ " 137 | if parent.id in root_events: 138 | result.append(f"{child_indent}{prefix}{parent.repr_name} ") 139 | continue 140 | root_events.add(parent.id) 141 | parent_debug = parent.debug_string(exclude_events=root_events).split("\n") 142 | parent_debug = [p for p in parent_debug if p.strip()] 143 | result.append(f"{child_indent}{prefix}{parent.repr_name}") 144 | for line in parent_debug[1:]: 145 | result.append(f"{child_indent}{inter_indent}{line}") 146 | return "\n".join(result) 147 | -------------------------------------------------------------------------------- /drive_flow/utils.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | import logging 3 | import asyncio 4 | import inspect 5 | import hashlib 6 | from typing import Callable 7 | 8 | logger = logging.getLogger("drive-flow") 9 | 10 | 11 | def generate_uuid() -> str: 12 | return str(uuid.uuid4()) 13 | 14 | 15 | def function_or_method_to_repr(func_or_method: Callable) -> str: 16 | is_method = inspect.ismethod(func_or_method) 17 | is_function = inspect.isfunction(func_or_method) 18 | if not is_method and not is_function: 19 | raise ValueError("Input must be a function or method") 20 | module = func_or_method.__module__ 21 | name = func_or_method.__name__ 22 | line_number = inspect.getsourcelines(func_or_method)[1] 23 | 24 | if is_method: 25 | class_name = func_or_method.__self__.__class__.__name__ 26 | return f"{module}.l_{line_number}.{class_name}.{name}".strip() 27 | else: 28 | return f"{module}.l_{line_number}.{name}".strip() 29 | 30 | 31 | def function_or_method_to_string(func_or_method: Callable) -> str: 32 | is_method = inspect.ismethod(func_or_method) 33 | is_function = inspect.isfunction(func_or_method) 34 | if not is_method and not is_function: 35 | raise ValueError("Input must be a function or method") 36 | module = func_or_method.__module__ 37 | source = inspect.getsource(func_or_method) 38 | line_number = inspect.getsourcelines(func_or_method)[1] 39 | 40 | if is_method: 41 | class_name = func_or_method.__self__.__class__.__name__ 42 | return f"{module}.l_{line_number}.{class_name}\n{source}".strip() 43 | else: 44 | return f"{module}.l_{line_number}\n{source}".strip() 45 | 46 | 47 | def string_to_md5_hash(string: str) -> str: 48 | return hashlib.md5(string.encode()).hexdigest() 49 | -------------------------------------------------------------------------------- /examples/1_hello_world_in_order.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from drive_flow import EventInput, default_drive 3 | 4 | 5 | @default_drive.make_event 6 | async def hello(event: EventInput, global_ctx): 7 | print("hello") 8 | 9 | 10 | @default_drive.listen_group([hello]) 11 | async def world(event: EventInput, global_ctx): 12 | print("world") 13 | 14 | 15 | asyncio.run(default_drive.invoke_event(hello)) 16 | -------------------------------------------------------------------------------- /examples/2_hello_world_parallele.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from datetime import datetime 3 | from drive_flow import EventInput, default_drive 4 | 5 | 6 | @default_drive.make_event 7 | async def start(event: EventInput, global_ctx): 8 | print("start") 9 | 10 | 11 | @default_drive.listen_group([start]) 12 | async def hello(event: EventInput, global_ctx): 13 | print(datetime.now(), "hello") 14 | await asyncio.sleep(0.2) 15 | print(datetime.now(), "hello done") 16 | 17 | 18 | @default_drive.listen_group([start]) 19 | async def world(event: EventInput, global_ctx): 20 | print(datetime.now(), "world") 21 | await asyncio.sleep(0.2) 22 | print(datetime.now(), "world done") 23 | 24 | 25 | asyncio.run(default_drive.invoke_event(start)) 26 | -------------------------------------------------------------------------------- /examples/3_use_event_output.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from datetime import datetime 3 | from drive_flow import EventInput, default_drive 4 | 5 | 6 | @default_drive.make_event 7 | async def start(event: EventInput, global_ctx): 8 | print("start") 9 | 10 | 11 | @default_drive.listen_group([start]) 12 | async def hello(event: EventInput, global_ctx): 13 | return 1 14 | 15 | 16 | @default_drive.listen_group([start]) 17 | async def world(event: EventInput, global_ctx): 18 | return 2 19 | 20 | 21 | @default_drive.listen_group([hello, world]) 22 | async def adding(event: EventInput, global_ctx): 23 | results = event.results 24 | print("adding", hello, world) 25 | return results[hello.id] + results[world.id] 26 | 27 | 28 | results = asyncio.run(default_drive.invoke_event(start)) 29 | assert results[adding.id] == 3 30 | -------------------------------------------------------------------------------- /examples/4_endless_tick_timer.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from datetime import datetime 3 | from drive_flow import default_drive, EventInput 4 | from drive_flow.dynamic import goto_events 5 | 6 | 7 | @default_drive.make_event 8 | async def tick(event: EventInput, global_ctx): 9 | await asyncio.sleep(1) 10 | return "tick" 11 | 12 | 13 | @default_drive.listen_group([tick]) 14 | async def tok(event: EventInput, global_ctx): 15 | print(datetime.now(), f"{event.results[tick.id]}, then tok") 16 | return goto_events([tick]) 17 | 18 | 19 | asyncio.run(default_drive.invoke_event(tick)) 20 | -------------------------------------------------------------------------------- /examples/5_retrigger_type.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from drive_flow import default_drive, EventInput 3 | from drive_flow.dynamic import goto_events 4 | 5 | 6 | @default_drive.make_event 7 | async def on_start(event: EventInput, global_ctx): 8 | print("---------New_turn---------") 9 | 10 | 11 | @default_drive.listen_group([on_start]) 12 | async def a(event: EventInput, global_ctx): 13 | await asyncio.sleep(0.1) 14 | print("a") 15 | 16 | 17 | @default_drive.listen_group([on_start]) 18 | async def b(event: EventInput, global_ctx): 19 | await asyncio.sleep(0.5) 20 | print("b") 21 | 22 | 23 | @default_drive.listen_group([a, b], retrigger_type="any") 24 | async def c(event: EventInput, global_ctx): 25 | print("C is triggered") 26 | 27 | 28 | # default retrigger_type is 'all' 29 | @default_drive.listen_group([a, b], retrigger_type="all") 30 | async def d(event: EventInput, global_ctx): 31 | print("D is triggered") 32 | return goto_events([on_start]) # re-loop the workflow 33 | 34 | 35 | if __name__ == "__main__": 36 | asyncio.run(default_drive.invoke_event(on_start)) 37 | 38 | # For the first turn, the print will be: 39 | # ---------New_turn--------- 40 | # a 41 | # b 42 | # C is triggered 43 | # D is triggered 44 | 45 | # But for the rest of the turns, the print will be: 46 | # ---------New_turn--------- 47 | # a 48 | # C is triggered 49 | # b 50 | # C is triggered 51 | # D is triggered 52 | 53 | # Because the retrigger_type of d is 'all', it will be triggered only when all the events in the group (a, b) are updated. 54 | # The retrigger_type of c is 'any'. So when a is updated, it will trigger c, and when b is updated, it will trigger c again. 55 | -------------------------------------------------------------------------------- /examples/6_llm_agent_ReAct.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | from openai import AsyncOpenAI 4 | from openai.types.chat import ChatCompletionMessageToolCall 5 | from drive_flow import default_drive, EventInput, ReturnBehavior 6 | from drive_flow.dynamic import goto_events, abort_this 7 | 8 | openai_client = AsyncOpenAI() 9 | use_model = "gpt-4o-mini" 10 | 11 | 12 | def multiply(a: int, b: int) -> int: 13 | """Multiply two integers and returns the result integer""" 14 | return a * b 15 | 16 | 17 | def add(a: int, b: int) -> int: 18 | """Add two integers and returns the result integer""" 19 | return a + b 20 | 21 | 22 | function_describe = [ 23 | { 24 | "type": "function", 25 | "function": { 26 | "name": "multiply", 27 | "description": "Multiply two integers and returns the result integer", 28 | "parameters": { 29 | "type": "object", 30 | "properties": { 31 | "a": { 32 | "type": "integer", 33 | "description": "The first integer to multiply.", 34 | }, 35 | "b": { 36 | "type": "integer", 37 | "description": "The second integer to multiply.", 38 | }, 39 | }, 40 | "required": ["a", "b"], 41 | }, 42 | }, 43 | }, 44 | { 45 | "type": "function", 46 | "function": { 47 | "name": "add", 48 | "description": "Add two integers and returns the result integer", 49 | "parameters": { 50 | "type": "object", 51 | "properties": { 52 | "a": { 53 | "type": "integer", 54 | "description": "The first integer to add.", 55 | }, 56 | "b": { 57 | "type": "integer", 58 | "description": "The second integer to add.", 59 | }, 60 | }, 61 | "required": ["a", "b"], 62 | }, 63 | }, 64 | }, 65 | ] 66 | 67 | 68 | @default_drive.make_event 69 | async def plan(event: EventInput, global_ctx): 70 | print("Planning...") 71 | if event.behavior == ReturnBehavior.INPUT: 72 | query = event.results["query"] 73 | messages = [ 74 | { 75 | "role": "system", 76 | "content": "You are a assistant. Use the following functions: multiply, add to compute the result of a calculation. Compute the result step by step", 77 | }, 78 | { 79 | "role": "user", 80 | "content": query, 81 | }, 82 | ] 83 | global_ctx["messages"] = messages 84 | messages = global_ctx["messages"] 85 | response = await openai_client.chat.completions.create( 86 | messages=messages, 87 | model=use_model, 88 | tools=function_describe, 89 | ) 90 | if response.choices[0].finish_reason == "tool_calls": 91 | return response.choices[0].message 92 | else: 93 | global_ctx["answer"] = response.choices[0].message.content 94 | return abort_this() 95 | 96 | 97 | @default_drive.listen_group([plan]) 98 | async def action(event: EventInput, global_ctx): 99 | func_calls: list[ChatCompletionMessageToolCall] = event.results[plan.id].tool_calls 100 | print( 101 | "Executing", 102 | [c.function.name for c in func_calls], 103 | "with arguments", 104 | [json.loads(c.function.arguments) for c in func_calls], 105 | ) 106 | results = [] 107 | for func_c in func_calls: 108 | if func_c.function.name == "multiply": 109 | result = multiply(**json.loads(func_c.function.arguments)) 110 | elif func_c.function.name == "add": 111 | result = add(**json.loads(func_c.function.arguments)) 112 | else: 113 | raise ValueError(f"Unknown function {func_c.function.name}") 114 | results.append(result) 115 | return event.results[plan.id], func_calls, results 116 | 117 | 118 | @default_drive.listen_group([action]) 119 | async def observate(event: EventInput, global_ctx): 120 | func_calls: list[ChatCompletionMessageToolCall] 121 | tool_call_response, func_calls, func_results = event.results[action.id] 122 | print("Observing", [c.function.name for c in func_calls]) 123 | messages = global_ctx["messages"] 124 | messages.append(tool_call_response) 125 | messages.extend( 126 | [ 127 | {"role": "tool", "content": json.dumps({"result": r}), "tool_call_id": c.id} 128 | for c, r in zip(func_calls, func_results) 129 | ] 130 | ) 131 | return goto_events([plan]) 132 | 133 | 134 | if __name__ == "__main__": 135 | question = "3+3*2+20*4" 136 | storage_results = {} 137 | print(observate.debug_string()) 138 | asyncio.run( 139 | default_drive.invoke_event( 140 | plan, 141 | event_input=EventInput.from_input({"query": question}), 142 | global_ctx=storage_results, 143 | ) 144 | ) 145 | 146 | if "answer" not in storage_results: 147 | print(f"Failed to get answer {question}") 148 | exit(1) 149 | print(storage_results["answer"]) 150 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 |
2 |

drive-flow

3 |

Build event-driven workflows with python async functions

4 |

5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 |

15 |
16 | 17 | 18 | 🌬️ [Zero dependency](./requirements.txt). No trouble, no loss. 19 | 20 | 🍰 With **intuitive decorators**, write your async workflow like a piece of cake. 21 | 22 | 🔄 Support dynamic dispatch(`goto`, `abort`). Create a **looping or if-else workflow with ease**. 23 | 24 | 🔜 **Fully asynchronous**. Events are always triggered at the same time if they listen to the same group! 25 | 26 | 27 | 28 | 29 | 30 | ## Install 31 | 32 | **Install from PyPi** 33 | 34 | ```shell 35 | pip install drive-flow 36 | ``` 37 | 38 | **Install from source** 39 | 40 | ```shell 41 | # clone this repo first 42 | cd drive-flow 43 | pip install -e . 44 | ``` 45 | 46 | 47 | 48 | ## Quick Start 49 | 50 | A hello world example: 51 | 52 | ```python 53 | import asyncio 54 | from drive_flow import EventInput, default_drive 55 | 56 | 57 | @default_drive.make_event 58 | async def hello(event: EventInput, global_ctx): 59 | print("hello") 60 | 61 | @default_drive.listen_group([hello]) 62 | async def world(event: EventInput, global_ctx): 63 | print("world") 64 | 65 | # display the dependencies of 'world' event 66 | print(world.debug_string()) 67 | asyncio.run(default_drive.invoke_event(hello)) 68 | ``` 69 | 70 | In this example, The return of `hello` event will trigger `world` event. 71 | 72 | > [!TIP] 73 | > 74 | > Hello world is not cool enough? Try to build a [ReAct Agent Workflow](./examples/6_llm_agent_ReAct.py) with `drive-flow` 75 | 76 | ### Break-down 77 | 78 | To make an event function, there are few elements: 79 | 80 | * Input Signature: must be `(event: EventInput, global_ctx)`. `EventInput` is the returns of the listening groups. `global_ctx` is set by you when invoking events, it can be anything and default to `None`. 81 | 82 | This [example](./examples/3_use_event_output.py) shows how to get returns from `EventInput` . 83 | * Make sure you decorate the function with `@default_drive.make_event` or `@default_drive.listen_group([EVENT,...])` 84 | 85 | Then, run your workflow from any event: 86 | 87 | ```python 88 | await default_drive.invoke_event(EVENT, EVENT_INPUT, GLOBAL_CTX) 89 | ``` 90 | 91 | Check out [examples](./examples) for more detailed usages and features! 92 | 93 | ## Features 94 | 95 | ### Multi-Recv 96 | 97 | `drive_flow` allow an event to be triggered only when a group of events are produced: 98 | 99 |
100 | code snippet 101 | 102 | ```python 103 | @default_drive.make_event 104 | async def start(event: EventInput, global_ctx): 105 | print("start") 106 | 107 | @default_drive.listen_group([start]) 108 | async def hello(event: EventInput, global_ctx): 109 | return 1 110 | 111 | 112 | @default_drive.listen_group([start]) 113 | async def world(event: EventInput, global_ctx): 114 | return 2 115 | 116 | 117 | @default_drive.listen_group([hello, world]) 118 | async def adding(event: EventInput, global_ctx): 119 | results = event.results 120 | print("adding", hello, world) 121 | return results[hello.id] + results[world.id] 122 | 123 | 124 | results = asyncio.run(default_drive.invoke_event(start)) 125 | assert results[adding.id] == 3 126 | ``` 127 | 128 | `adding` will be triggered at first time as long as `hello` and `world` are done. 129 |
130 | 131 | #### Re-trigger the event 132 | 133 | `drive_flow` suppports different behaviors for multi-event retriggering: 134 | 135 | - `all`: retrigger this event only when all the listening events are updated. 136 | - `any`: retrigger this event as long as one of the listening events is updated. 137 | 138 | Check out this [example](./examples/5_retrigger_type.py) for more details 139 | 140 | ### Parallel 141 | 142 | `drive_flow` is perfect for workflows that have many network IO that can be awaited in parallel. If two events are listened to the same group of events, then they will be triggered at the same time: 143 | 144 |
145 | code snippet 146 | 147 | ```python 148 | @default_drive.make_event 149 | async def start(event: EventInput, global_ctx): 150 | print("start") 151 | 152 | @default_drive.listen_group([start]) 153 | async def hello(event: EventInput, global_ctx): 154 | print(datetime.now(), "hello") 155 | await asyncio.sleep(0.2) 156 | print(datetime.now(), "hello done") 157 | 158 | 159 | @default_drive.listen_group([start]) 160 | async def world(event: EventInput, global_ctx): 161 | print(datetime.now(), "world") 162 | await asyncio.sleep(0.2) 163 | print(datetime.now(), "world done") 164 | 165 | asyncio.run(default_drive.invoke_event(start)) 166 | ``` 167 | 168 |
169 | 170 | 171 | 172 | ### Dynamic 173 | 174 | `drive_flow` is dynamic. You can use `goto` and `abort` to change the workflow at runtime: 175 | 176 |
177 | code snippet for abort_this 178 | 179 | ```python 180 | from drive_flow.dynamic import abort_this 181 | 182 | @default_drive.make_event 183 | async def a(event: EventInput, global_ctx): 184 | return abort_this() 185 | # abort_this is not exiting the whole workflow, 186 | # only abort this event's return and not causing any other influence 187 | # `a` chooses to abort its return. So no more events in this invoking. 188 | # this invoking then will end 189 | @default_drive.listen_group([a]) 190 | async def b(event: EventInput, global_ctx): 191 | assert False, "should not be called" 192 | 193 | asyncio.run(default_drive.invoke_event(a)) 194 | ``` 195 | 196 |
197 | 198 |
199 | code snippet for goto 200 | 201 | ```python 202 | from drive_flow.types import ReturnBehavior 203 | from drive_flow.dynamic import goto_events, abort_this 204 | 205 | call_a_count = 0 206 | @default_drive.make_event 207 | async def a(event: EventInput, global_ctx): 208 | global call_a_count 209 | if call_a_count == 0: 210 | assert event is None 211 | elif call_a_count == 1: 212 | assert event.behavior == ReturnBehavior.GOTO 213 | assert event.results == {b.id: 2} 214 | return abort_this() 215 | call_a_count += 1 216 | return 1 217 | 218 | @default_drive.listen_group([a]) 219 | async def b(event: EventInput, global_ctx): 220 | return goto_events([a], 2) 221 | 222 | @default_drive.listen_group([b]) 223 | async def c(event: EventInput, global_ctx): 224 | assert False, "should not be called" 225 | 226 | asyncio.run(default_drive.invoke_event(a)) 227 | ``` 228 | 229 |
230 | 231 | 232 | 233 | ## TODO 234 | 235 | - [x] fix: streaming event executation 236 | - [x] fix: an event never receive the listened events' results twice (de-duplication), unless the group is totally updated for `retrigger_type='all'` 237 | - [x] Add ReAct workflow example 238 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | flake8 2 | pytest 3 | pytest-asyncio 4 | pytest-cov -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/memodb-io/drive-flow/3dc3dd63a35ae746b84f74fc31e6fad187018b01/requirements.txt -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("readme.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | 7 | vars2find = ["__author__", "__version__", "__url__"] 8 | vars2readme = {} 9 | with open("./drive_flow/__init__.py") as f: 10 | for line in f.readlines(): 11 | for v in vars2find: 12 | if line.startswith(v): 13 | line = line.replace(" ", "").replace('"', "").replace("'", "").strip() 14 | vars2readme[v] = line.split("=")[1] 15 | 16 | deps = [] 17 | with open("./requirements.txt") as f: 18 | for line in f.readlines(): 19 | if not line.strip(): 20 | continue 21 | deps.append(line.strip()) 22 | 23 | setuptools.setup( 24 | name="drive-flow", 25 | url=vars2readme["__url__"], 26 | version=vars2readme["__version__"], 27 | author=vars2readme["__author__"], 28 | description="Build event-driven workflows with python async functions", 29 | long_description=long_description, 30 | long_description_content_type="text/markdown", 31 | packages=["drive_flow"], 32 | classifiers=[ 33 | "Programming Language :: Python :: 3", 34 | "License :: OSI Approved :: MIT License", 35 | "Operating System :: OS Independent", 36 | ], 37 | python_requires=">=3.9", 38 | install_requires=deps, 39 | ) 40 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/memodb-io/drive-flow/3dc3dd63a35ae746b84f74fc31e6fad187018b01/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_define.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from drive_flow import default_drive, EventInput 4 | from drive_flow.types import BaseEvent 5 | 6 | 7 | @pytest.mark.asyncio 8 | async def test_non_async_func(): 9 | with pytest.raises(AssertionError): 10 | 11 | @default_drive.make_event 12 | def a(event: EventInput, global_ctx): 13 | return 1 14 | 15 | 16 | @pytest.mark.asyncio 17 | async def test_non_async_listen_groups(): 18 | async def a(event: EventInput, global_ctx): 19 | return 1 20 | 21 | with pytest.raises(AssertionError): 22 | 23 | @default_drive.listen_group([a]) 24 | async def b(event: EventInput, global_ctx): 25 | return 1 26 | 27 | 28 | @pytest.mark.asyncio 29 | async def test_set_and_reset(): 30 | @default_drive.make_event 31 | async def a(event: EventInput, global_ctx): 32 | return 1 33 | 34 | @default_drive.listen_group([a]) 35 | async def b(event: EventInput, global_ctx): 36 | return 2 37 | 38 | default_drive.reset() 39 | 40 | with pytest.raises(AssertionError): 41 | 42 | @default_drive.listen_group([a]) 43 | async def b(event: EventInput, global_ctx): 44 | return 2 45 | 46 | 47 | @pytest.mark.asyncio 48 | async def test_duplicate_decorator(): 49 | @default_drive.make_event 50 | @default_drive.make_event 51 | async def a(event: EventInput, global_ctx): 52 | return 1 53 | 54 | assert isinstance(a, BaseEvent) 55 | 56 | 57 | @pytest.mark.asyncio 58 | async def test_correct_get_id(): 59 | @default_drive.make_event 60 | async def a(event: EventInput, global_ctx): 61 | return 1 62 | 63 | assert default_drive.get_event_from_id(a.id) == a 64 | 65 | 66 | @pytest.mark.asyncio 67 | async def test_order(): 68 | @default_drive.make_event 69 | async def a(event: EventInput, global_ctx): 70 | return 1 71 | 72 | @default_drive.listen_group([a]) 73 | async def b(event: EventInput, global_ctx): 74 | return 2 75 | 76 | @default_drive.listen_group([b]) 77 | async def c(event: EventInput, global_ctx): 78 | return 3 79 | 80 | print(a.debug_string()) 81 | print(b.debug_string()) 82 | print(c.debug_string()) 83 | 84 | assert await a.solo_run(None) == 1 85 | assert await b.solo_run(None) == 2 86 | assert await c.solo_run(None) == 3 87 | 88 | 89 | @pytest.mark.asyncio 90 | async def test_multi_send(): 91 | @default_drive.make_event 92 | async def a(event: EventInput, global_ctx): 93 | return 1 94 | 95 | @default_drive.listen_group([a]) 96 | async def b(event: EventInput, global_ctx): 97 | return 2 98 | 99 | @default_drive.listen_group([a]) 100 | async def c(event: EventInput, global_ctx): 101 | return 3 102 | 103 | print(a.debug_string()) 104 | print(b.debug_string()) 105 | print(c.debug_string()) 106 | assert await a.solo_run(None) == 1 107 | assert await b.solo_run(None) == 2 108 | assert await c.solo_run(None) == 3 109 | 110 | 111 | @pytest.mark.asyncio 112 | async def test_multi_recv(): 113 | @default_drive.make_event 114 | async def a(event: EventInput, global_ctx): 115 | return 1 116 | 117 | @default_drive.listen_group([a]) 118 | async def a1(event: EventInput, global_ctx): 119 | return 1 120 | 121 | @default_drive.make_event 122 | async def b(event: EventInput, global_ctx): 123 | return 2 124 | 125 | @default_drive.listen_group([a1, b]) 126 | async def c(event: EventInput, global_ctx): 127 | return 3 128 | 129 | print(a.debug_string()) 130 | print(b.debug_string()) 131 | print(c.debug_string()) 132 | assert await a.solo_run(None) == 1 133 | assert await b.solo_run(None) == 2 134 | assert await c.solo_run(None) == 3 135 | 136 | 137 | @pytest.mark.asyncio 138 | async def test_multi_groups(): 139 | @default_drive.make_event 140 | async def a0(event: EventInput, global_ctx): 141 | return 0 142 | 143 | @default_drive.make_event 144 | async def a1(event: EventInput, global_ctx): 145 | return 0 146 | 147 | @default_drive.listen_group([a0, a1]) 148 | @default_drive.listen_group([a0, a1]) 149 | @default_drive.listen_group([a0, a1]) 150 | async def a(event: EventInput, global_ctx): 151 | return 1 152 | 153 | assert await a.solo_run(None) == 1 154 | 155 | 156 | @pytest.mark.asyncio 157 | async def test_loop(): 158 | @default_drive.make_event 159 | async def a(event: EventInput, global_ctx): 160 | return 1 161 | 162 | @default_drive.listen_group([a]) 163 | async def b(event: EventInput, global_ctx): 164 | return 2 165 | 166 | a = default_drive.listen_group([b])(a) 167 | 168 | @default_drive.listen_group([a, b]) 169 | async def c(event: EventInput, global_ctx): 170 | return 3 171 | 172 | print(a.debug_string()) 173 | print(b.debug_string()) 174 | print(c.debug_string()) 175 | 176 | assert await a.solo_run(None) == 1 177 | assert await b.solo_run(None) == 2 178 | assert await c.solo_run(None) == 3 179 | -------------------------------------------------------------------------------- /tests/test_dynamic_run.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from drive_flow import default_drive, EventInput 3 | from drive_flow.types import ReturnBehavior, _SpecialEventReturn 4 | from drive_flow.dynamic import goto_events, abort_this 5 | 6 | 7 | class DeliberateExcepion(Exception): 8 | pass 9 | 10 | 11 | def test_special_event_init(): 12 | with pytest.raises(TypeError): 13 | _SpecialEventReturn("fool", 1) 14 | 15 | 16 | @pytest.mark.asyncio 17 | async def test_abort(): 18 | @default_drive.make_event 19 | async def a(event: EventInput, global_ctx): 20 | assert global_ctx == {"test_ctx": 1} 21 | return abort_this() 22 | 23 | @default_drive.listen_group([a]) 24 | async def b(event: EventInput, global_ctx): 25 | assert False, "should not be called" 26 | 27 | result = await default_drive.invoke_event(a, None, {"test_ctx": 1}) 28 | print(result) 29 | 30 | 31 | @pytest.mark.asyncio 32 | async def test_goto(): 33 | call_a_count = 0 34 | 35 | @default_drive.make_event 36 | async def a(event: EventInput, global_ctx): 37 | nonlocal call_a_count 38 | if call_a_count == 0: 39 | assert event is None 40 | elif call_a_count == 1: 41 | assert event.behavior == ReturnBehavior.GOTO 42 | assert event.group_name == "$goto" 43 | assert event.results == {b.id: 2} 44 | return abort_this() 45 | else: 46 | raise ValueError("should not be called more than twice") 47 | call_a_count += 1 48 | return 1 49 | 50 | @default_drive.listen_group([a]) 51 | async def b(event: EventInput, global_ctx): 52 | return goto_events([a], 2) 53 | 54 | @default_drive.listen_group([b]) 55 | async def c(event: EventInput, global_ctx): 56 | assert False, "should not be called" 57 | 58 | result = await default_drive.invoke_event(a, None, {"test_ctx": 1}) 59 | assert call_a_count == 1 60 | print(result) 61 | -------------------------------------------------------------------------------- /tests/test_run.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import pytest 3 | from drive_flow import default_drive, EventInput 4 | from drive_flow.types import ReturnBehavior 5 | from drive_flow.dynamic import abort_this 6 | 7 | 8 | class DeliberateExcepion(Exception): 9 | pass 10 | 11 | 12 | @pytest.mark.asyncio 13 | async def test_simple_order_run(): 14 | @default_drive.make_event 15 | async def a(event: EventInput, global_ctx): 16 | assert global_ctx == {"test_ctx": 1} 17 | return 1 18 | 19 | @default_drive.listen_group([a]) 20 | async def b(event: EventInput, global_ctx): 21 | assert global_ctx == {"test_ctx": 1} 22 | assert event.group_name == "0" 23 | assert event.behavior == ReturnBehavior.DISPATCH 24 | assert event.results == {a.id: 1} 25 | return 2 26 | 27 | @default_drive.listen_group([b]) 28 | async def c(event: EventInput, global_ctx): 29 | assert global_ctx == {"test_ctx": 1} 30 | assert event.group_name == "0" 31 | assert event.behavior == ReturnBehavior.DISPATCH 32 | assert event.results == {b.id: 2} 33 | return 3 34 | 35 | result = await default_drive.invoke_event(a, None, {"test_ctx": 1}) 36 | print(result) 37 | 38 | 39 | @pytest.mark.asyncio 40 | async def test_multi_send(): 41 | @default_drive.make_event 42 | async def a(event: EventInput, global_ctx): 43 | return 1 44 | 45 | @default_drive.listen_group([a]) 46 | async def b(event: EventInput, global_ctx): 47 | assert event.group_name == "0" 48 | assert event.behavior == ReturnBehavior.DISPATCH 49 | assert event.results == {a.id: 1} 50 | return 2 51 | 52 | @default_drive.listen_group([a]) 53 | async def c(event: EventInput, global_ctx): 54 | assert event.group_name == "0" 55 | assert event.behavior == ReturnBehavior.DISPATCH 56 | assert event.results == {a.id: 1} 57 | return 3 58 | 59 | result = await default_drive.invoke_event(a, None, {"test_ctx": 1}) 60 | print(result) 61 | 62 | 63 | @pytest.mark.asyncio 64 | async def test_multi_recv(): 65 | @default_drive.make_event 66 | async def start(event: EventInput, global_ctx): 67 | return None 68 | 69 | @default_drive.listen_group([start]) 70 | async def a(event: EventInput, global_ctx): 71 | return 1 72 | 73 | @default_drive.listen_group([start]) 74 | async def b(event: EventInput, global_ctx): 75 | await asyncio.sleep(0.2) 76 | return 2 77 | 78 | @default_drive.listen_group([a, b]) 79 | async def c(event: EventInput, global_ctx): 80 | assert event.group_name == "0" 81 | assert event.behavior == ReturnBehavior.DISPATCH 82 | assert event.results == {a.id: 1, b.id: 2} 83 | return 3 84 | 85 | result = await default_drive.invoke_event(start, None, {"test_ctx": 1}) 86 | print(result) 87 | 88 | 89 | @pytest.mark.asyncio 90 | async def test_multi_recv_cancel(): 91 | @default_drive.make_event 92 | async def start(event: EventInput, global_ctx): 93 | return None 94 | 95 | @default_drive.listen_group([start]) 96 | async def a(event: EventInput, global_ctx): 97 | raise asyncio.CancelledError() 98 | return 1 99 | 100 | @default_drive.listen_group([start]) 101 | async def b(event: EventInput, global_ctx): 102 | await asyncio.sleep(0.2) 103 | return 2 104 | 105 | @default_drive.listen_group([a, b]) 106 | async def c(event: EventInput, global_ctx): 107 | assert event.group_name == "0" 108 | assert event.behavior == ReturnBehavior.DISPATCH 109 | assert event.results == {a.id: 1, b.id: 2} 110 | return 3 111 | 112 | with pytest.raises(asyncio.CancelledError): 113 | result = await default_drive.invoke_event(start, None, {"test_ctx": 1}) 114 | 115 | 116 | @pytest.mark.asyncio 117 | async def test_multi_groups(): 118 | @default_drive.make_event 119 | async def a(event: EventInput, global_ctx): 120 | return 1 121 | 122 | @default_drive.listen_group([a]) 123 | async def b(event: EventInput, global_ctx): 124 | return 2 125 | 126 | call_c_count = 0 127 | 128 | @default_drive.listen_group([a]) 129 | @default_drive.listen_group([b, a]) 130 | async def c(event: EventInput, global_ctx): 131 | nonlocal call_c_count 132 | if call_c_count == 0: 133 | assert event.group_name == "1" 134 | assert event.behavior == ReturnBehavior.DISPATCH 135 | assert event.results == {a.id: 1} 136 | elif call_c_count == 1: 137 | assert event.group_name == "0" 138 | assert event.behavior == ReturnBehavior.DISPATCH 139 | assert event.results == {a.id: 1, b.id: 2} 140 | else: 141 | assert False, "c should only be called twice" 142 | call_c_count += 1 143 | return 3 144 | 145 | result = await default_drive.invoke_event(a, None, {"test_ctx": 1}) 146 | print(result) 147 | assert call_c_count == 2 148 | 149 | 150 | @pytest.mark.asyncio 151 | async def test_loop(): 152 | call_a_count = 0 153 | 154 | @default_drive.make_event 155 | async def a(event: EventInput, global_ctx): 156 | nonlocal call_a_count 157 | if call_a_count == 0: 158 | pass 159 | elif call_a_count == 1: 160 | assert event.group_name == "0" 161 | assert event.behavior == ReturnBehavior.DISPATCH 162 | assert event.results == {b.id: 2} 163 | raise DeliberateExcepion() 164 | call_a_count += 1 165 | return 1 166 | 167 | @default_drive.listen_group([a]) 168 | async def b(event: EventInput, global_ctx): 169 | return 2 170 | 171 | a = default_drive.listen_group([b])(a) 172 | 173 | @default_drive.listen_group([a, b]) 174 | async def c(event: EventInput, global_ctx): 175 | return 3 176 | 177 | with pytest.raises(DeliberateExcepion): 178 | await default_drive.invoke_event(a, None, {"test_ctx": 1}) 179 | assert call_a_count == 1 180 | 181 | 182 | @pytest.mark.asyncio 183 | async def test_duplicate_events_not_send(): 184 | call_a_count = 0 185 | 186 | @default_drive.make_event 187 | async def start(event: EventInput, global_ctx): 188 | pass 189 | 190 | @default_drive.listen_group([start]) 191 | async def a(event: EventInput, global_ctx): 192 | nonlocal call_a_count 193 | if call_a_count <= 1: 194 | pass 195 | elif call_a_count == 2: 196 | return abort_this() 197 | call_a_count += 1 198 | return 1 199 | 200 | a = default_drive.listen_group([a])(a) # self loop 201 | 202 | @default_drive.listen_group([start]) 203 | async def b(event: EventInput, global_ctx): 204 | return 2 205 | 206 | call_c_count = 0 207 | 208 | @default_drive.listen_group([a, b]) 209 | async def c(event: EventInput, global_ctx): 210 | nonlocal call_c_count 211 | assert call_c_count < 1, "c should only be called once" 212 | call_c_count += 1 213 | print("Call C") 214 | return 3 215 | 216 | r = await default_drive.invoke_event(start, None, {"test_ctx": 1}) 217 | assert call_a_count == 2 218 | assert call_c_count == 1 219 | print({default_drive.get_event_from_id(k).repr_name: v for k, v in r.items()}) 220 | -------------------------------------------------------------------------------- /tests/test_types.py: -------------------------------------------------------------------------------- 1 | from drive_flow.types import BaseEvent, EventGroup, EventInput, ReturnBehavior 2 | 3 | 4 | def test_user_input(): 5 | fake_input = {"query": "Hello World"} 6 | a = EventInput.from_input(fake_input) 7 | assert a.results == fake_input 8 | assert a.behavior == ReturnBehavior.INPUT 9 | 10 | 11 | def test_node_hash(): 12 | def mock_a(): 13 | return 1 14 | 15 | def mock_b(): 16 | return 2 17 | 18 | n1 = BaseEvent(mock_a) 19 | n2 = BaseEvent(mock_a) 20 | n3 = BaseEvent(mock_b) 21 | assert n1.id == n2.id 22 | assert n1.id != n3.id 23 | 24 | 25 | def test_node_debug_print(): 26 | def mock_a(): 27 | return 1 28 | 29 | def mock_b(): 30 | return 2 31 | 32 | n1 = BaseEvent(mock_a) 33 | g1 = EventGroup("1", "hash-xxxxx", {n1.id: n1}) 34 | n2 = BaseEvent(mock_a, parent_groups={g1.hash(): g1}) 35 | g2 = EventGroup("2", "hash-yyyy", {n1.id: n1, n2.id: n2}) 36 | n3 = BaseEvent(mock_b, parent_groups={g1.hash(): g1, g2.hash(): g2}) 37 | 38 | print(n1, n1.debug_string()) 39 | print(n2, n2.debug_string()) 40 | print(n3, n3.debug_string()) 41 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from drive_flow import utils 3 | 4 | 5 | def test_func_to_string(): 6 | def my_func(a, b: str, c: float = 3.14) -> bool: 7 | print(a, b, c) 8 | return False 9 | 10 | func_string = """tests.test_utils.l_6 11 | def my_func(a, b: str, c: float = 3.14) -> bool: 12 | print(a, b, c) 13 | return False""" 14 | 15 | class Fool: 16 | def __call__(self, a, b, v): 17 | return a + b + v 18 | 19 | class_string = """tests.test_utils.l_16.Fool 20 | def __call__(self, a, b, v): 21 | return a + b + v""" 22 | 23 | fool_inst = Fool() 24 | assert utils.function_or_method_to_string(my_func) == func_string 25 | assert utils.function_or_method_to_string(fool_inst.__call__) == class_string 26 | 27 | 28 | def test_func_to_repr_string(): 29 | def my_func(a, b: str, c: float = 3.14) -> bool: 30 | print(a, b, c) 31 | return False 32 | 33 | func_string = """tests.test_utils.l_29.my_func""" 34 | 35 | class Fool: 36 | def __call__(self, a, b, v): 37 | return a + b + v 38 | 39 | class_string = """tests.test_utils.l_36.Fool.__call__""" 40 | 41 | fool_inst = Fool() 42 | assert utils.function_or_method_to_repr(my_func) == func_string 43 | assert utils.function_or_method_to_repr(fool_inst.__call__) == class_string 44 | 45 | 46 | def test_any_to_repr_string(): 47 | with pytest.raises(ValueError): 48 | utils.function_or_method_to_repr(123) 49 | 50 | with pytest.raises(ValueError): 51 | utils.function_or_method_to_string(123) 52 | --------------------------------------------------------------------------------