├── tests ├── __init__.py └── test_flow.py ├── water ├── exceptions.py ├── __init__.py ├── config.py ├── types.py ├── task.py ├── context.py ├── flow.py ├── execution_engine.py └── server.py ├── cookbook ├── playground.py ├── loop_flow.py ├── branched_flow.py ├── parallel_flow.py ├── sequential_flow.py └── agno_flow.py ├── pyproject.toml ├── README.md ├── .gitignore └── LICENSE /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for the Water framework.""" -------------------------------------------------------------------------------- /water/exceptions.py: -------------------------------------------------------------------------------- 1 | class WaterError(Exception): 2 | """Base exception for all Water framework errors.""" 3 | pass -------------------------------------------------------------------------------- /water/__init__.py: -------------------------------------------------------------------------------- 1 | from .flow import Flow 2 | from .task import create_task, Task 3 | from .server import FlowServer 4 | 5 | __all__ = ["Flow", "create_task", "Task", "FlowServer"] -------------------------------------------------------------------------------- /water/config.py: -------------------------------------------------------------------------------- 1 | """Configuration constants for the Water framework.""" 2 | 3 | class Config: 4 | """ 5 | Global configuration settings for the Water framework. 6 | 7 | Contains default values for execution parameters like loop iterations 8 | and timeout settings that can be referenced throughout the framework. 9 | """ 10 | 11 | # Loop settings 12 | DEFAULT_MAX_ITERATIONS: int = 100 13 | 14 | # Execution settings 15 | DEFAULT_TIMEOUT_SECONDS: int = 300 -------------------------------------------------------------------------------- /cookbook/playground.py: -------------------------------------------------------------------------------- 1 | from water import FlowServer 2 | from branched_flow import notification_flow 3 | from loop_flow import retry_flow 4 | from parallel_flow import send_notification_flow 5 | from sequential_flow import registration_flow 6 | 7 | # Create server with flows 8 | app = FlowServer(flows=[notification_flow, retry_flow, send_notification_flow, registration_flow]).get_app() 9 | 10 | if __name__ == "__main__": 11 | import uvicorn 12 | uvicorn.run("playground:app", host="0.0.0.0", port=8000, reload=True) -------------------------------------------------------------------------------- /water/types.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, List, Union 2 | from typing_extensions import TypedDict 3 | 4 | # Forward declaration for ExecutionContext 5 | from typing import TYPE_CHECKING 6 | if TYPE_CHECKING: 7 | from water.context import ExecutionContext 8 | 9 | # Type aliases 10 | InputData = Dict[str, Any] 11 | OutputData = Dict[str, Any] 12 | ConditionFunction = Callable[[InputData], bool] 13 | 14 | # Updated task execution function signature to include context 15 | TaskExecuteFunction = Callable[[Dict[str, InputData], 'ExecutionContext'], OutputData] 16 | 17 | # TypedDict definitions for node structures 18 | class SequentialNode(TypedDict): 19 | type: str 20 | task: Any # Will be Task when we import it 21 | 22 | class ParallelNode(TypedDict): 23 | type: str 24 | tasks: List[Any] # Will be List[Task] 25 | 26 | class BranchCondition(TypedDict): 27 | condition: ConditionFunction 28 | task: Any # Will be Task 29 | 30 | class BranchNode(TypedDict): 31 | type: str 32 | branches: List[BranchCondition] 33 | 34 | class LoopNode(TypedDict): 35 | type: str 36 | condition: ConditionFunction 37 | task: Any # Will be Task 38 | max_iterations: int 39 | 40 | # Union type for all node types 41 | ExecutionNode = Union[SequentialNode, ParallelNode, BranchNode, LoopNode] 42 | ExecutionGraph = List[ExecutionNode] -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "water-ai" 7 | version = "0.1.0" 8 | description = "A multi-agent orchestration framework that works with any agent framework" 9 | authors = [ 10 | {name = "Manthan Gupta", email = "manthangupta109@gmail.com"} 11 | ] 12 | keywords = [ 13 | "water-ai", 14 | "multi-agent", 15 | "orchestration", 16 | "llm", 17 | "large-language-model", 18 | "framework", 19 | "agents", 20 | "ai-agents", 21 | "workflow", 22 | "pipeline", 23 | "langchain", 24 | "crewai", 25 | "agno", 26 | "autogen", 27 | "distributed", 28 | "async", 29 | "coordination", 30 | "agent-framework", 31 | "multi-agent-systems" 32 | ] 33 | classifiers = [ 34 | "Development Status :: 4 - Beta", 35 | "Intended Audience :: Developers", 36 | "Intended Audience :: Science/Research", 37 | "License :: OSI Approved :: Apache Software License", 38 | "Operating System :: OS Independent", 39 | "Programming Language :: Python :: 3", 40 | "Programming Language :: Python :: 3.8", 41 | "Programming Language :: Python :: 3.9", 42 | "Programming Language :: Python :: 3.10", 43 | "Programming Language :: Python :: 3.11", 44 | "Programming Language :: Python :: 3.12", 45 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 46 | "Topic :: Software Development :: Libraries :: Python Modules", 47 | "Topic :: System :: Distributed Computing", 48 | "Framework :: FastAPI", 49 | "Framework :: AsyncIO", 50 | ] 51 | readme = "README.md" 52 | requires-python = ">=3.8" 53 | dependencies = [ 54 | "pydantic>=2.0.0", 55 | "fastapi>=0.104.0", 56 | "uvicorn[standard]>=0.24.0", 57 | ] 58 | 59 | [project.optional-dependencies] 60 | dev = [ 61 | "pytest>=7.0.0", 62 | "pytest-asyncio>=0.21.0", 63 | ] 64 | 65 | [project.urls] 66 | Homepage = "https://github.com/manthanguptaa/water" 67 | Repository = "https://github.com/manthanguptaa/water" 68 | 69 | [tool.setuptools.packages.find] 70 | where = ["."] 71 | include = ["water*"] 72 | exclude = ["tests*", "cookbook*"] 73 | 74 | [tool.pytest.ini_options] 75 | asyncio_mode = "auto" 76 | testpaths = ["tests"] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Water 2 | **A multi-agent orchestration framework that works with any agent framework.** 3 | 4 | [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) 5 | 6 | ## Overview 7 | Water is a production ready orchestration framework that enables developers to build complex multi-agent systems without being locked into specific agent framework. Whether you are using LangChain, CrewAI, Agno, or custom agents, Water provides the orchestration layer to coordinate and scale your multi-agent workflows. 8 | 9 | ### Key Features 10 | 11 | - **Framework Agnostic** - Integrate any agent framework or custom implementation 12 | - **Flexible Workflows** - Orchestrate complex multi-agent interactions with simple Python 13 | - **Playground** - Run lightweight FastAPI server with all the flows you define 14 | 15 | ## Quick Start 16 | ### Installation 17 | 18 | ```bash 19 | pip install water-ai 20 | ``` 21 | 22 | ### Basic Usage 23 | 24 | ```python 25 | import asyncio 26 | from water import Flow, create_task 27 | from pydantic import BaseModel 28 | 29 | class NumberInput(BaseModel): 30 | value: int 31 | 32 | class NumberOutput(BaseModel): 33 | result: int 34 | 35 | def add_five(params, context): 36 | return {"result": params["input_data"]["value"] + 5} 37 | 38 | math_task = create_task( 39 | id="math_task", 40 | description="Math task", 41 | input_schema=NumberInput, 42 | output_schema=NumberOutput, 43 | execute=add_five 44 | ) 45 | 46 | flow = Flow(id="my_flow", description="My flow").then(math_task).register() 47 | 48 | async def main(): 49 | result = await flow.run({"value": 10}) 50 | print(result) 51 | 52 | if __name__ == "__main__": 53 | asyncio.run(main()) 54 | ``` 55 | 56 | ## Contributing 57 | 58 | We welcome contributions from the community! 59 | 60 | - Bug reports and feature requests 61 | - Code contributions 62 | - Documentation improvements 63 | - Testing and quality assurance 64 | 65 | ## Roadmap 66 | 67 | - Storage layer to store flow sessions and task runs 68 | - Human in the loop support 69 | - Retry mechanism for individual tasks 70 | 71 | ## License 72 | 73 | Water is licensed under the Apache License 2.0. See the [LICENSE](LICENSE) file for details. 74 | -------------------------------------------------------------------------------- /cookbook/loop_flow.py: -------------------------------------------------------------------------------- 1 | """ 2 | Loop Flow Example: Notification Retry with Backoff 3 | 4 | This example demonstrates a loop flow that retries a flaky notification 5 | service until success or max attempts reached. Shows how .loop() continues 6 | execution based on conditions and tracks state between iterations. 7 | """ 8 | 9 | from water import Flow, create_task 10 | from pydantic import BaseModel 11 | from typing import Dict, Any 12 | import asyncio 13 | import random 14 | import time 15 | 16 | # Data schemas 17 | class RetryState(BaseModel): 18 | user_id: str 19 | message: str 20 | max_attempts: int = 3 21 | attempt: int = 0 22 | success: bool = False 23 | error: str = "" 24 | should_retry: bool = True 25 | 26 | # Retry task (simulates flaky service) 27 | def attempt_notification(params: Dict[str, Any], context) -> Dict[str, Any]: 28 | """Attempt to send notification (simulates 70% failure rate).""" 29 | data = params["input_data"] 30 | 31 | # Increment attempt counter 32 | current_attempt = data.get("attempt", 0) + 1 33 | 34 | # Add backoff delay for retries 35 | if current_attempt > 1: 36 | time.sleep(0.5 * current_attempt) 37 | 38 | # Simulate flaky service (60% success rate) 39 | success = random.random() < 0.6 40 | 41 | result = { 42 | "user_id": data["user_id"], 43 | "message": data["message"], 44 | "max_attempts": data["max_attempts"], 45 | "attempt": current_attempt, 46 | "success": success, 47 | "error": "" if success else f"Network timeout on attempt {current_attempt}", 48 | "should_retry": not success and current_attempt < data["max_attempts"] 49 | } 50 | 51 | return result 52 | 53 | # Create task 54 | notification_task = create_task( 55 | id="attempt_notification", 56 | description="Attempt notification with retry", 57 | input_schema=RetryState, 58 | output_schema=RetryState, 59 | execute=attempt_notification 60 | ) 61 | 62 | # Loop flow with retry logic 63 | retry_flow = Flow(id="notification_retry", description="Notification retry flow") 64 | retry_flow.loop( 65 | task=notification_task, 66 | condition=lambda result: result.get("should_retry", False) 67 | ).register() 68 | 69 | async def main(): 70 | """Run the loop retry flow example.""" 71 | 72 | request = { 73 | "user_id": "user_001", 74 | "message": "Welcome to Water Framework!", 75 | "max_attempts": 3, 76 | "attempt": 0, 77 | "success": False, 78 | "error": "", 79 | "should_retry": True 80 | } 81 | 82 | try: 83 | result = await retry_flow.run(request) 84 | 85 | if result.get("success"): 86 | print(f"Notification sent successfully on attempt {result['attempt']}") 87 | else: 88 | print(f"Notification failed after {result['attempt']} attempts") 89 | 90 | except Exception as e: 91 | print(f"Error: {e}") 92 | 93 | if __name__ == "__main__": 94 | asyncio.run(main()) -------------------------------------------------------------------------------- /water/task.py: -------------------------------------------------------------------------------- 1 | from typing import Type, Callable, Optional, Dict 2 | from pydantic import BaseModel 3 | from water.exceptions import WaterError 4 | import uuid 5 | 6 | from water.types import InputData, OutputData 7 | 8 | # Import here to avoid circular imports 9 | from typing import TYPE_CHECKING 10 | if TYPE_CHECKING: 11 | from water.context import ExecutionContext 12 | 13 | class Task: 14 | """ 15 | A single executable unit within a Water flow. 16 | 17 | Tasks define input/output schemas using Pydantic models and contain 18 | an execute function that processes data. Tasks can be synchronous 19 | or asynchronous. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | input_schema: Type[BaseModel], 25 | output_schema: Type[BaseModel], 26 | execute: Callable[[Dict[str, InputData], 'ExecutionContext'], OutputData], 27 | id: Optional[str] = None, 28 | description: Optional[str] = None 29 | ) -> None: 30 | """ 31 | Initialize a new Task. 32 | 33 | Args: 34 | input_schema: Pydantic BaseModel class defining expected input structure 35 | output_schema: Pydantic BaseModel class defining output structure 36 | execute: Function that processes input data and returns output 37 | id: Unique identifier for the task. Auto-generated if not provided. 38 | description: Human-readable description of the task's purpose 39 | 40 | Raises: 41 | WaterError: If schemas are not Pydantic BaseModel classes or execute is not callable 42 | """ 43 | self.id: str = id if id else f"task_{uuid.uuid4().hex[:8]}" 44 | self.description: str = description if description else f"Task {self.id}" 45 | 46 | # Validate schemas are Pydantic BaseModel classes 47 | if not input_schema or not (isinstance(input_schema, type) and issubclass(input_schema, BaseModel)): 48 | raise WaterError("input_schema must be a Pydantic BaseModel class") 49 | if not output_schema or not (isinstance(output_schema, type) and issubclass(output_schema, BaseModel)): 50 | raise WaterError("output_schema must be a Pydantic BaseModel class") 51 | if not execute or not callable(execute): 52 | raise WaterError("Task must have a callable execute function") 53 | 54 | self.input_schema = input_schema 55 | self.output_schema = output_schema 56 | self.execute = execute 57 | 58 | 59 | def create_task( 60 | id: Optional[str] = None, 61 | description: Optional[str] = None, 62 | input_schema: Optional[Type[BaseModel]] = None, 63 | output_schema: Optional[Type[BaseModel]] = None, 64 | execute: Optional[Callable[[Dict[str, InputData], 'ExecutionContext'], OutputData]] = None 65 | ) -> Task: 66 | """ 67 | Factory function to create a Task instance. 68 | """ 69 | return Task( 70 | input_schema=input_schema, 71 | output_schema=output_schema, 72 | execute=execute, 73 | id=id, 74 | description=description, 75 | ) -------------------------------------------------------------------------------- /cookbook/branched_flow.py: -------------------------------------------------------------------------------- 1 | """ 2 | Branched Flow Example: Conditional Notification Flow 3 | 4 | This example demonstrates a branched flow where notifications are sent 5 | based on user preferences. Shows how .branch() executes only the first 6 | matching condition, even when multiple conditions would match. 7 | """ 8 | 9 | from water import Flow, create_task 10 | from pydantic import BaseModel 11 | from typing import Dict, Any 12 | import asyncio 13 | 14 | # Data schemas 15 | class UserPreferences(BaseModel): 16 | user_id: str 17 | email_enabled: bool 18 | sms_enabled: bool 19 | whatsapp_enabled: bool 20 | 21 | class NotificationSent(BaseModel): 22 | user_id: str 23 | channel: str 24 | sent: bool 25 | 26 | # Notification tasks 27 | def send_email(params: Dict[str, Any], context) -> Dict[str, Any]: 28 | """Send email notification.""" 29 | user = params["input_data"] 30 | return { 31 | "user_id": user["user_id"], 32 | "channel": "email", 33 | "sent": True 34 | } 35 | 36 | def send_sms(params: Dict[str, Any], context) -> Dict[str, Any]: 37 | """Send SMS notification.""" 38 | user = params["input_data"] 39 | return { 40 | "user_id": user["user_id"], 41 | "channel": "sms", 42 | "sent": True 43 | } 44 | 45 | def send_whatsapp(params: Dict[str, Any], context) -> Dict[str, Any]: 46 | """Send WhatsApp notification.""" 47 | user = params["input_data"] 48 | return { 49 | "user_id": user["user_id"], 50 | "channel": "whatsapp", 51 | "sent": True 52 | } 53 | 54 | # Create tasks 55 | email_task = create_task( 56 | id="email", 57 | description="Send email", 58 | input_schema=UserPreferences, 59 | output_schema=NotificationSent, 60 | execute=send_email 61 | ) 62 | 63 | sms_task = create_task( 64 | id="sms", 65 | description="Send SMS", 66 | input_schema=UserPreferences, 67 | output_schema=NotificationSent, 68 | execute=send_sms 69 | ) 70 | 71 | whatsapp_task = create_task( 72 | id="whatsapp", 73 | description="Send WhatsApp", 74 | input_schema=UserPreferences, 75 | output_schema=NotificationSent, 76 | execute=send_whatsapp 77 | ) 78 | 79 | # Branched notification flow 80 | notification_flow = Flow(id="conditional_notifications", description="Conditional notification flow") 81 | notification_flow.branch([ 82 | (lambda data: data.get("email_enabled", False), email_task), 83 | (lambda data: data.get("sms_enabled", False), sms_task), 84 | (lambda data: data.get("whatsapp_enabled", False), whatsapp_task) 85 | ]).register() 86 | 87 | async def main(): 88 | """Run the branched notification flow example.""" 89 | 90 | user = { 91 | "user_id": "user_001", 92 | "email_enabled": True, 93 | "sms_enabled": True, 94 | "whatsapp_enabled": False 95 | } 96 | 97 | print(f"User preferences: email={user['email_enabled']}, sms={user['sms_enabled']}, whatsapp={user['whatsapp_enabled']}") 98 | 99 | try: 100 | result = await notification_flow.run(user) 101 | print(f"Notification sent via: {result['channel']}") 102 | print(f"First matching condition executed (email has priority)") 103 | 104 | except Exception as e: 105 | print(f"ERROR - {e}") 106 | 107 | if __name__ == "__main__": 108 | asyncio.run(main()) -------------------------------------------------------------------------------- /cookbook/parallel_flow.py: -------------------------------------------------------------------------------- 1 | """ 2 | Parallel Flow Example: Send Notification Flow 3 | 4 | This example demonstrates a parallel flow where multiple welcome notifications 5 | are sent simultaneously to a newly registered user. Shows how .parallel() executes 6 | independent tasks concurrently for improved performance. 7 | """ 8 | 9 | from water import Flow, create_task 10 | from pydantic import BaseModel 11 | from typing import Dict, Any 12 | import asyncio 13 | import uuid 14 | 15 | # Data schemas 16 | class UserData(BaseModel): 17 | user_id: str 18 | email: str 19 | phone: str 20 | first_name: str 21 | 22 | class NotificationResult(BaseModel): 23 | channel: str 24 | user_id: str 25 | message_id: str 26 | sent: bool 27 | delivery_time: float 28 | 29 | class ParallelResults(BaseModel): 30 | email: NotificationResult 31 | sms: NotificationResult 32 | whatsapp: NotificationResult 33 | 34 | class Summary(BaseModel): 35 | user_id: str 36 | notifications_sent: int 37 | total_time: float 38 | success: bool 39 | 40 | # Parallel Task 1: Send Email 41 | def send_email(params: Dict[str, Any], context) -> Dict[str, Any]: 42 | """Send welcome email to the user.""" 43 | user = params["input_data"] 44 | 45 | return { 46 | "channel": "email", 47 | "user_id": user["user_id"], 48 | "message_id": f"email_{uuid.uuid4().hex[:8]}", 49 | "sent": True, 50 | "delivery_time": 1.2 51 | } 52 | 53 | # Parallel Task 2: Send SMS 54 | def send_sms(params: Dict[str, Any], context) -> Dict[str, Any]: 55 | """Send welcome SMS to the user.""" 56 | user = params["input_data"] 57 | 58 | return { 59 | "channel": "sms", 60 | "user_id": user["user_id"], 61 | "message_id": f"sms_{uuid.uuid4().hex[:8]}", 62 | "sent": True, 63 | "delivery_time": 0.8 64 | } 65 | 66 | # Parallel Task 3: Send WhatsApp 67 | def send_whatsapp(params: Dict[str, Any], context) -> Dict[str, Any]: 68 | """Send welcome WhatsApp message to the user.""" 69 | user = params["input_data"] 70 | 71 | return { 72 | "channel": "whatsapp", 73 | "user_id": user["user_id"], 74 | "message_id": f"whatsapp_{uuid.uuid4().hex[:8]}", 75 | "sent": True, 76 | "delivery_time": 1.5 77 | } 78 | 79 | # Aggregation Task: Summarize Results 80 | def summarize_results(params: Dict[str, Any], context) -> Dict[str, Any]: 81 | """Aggregate results from all parallel notification tasks.""" 82 | results = params["input_data"] 83 | 84 | # Count successful notifications 85 | sent_count = sum(1 for result in results.values() if result["sent"]) 86 | 87 | # Get maximum time (since they ran in parallel) 88 | max_time = max(result["delivery_time"] for result in results.values()) 89 | 90 | return { 91 | "user_id": list(results.values())[0]["user_id"], 92 | "notifications_sent": sent_count, 93 | "total_time": max_time, 94 | "success": sent_count == 3 95 | } 96 | 97 | # Create tasks 98 | email_task = create_task( 99 | id="email", 100 | description="Send email notification", 101 | input_schema=UserData, 102 | output_schema=NotificationResult, 103 | execute=send_email 104 | ) 105 | 106 | sms_task = create_task( 107 | id="sms", 108 | description="Send SMS notification", 109 | input_schema=UserData, 110 | output_schema=NotificationResult, 111 | execute=send_sms 112 | ) 113 | 114 | whatsapp_task = create_task( 115 | id="whatsapp", 116 | description="Send WhatsApp notification", 117 | input_schema=UserData, 118 | output_schema=NotificationResult, 119 | execute=send_whatsapp 120 | ) 121 | 122 | summary_task = create_task( 123 | id="summary", 124 | description="Summarize notification results", 125 | input_schema=ParallelResults, 126 | output_schema=Summary, 127 | execute=summarize_results 128 | ) 129 | 130 | # Parallel send notification flow 131 | send_notification_flow = Flow(id="send_notifications", description="Parallel send notification flow") 132 | send_notification_flow.parallel([ 133 | email_task, 134 | sms_task, 135 | whatsapp_task 136 | ]).then(summary_task).register() 137 | 138 | async def main(): 139 | """Run the send notification flow example.""" 140 | 141 | user = { 142 | "user_id": "user_abc123", 143 | "email": "manthan.gupta@water.ai", 144 | "phone": "+1234567890", 145 | "first_name": "Manthan" 146 | } 147 | 148 | try: 149 | result = await send_notification_flow.run(user) 150 | print(result) 151 | print("flow completed successfully!") 152 | except Exception as e: 153 | print(f"ERROR - {e}") 154 | 155 | if __name__ == "__main__": 156 | asyncio.run(main()) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .envrc 133 | .venv 134 | env/ 135 | venv/ 136 | ENV/ 137 | env.bak/ 138 | venv.bak/ 139 | 140 | # Spyder project settings 141 | .spyderproject 142 | .spyproject 143 | 144 | # Rope project settings 145 | .ropeproject 146 | 147 | # mkdocs documentation 148 | /site 149 | 150 | # mypy 151 | .mypy_cache/ 152 | .dmypy.json 153 | dmypy.json 154 | 155 | # Pyre type checker 156 | .pyre/ 157 | 158 | # pytype static type analyzer 159 | .pytype/ 160 | 161 | # Cython debug symbols 162 | cython_debug/ 163 | 164 | # PyCharm 165 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 166 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 167 | # and can be added to the global gitignore or merged into this file. For a more nuclear 168 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 169 | #.idea/ 170 | 171 | # Abstra 172 | # Abstra is an AI-powered process automation framework. 173 | # Ignore directories containing user credentials, local state, and settings. 174 | # Learn more at https://abstra.io/docs 175 | .abstra/ 176 | 177 | # Visual Studio Code 178 | # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore 179 | # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore 180 | # and can be added to the global gitignore or merged into this file. However, if you prefer, 181 | # you could uncomment the following to ignore the enitre vscode folder 182 | # .vscode/ 183 | 184 | # Ruff stuff: 185 | .ruff_cache/ 186 | 187 | # PyPI configuration file 188 | .pypirc 189 | 190 | # Cursor 191 | # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to 192 | # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data 193 | # refer to https://docs.cursor.com/context/ignore-files 194 | .cursorignore 195 | .cursorindexingignore -------------------------------------------------------------------------------- /cookbook/sequential_flow.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sequential Flow Example: User Registration Pipeline 3 | 4 | This example demonstrates a sequential flow where each step must complete 5 | before the next can begin. Shows how .then() creates clear dependencies and 6 | how context preserves data across the pipeline. 7 | """ 8 | 9 | from water import Flow, create_task 10 | from pydantic import BaseModel 11 | from typing import Dict, Any 12 | import asyncio 13 | import uuid 14 | 15 | # Data schemas 16 | class UserRequest(BaseModel): 17 | email: str 18 | password: str 19 | first_name: str 20 | 21 | class ValidationResult(BaseModel): 22 | email: str 23 | first_name: str 24 | is_valid: bool 25 | errors: list 26 | 27 | class AccountResult(BaseModel): 28 | email: str 29 | first_name: str 30 | user_id: str 31 | account_created: bool 32 | 33 | class ProfileResult(BaseModel): 34 | user_id: str 35 | profile_id: str 36 | profile_created: bool 37 | 38 | class RegistrationSummary(BaseModel): 39 | user_id: str 40 | profile_id: str 41 | email: str 42 | registration_complete: bool 43 | total_time: float 44 | 45 | # Step 1: Validate Input 46 | def validate_input(params: Dict[str, Any], context) -> Dict[str, Any]: 47 | """Validate user registration data.""" 48 | request = params["input_data"] 49 | errors = [] 50 | 51 | if "@" not in request["email"]: 52 | errors.append("Invalid email") 53 | if len(request["password"]) < 6: 54 | errors.append("Password too short") 55 | if not request["first_name"]: 56 | errors.append("Name required") 57 | 58 | return { 59 | "email": request["email"], 60 | "first_name": request["first_name"], 61 | "is_valid": len(errors) == 0, 62 | "errors": errors 63 | } 64 | 65 | # Step 2: Create Account 66 | def create_account(params: Dict[str, Any], context) -> Dict[str, Any]: 67 | """Create user account - depends on successful validation.""" 68 | current_data = params["input_data"] 69 | 70 | if not current_data["is_valid"]: 71 | return { 72 | "email": current_data["email"], 73 | "first_name": current_data["first_name"], 74 | "user_id": "", 75 | "account_created": False 76 | } 77 | 78 | # Generate user ID 79 | user_id = f"user_{uuid.uuid4().hex[:8]}" 80 | 81 | return { 82 | "email": current_data["email"], 83 | "first_name": current_data["first_name"], 84 | "user_id": user_id, 85 | "account_created": True 86 | } 87 | 88 | # Step 3: Setup Profile 89 | def setup_profile(params: Dict[str, Any], context) -> Dict[str, Any]: 90 | """Setup user profile - depends on account creation.""" 91 | current_data = params["input_data"] 92 | 93 | # Access account creation results from context 94 | account = context.get_task_output("account") 95 | 96 | if not current_data["account_created"]: 97 | return { 98 | "user_id": current_data["user_id"], 99 | "profile_id": "", 100 | "profile_created": False 101 | } 102 | 103 | # Generate profile ID 104 | profile_id = f"profile_{uuid.uuid4().hex[:8]}" 105 | 106 | return { 107 | "user_id": current_data["user_id"], 108 | "profile_id": profile_id, 109 | "profile_created": True 110 | } 111 | 112 | # Step 4: Complete Registration 113 | def complete_registration(params: Dict[str, Any], context) -> Dict[str, Any]: 114 | """Complete registration - depends on all previous steps.""" 115 | current_data = params["input_data"] 116 | 117 | # Access all previous step results from context 118 | validation = context.get_task_output("validate") 119 | 120 | return { 121 | "user_id": current_data["user_id"], 122 | "profile_id": current_data["profile_id"], 123 | "email": validation["email"], 124 | "registration_complete": current_data["profile_created"] 125 | } 126 | 127 | # Create tasks 128 | validate_task = create_task( 129 | id="validate", 130 | description="Validate registration input", 131 | input_schema=UserRequest, 132 | output_schema=ValidationResult, 133 | execute=validate_input 134 | ) 135 | 136 | account_task = create_task( 137 | id="account", 138 | description="Create user account", 139 | input_schema=ValidationResult, 140 | output_schema=AccountResult, 141 | execute=create_account 142 | ) 143 | 144 | profile_task = create_task( 145 | id="profile", 146 | description="Setup user profile", 147 | input_schema=AccountResult, 148 | output_schema=ProfileResult, 149 | execute=setup_profile 150 | ) 151 | 152 | complete_task = create_task( 153 | id="complete", 154 | description="Complete registration", 155 | input_schema=ProfileResult, 156 | output_schema=RegistrationSummary, 157 | execute=complete_registration 158 | ) 159 | 160 | # Sequential registration flow - each step depends on the previous 161 | registration_flow = Flow(id="user_registration", description="Sequential user registration pipeline") 162 | registration_flow.then(validate_task)\ 163 | .then(account_task)\ 164 | .then(profile_task)\ 165 | .then(complete_task)\ 166 | .register() 167 | 168 | async def main(): 169 | """Run the sequential registration example.""" 170 | 171 | user_request = { 172 | "email": "manthan.gupta@water.ai", 173 | "password": "SecurePass123", 174 | "first_name": "Manthan" 175 | } 176 | 177 | try: 178 | result = await registration_flow.run(user_request) 179 | 180 | if result["registration_complete"]: 181 | print(result) 182 | else: 183 | print("FAILED - Registration incomplete") 184 | print("flow completed successfully!") 185 | except Exception as e: 186 | print(f"ERROR - {e}") 187 | 188 | if __name__ == "__main__": 189 | asyncio.run(main()) -------------------------------------------------------------------------------- /water/context.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, List 2 | from datetime import datetime 3 | import uuid 4 | 5 | from water.types import OutputData 6 | 7 | class ExecutionContext: 8 | """ 9 | Execution context passed to every task containing metadata and execution state. 10 | 11 | The context provides access to flow metadata, execution timing, task outputs, 12 | and execution history. It enables tasks to access data from previous steps 13 | and maintain state throughout the flow execution. 14 | """ 15 | 16 | def __init__( 17 | self, 18 | flow_id: str, 19 | execution_id: Optional[str] = None, 20 | task_id: Optional[str] = None, 21 | step_number: int = 0, 22 | attempt_number: int = 1, 23 | flow_metadata: Optional[Dict[str, Any]] = None 24 | ) -> None: 25 | """ 26 | Initialize execution context. 27 | 28 | Args: 29 | flow_id: Unique identifier of the executing flow 30 | execution_id: Unique identifier for this execution instance 31 | task_id: Current task identifier 32 | step_number: Current step number in the execution 33 | attempt_number: Attempt number for retry scenarios 34 | flow_metadata: Metadata associated with the flow 35 | """ 36 | self.flow_id = flow_id 37 | self.execution_id = execution_id or f"exec_{uuid.uuid4().hex[:8]}" 38 | self.task_id = task_id 39 | self.step_number = step_number 40 | self.attempt_number = attempt_number 41 | self.flow_metadata = flow_metadata or {} 42 | 43 | # Timing information 44 | self.execution_start_time = datetime.utcnow() 45 | self.step_start_time = datetime.utcnow() 46 | 47 | # Task outputs history 48 | self._task_outputs: Dict[str, OutputData] = {} 49 | self._step_history: List[Dict[str, Any]] = [] 50 | 51 | def add_task_output(self, task_id: str, output: OutputData) -> None: 52 | """ 53 | Record the output of a completed task. 54 | 55 | Args: 56 | task_id: Identifier of the completed task 57 | output: Output data from the task 58 | """ 59 | self._task_outputs[task_id] = output 60 | 61 | step_info = { 62 | "step_number": self.step_number, 63 | "task_id": task_id, 64 | "output": output, 65 | "timestamp": datetime.utcnow().isoformat(), 66 | "attempt_number": self.attempt_number 67 | } 68 | self._step_history.append(step_info) 69 | 70 | def get_task_output(self, task_id: str) -> Optional[OutputData]: 71 | """ 72 | Get the output from a previously executed task. 73 | 74 | Args: 75 | task_id: Identifier of the task whose output to retrieve 76 | 77 | Returns: 78 | Task output data, or None if task hasn't executed 79 | """ 80 | return self._task_outputs.get(task_id) 81 | 82 | def get_all_task_outputs(self) -> Dict[str, OutputData]: 83 | """ 84 | Get all task outputs from this execution. 85 | 86 | Returns: 87 | Dictionary mapping task IDs to their output data 88 | """ 89 | return self._task_outputs.copy() 90 | 91 | def get_step_history(self) -> List[Dict[str, Any]]: 92 | """ 93 | Get the complete step execution history. 94 | 95 | Returns: 96 | List of step execution records with timestamps and outputs 97 | """ 98 | return self._step_history.copy() 99 | 100 | def create_child_context( 101 | self, 102 | task_id: str, 103 | step_number: Optional[int] = None, 104 | attempt_number: int = 1 105 | ) -> 'ExecutionContext': 106 | """ 107 | Create a new context for a child task execution. 108 | 109 | Inherits the current context state while updating task-specific fields. 110 | 111 | Args: 112 | task_id: Identifier for the child task 113 | step_number: Step number for the child execution 114 | attempt_number: Attempt number for retry scenarios 115 | 116 | Returns: 117 | New ExecutionContext instance for the child task 118 | """ 119 | child_context = ExecutionContext( 120 | flow_id=self.flow_id, 121 | execution_id=self.execution_id, 122 | task_id=task_id, 123 | step_number=step_number or (self.step_number + 1), 124 | attempt_number=attempt_number, 125 | flow_metadata=self.flow_metadata 126 | ) 127 | 128 | # Copy task outputs and history to child 129 | child_context._task_outputs = self._task_outputs.copy() 130 | child_context._step_history = self._step_history.copy() 131 | child_context.execution_start_time = self.execution_start_time 132 | 133 | return child_context 134 | 135 | def to_dict(self) -> Dict[str, Any]: 136 | """ 137 | Convert context to dictionary for serialization. 138 | 139 | Returns: 140 | Dictionary representation of the execution context 141 | """ 142 | return { 143 | "flow_id": self.flow_id, 144 | "execution_id": self.execution_id, 145 | "task_id": self.task_id, 146 | "step_number": self.step_number, 147 | "attempt_number": self.attempt_number, 148 | "flow_metadata": self.flow_metadata, 149 | "execution_start_time": self.execution_start_time.isoformat(), 150 | "step_start_time": self.step_start_time.isoformat(), 151 | "task_outputs": self._task_outputs, 152 | "step_history": self._step_history 153 | } 154 | 155 | def __repr__(self) -> str: 156 | """String representation of the execution context.""" 157 | return ( 158 | f"ExecutionContext(flow_id='{self.flow_id}', " 159 | f"execution_id='{self.execution_id}', " 160 | f"task_id='{self.task_id}', " 161 | f"step={self.step_number}, " 162 | f"attempt={self.attempt_number})" 163 | ) -------------------------------------------------------------------------------- /water/flow.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Optional, Tuple, Dict 2 | import inspect 3 | import uuid 4 | 5 | from water.execution_engine import ExecutionEngine, NodeType 6 | from water.types import ( 7 | InputData, 8 | OutputData, 9 | ConditionFunction, 10 | ExecutionNode 11 | ) 12 | 13 | class Flow: 14 | """ 15 | A workflow orchestrator that allows building and executing complex data processing pipelines. 16 | 17 | Flows support sequential execution, parallel processing, conditional branching, and loops. 18 | All flows must be registered before execution. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | id: Optional[str] = None, 24 | description: Optional[str] = None 25 | ) -> None: 26 | """ 27 | Initialize a new Flow. 28 | 29 | Args: 30 | id: Unique identifier for the flow. Auto-generated if not provided. 31 | description: Human-readable description of the flow's purpose. 32 | """ 33 | self.id: str = id if id else f"flow_{uuid.uuid4().hex[:8]}" 34 | self.description: str = description if description else f"Flow {self.id}" 35 | self._tasks: List[ExecutionNode] = [] 36 | self._registered: bool = False 37 | self.metadata: Dict[str, Any] = {} 38 | 39 | def _validate_registration_state(self) -> None: 40 | """Ensure flow is not registered when adding tasks.""" 41 | if self._registered: 42 | raise RuntimeError("Cannot add tasks after registration") 43 | 44 | def _validate_task(self, task: Any) -> None: 45 | """Validate that a task is not None.""" 46 | if task is None: 47 | raise ValueError("Task cannot be None") 48 | 49 | def _validate_condition(self, condition: ConditionFunction) -> None: 50 | """Validate that a condition function is not async.""" 51 | if inspect.iscoroutinefunction(condition): 52 | raise ValueError("Branch conditions cannot be async functions") 53 | 54 | def _validate_loop_condition(self, condition: ConditionFunction) -> None: 55 | """Validate that a loop condition function is not async.""" 56 | if inspect.iscoroutinefunction(condition): 57 | raise ValueError("Loop conditions cannot be async functions") 58 | 59 | def set_metadata(self, key: str, value: Any) -> 'Flow': 60 | """ 61 | Set metadata for this flow. 62 | 63 | Args: 64 | key: The metadata key 65 | value: The metadata value 66 | 67 | Returns: 68 | Self for method chaining 69 | """ 70 | self.metadata[key] = value 71 | return self 72 | 73 | def then(self, task: Any) -> 'Flow': 74 | """ 75 | Add a task to execute sequentially. 76 | 77 | Args: 78 | task: The task to execute 79 | 80 | Returns: 81 | Self for method chaining 82 | 83 | Raises: 84 | RuntimeError: If flow is already registered 85 | ValueError: If task is None 86 | """ 87 | self._validate_registration_state() 88 | self._validate_task(task) 89 | 90 | node: ExecutionNode = {"type": NodeType.SEQUENTIAL.value, "task": task} 91 | self._tasks.append(node) 92 | return self 93 | 94 | def parallel(self, tasks: List[Any]) -> 'Flow': 95 | """ 96 | Add tasks to execute in parallel. 97 | 98 | Args: 99 | tasks: List of tasks to execute concurrently 100 | 101 | Returns: 102 | Self for method chaining 103 | 104 | Raises: 105 | RuntimeError: If flow is already registered 106 | ValueError: If task list is empty or contains None values 107 | """ 108 | self._validate_registration_state() 109 | if not tasks: 110 | raise ValueError("Parallel task list cannot be empty") 111 | 112 | for task in tasks: 113 | self._validate_task(task) 114 | 115 | node: ExecutionNode = { 116 | "type": NodeType.PARALLEL.value, 117 | "tasks": list(tasks) 118 | } 119 | self._tasks.append(node) 120 | return self 121 | 122 | def branch(self, branches: List[Tuple[ConditionFunction, Any]]) -> 'Flow': 123 | """ 124 | Add conditional branching logic. 125 | 126 | Executes the first task whose condition returns True. 127 | If no conditions match, data passes through unchanged. 128 | 129 | Args: 130 | branches: List of (condition_function, task) tuples 131 | 132 | Returns: 133 | Self for method chaining 134 | 135 | Raises: 136 | RuntimeError: If flow is already registered 137 | ValueError: If branch list is empty, task is None, or condition is async 138 | """ 139 | self._validate_registration_state() 140 | if not branches: 141 | raise ValueError("Branch list cannot be empty") 142 | 143 | for condition, task in branches: 144 | self._validate_task(task) 145 | self._validate_condition(condition) 146 | 147 | node: ExecutionNode = { 148 | "type": NodeType.BRANCH.value, 149 | "branches": [{"condition": cond, "task": task} for cond, task in branches] 150 | } 151 | self._tasks.append(node) 152 | return self 153 | 154 | def loop( 155 | self, 156 | condition: ConditionFunction, 157 | task: Any, 158 | max_iterations: int = 100 159 | ) -> 'Flow': 160 | """ 161 | Execute a task repeatedly while a condition is true. 162 | 163 | Args: 164 | condition: Function that returns True to continue looping 165 | task: Task to execute on each iteration 166 | max_iterations: Maximum number of iterations to prevent infinite loops 167 | 168 | Returns: 169 | Self for method chaining 170 | 171 | Raises: 172 | RuntimeError: If flow is already registered 173 | ValueError: If task is None or condition is async 174 | """ 175 | self._validate_registration_state() 176 | self._validate_task(task) 177 | self._validate_loop_condition(condition) 178 | 179 | node: ExecutionNode = { 180 | "type": NodeType.LOOP.value, 181 | "condition": condition, 182 | "task": task, 183 | "max_iterations": max_iterations 184 | } 185 | self._tasks.append(node) 186 | return self 187 | 188 | def register(self) -> 'Flow': 189 | """ 190 | Register the flow for execution. 191 | 192 | Must be called before running the flow. 193 | Once registered, no more tasks can be added. 194 | 195 | Returns: 196 | Self for method chaining 197 | 198 | Raises: 199 | ValueError: If flow has no tasks 200 | """ 201 | if not self._tasks: 202 | raise ValueError("Flow must have at least one task") 203 | self._registered = True 204 | return self 205 | 206 | async def run(self, input_data: InputData) -> OutputData: 207 | """ 208 | Execute the flow with the provided input data. 209 | 210 | Args: 211 | input_data: Input data dictionary to process 212 | 213 | Returns: 214 | The final output data after all tasks complete 215 | 216 | Raises: 217 | RuntimeError: If flow is not registered 218 | """ 219 | if not self._registered: 220 | raise RuntimeError("Flow must be registered before running") 221 | 222 | return await ExecutionEngine.run( 223 | self._tasks, 224 | input_data, 225 | flow_id=self.id, 226 | flow_metadata=self.metadata 227 | ) 228 | -------------------------------------------------------------------------------- /water/execution_engine.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import asyncio 3 | import logging 4 | from enum import Enum 5 | from typing import Any, Dict, List 6 | from datetime import datetime 7 | 8 | from water.types import ( 9 | ExecutionGraph, 10 | ExecutionNode, 11 | InputData, 12 | OutputData, 13 | SequentialNode, 14 | ParallelNode, 15 | BranchNode, 16 | LoopNode 17 | ) 18 | from water.context import ExecutionContext 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | class NodeType(Enum): 23 | """Enumeration of supported execution node types.""" 24 | SEQUENTIAL = "sequential" 25 | PARALLEL = "parallel" 26 | BRANCH = "branch" 27 | LOOP = "loop" 28 | 29 | class ExecutionEngine: 30 | """ 31 | Core execution engine for Water flows. 32 | 33 | Orchestrates the execution of different node types including sequential tasks, 34 | parallel execution, conditional branching, and loops. 35 | """ 36 | 37 | @staticmethod 38 | async def run( 39 | execution_graph: ExecutionGraph, 40 | input_data: InputData, 41 | flow_id: str, 42 | flow_metadata: Dict[str, Any] = None 43 | ) -> OutputData: 44 | """ 45 | Execute a complete flow execution graph. 46 | 47 | Args: 48 | execution_graph: List of execution nodes to process 49 | input_data: Initial input data 50 | flow_id: Unique identifier for the flow execution 51 | flow_metadata: Optional metadata for the flow 52 | 53 | Returns: 54 | Final output data after all nodes are executed 55 | """ 56 | context = ExecutionContext( 57 | flow_id=flow_id, 58 | flow_metadata=flow_metadata or {} 59 | ) 60 | 61 | data: OutputData = input_data 62 | 63 | for node in execution_graph: 64 | data = await ExecutionEngine._execute_node(node, data, context) 65 | 66 | return data 67 | 68 | @staticmethod 69 | async def _execute_node( 70 | node: ExecutionNode, 71 | data: InputData, 72 | context: ExecutionContext 73 | ) -> OutputData: 74 | """ 75 | Route execution to the appropriate node type handler. 76 | 77 | Args: 78 | node: The execution node to process 79 | data: Input data for the node 80 | context: Execution context 81 | 82 | Returns: 83 | Output data from the node execution 84 | 85 | Raises: 86 | ValueError: If node type is unknown or unhandled 87 | """ 88 | try: 89 | node_type = NodeType(node["type"]) 90 | except ValueError: 91 | raise ValueError(f"Unknown node type: {node['type']}") 92 | 93 | handlers = { 94 | NodeType.SEQUENTIAL: ExecutionEngine._execute_sequential, 95 | NodeType.PARALLEL: ExecutionEngine._execute_parallel, 96 | NodeType.BRANCH: ExecutionEngine._execute_branch, 97 | NodeType.LOOP: ExecutionEngine._execute_loop, 98 | } 99 | 100 | handler = handlers.get(node_type) 101 | if not handler: 102 | raise ValueError(f"Unhandled node type: {node_type}") 103 | 104 | return await handler(node, data, context) 105 | 106 | @staticmethod 107 | async def _execute_task(task: Any, data: InputData, context: ExecutionContext) -> OutputData: 108 | """ 109 | Execute a single task, handling both sync and async functions. 110 | 111 | Args: 112 | task: The task to execute 113 | data: Input data for the task 114 | context: Execution context 115 | 116 | Returns: 117 | Output data from the task execution 118 | """ 119 | params: Dict[str, InputData] = {"input_data": data} 120 | 121 | # Update context with current task info 122 | context.task_id = task.id 123 | context.step_start_time = datetime.utcnow() 124 | context.step_number += 1 125 | 126 | # Execute the task 127 | if inspect.iscoroutinefunction(task.execute): 128 | result = await task.execute(params, context) 129 | else: 130 | result = task.execute(params, context) 131 | 132 | # Store the task result in context for future tasks to access 133 | context.add_task_output(task.id, result) 134 | 135 | return result 136 | 137 | @staticmethod 138 | async def _execute_sequential( 139 | node: SequentialNode, 140 | data: InputData, 141 | context: ExecutionContext 142 | ) -> OutputData: 143 | """ 144 | Execute a single task sequentially. 145 | 146 | Args: 147 | node: Sequential execution node 148 | data: Input data 149 | context: Execution context 150 | 151 | Returns: 152 | Task execution result 153 | """ 154 | task = node["task"] 155 | return await ExecutionEngine._execute_task(task, data, context) 156 | 157 | @staticmethod 158 | async def _execute_parallel( 159 | node: ParallelNode, 160 | data: InputData, 161 | context: ExecutionContext 162 | ) -> OutputData: 163 | """ 164 | Execute multiple tasks in parallel. 165 | 166 | Args: 167 | node: Parallel execution node 168 | data: Input data (shared by all tasks) 169 | context: Execution context 170 | 171 | Returns: 172 | Dictionary mapping task IDs to their results 173 | """ 174 | tasks = node["tasks"] 175 | 176 | # Create task execution coroutines 177 | async def execute_single_task(task): 178 | return await ExecutionEngine._execute_task(task, data, context) 179 | 180 | coroutines = [execute_single_task(task) for task in tasks] 181 | 182 | # Execute all tasks in parallel 183 | results: List[OutputData] = await asyncio.gather(*coroutines) 184 | 185 | # Organize results by task ID 186 | parallel_results = {task.id: result for task, result in zip(tasks, results)} 187 | 188 | # Store individual parallel results in context (they were already stored by _execute_task) 189 | # but also store the combined parallel results as a special entry 190 | context.add_task_output("_parallel_results", parallel_results) 191 | 192 | return parallel_results 193 | 194 | @staticmethod 195 | async def _execute_branch( 196 | node: BranchNode, 197 | data: InputData, 198 | context: ExecutionContext 199 | ) -> OutputData: 200 | """ 201 | Execute conditional branching - run the first task whose condition matches. 202 | 203 | Args: 204 | node: Branch execution node 205 | data: Input data 206 | context: Execution context 207 | 208 | Returns: 209 | Result from the executed branch, or input data if no conditions match 210 | """ 211 | branches = node["branches"] 212 | 213 | for branch in branches: 214 | condition = branch["condition"] 215 | 216 | if condition(data): 217 | task = branch["task"] 218 | return await ExecutionEngine._execute_task(task, data, context) 219 | 220 | # If no condition matched, return data unchanged 221 | return data 222 | 223 | @staticmethod 224 | async def _execute_loop( 225 | node: LoopNode, 226 | data: InputData, 227 | context: ExecutionContext 228 | ) -> OutputData: 229 | """ 230 | Execute a task repeatedly while condition is true. 231 | 232 | Args: 233 | node: Loop execution node 234 | data: Initial input data 235 | context: Execution context 236 | 237 | Returns: 238 | Final data after loop completion 239 | """ 240 | condition = node["condition"] 241 | task = node["task"] 242 | max_iterations: int = node.get("max_iterations", 100) 243 | 244 | iteration_count: int = 0 245 | current_data: OutputData = data 246 | 247 | while iteration_count < max_iterations: 248 | if not condition(current_data): 249 | break 250 | 251 | current_data = await ExecutionEngine._execute_task(task, current_data, context) 252 | iteration_count += 1 253 | 254 | if iteration_count >= max_iterations: 255 | logger.warning(f"Loop reached maximum iterations ({max_iterations}) for flow {context.flow_id}") 256 | 257 | return current_data -------------------------------------------------------------------------------- /water/server.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Any, Optional, Type 2 | from fastapi import FastAPI, HTTPException 3 | from pydantic import BaseModel 4 | from datetime import datetime 5 | 6 | from water.flow import Flow 7 | 8 | 9 | class RunFlowRequest(BaseModel): 10 | """Request model for flow execution.""" 11 | input_data: Dict[str, Any] 12 | 13 | class RunFlowResponse(BaseModel): 14 | """Response model for flow execution results.""" 15 | flow_id: str 16 | status: str 17 | result: Dict[str, Any] 18 | execution_time_ms: float 19 | timestamp: datetime 20 | 21 | class TaskInfo(BaseModel): 22 | """Information about a task within a flow.""" 23 | id: str 24 | description: str 25 | type: str # "sequential", "parallel", "branch", "loop" 26 | input_schema: Optional[Dict[str, str]] = None 27 | output_schema: Optional[Dict[str, str]] = None 28 | 29 | class FlowSummary(BaseModel): 30 | """Summary information about a flow.""" 31 | id: str 32 | description: str 33 | tasks: List[TaskInfo] 34 | 35 | class FlowDetail(BaseModel): 36 | """Detailed information about a flow.""" 37 | id: str 38 | description: str 39 | metadata: Dict[str, Any] 40 | tasks: List[TaskInfo] 41 | 42 | class FlowsListResponse(BaseModel): 43 | """Response model for listing flows.""" 44 | flows: List[FlowSummary] 45 | 46 | class FlowServer: 47 | """ 48 | FastAPI server for hosting Water flows. 49 | 50 | Provides REST endpoints for discovering and executing flows. 51 | 52 | Example: 53 | flows = [flow1, flow2, flow3] 54 | app = FlowServer(flows=flows).get_app() 55 | 56 | if __name__ == "__main__": 57 | import uvicorn 58 | uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) 59 | """ 60 | 61 | def __init__(self, flows: List[Flow]) -> None: 62 | """ 63 | Initialize FlowServer with a list of flows. 64 | 65 | Args: 66 | flows: List of registered Flow instances 67 | 68 | Raises: 69 | ValueError: If flows contain duplicates or unregistered flows 70 | """ 71 | self.flows: Dict[str, Flow] = {} 72 | for flow in flows: 73 | if flow.id in self.flows: 74 | raise ValueError(f"Duplicate flow ID: {flow.id}") 75 | if not flow._registered: 76 | raise ValueError(f"Flow {flow.id} must be registered before adding to server") 77 | self.flows[flow.id] = flow 78 | 79 | def _serialize_schema(self, schema_class: Type[BaseModel]) -> Optional[Dict[str, str]]: 80 | """ 81 | Convert Pydantic model to a simple field:type mapping. 82 | 83 | Args: 84 | schema_class: Pydantic BaseModel class 85 | 86 | Returns: 87 | Dictionary mapping field names to simplified type strings, or None 88 | """ 89 | if not schema_class: 90 | return None 91 | 92 | try: 93 | schema_dict = {} 94 | for field_name, field_info in schema_class.model_fields.items(): 95 | # Simple type mapping - much simpler than before 96 | field_type = field_info.annotation 97 | type_name = getattr(field_type, '__name__', str(field_type)) 98 | 99 | # Basic type cleanup 100 | if 'int' in type_name.lower(): 101 | schema_dict[field_name] = "int" 102 | elif 'float' in type_name.lower(): 103 | schema_dict[field_name] = "float" 104 | elif 'str' in type_name.lower(): 105 | schema_dict[field_name] = "string" 106 | elif 'bool' in type_name.lower(): 107 | schema_dict[field_name] = "boolean" 108 | elif 'list' in type_name.lower(): 109 | schema_dict[field_name] = "array" 110 | elif 'dict' in type_name.lower(): 111 | schema_dict[field_name] = "object" 112 | else: 113 | schema_dict[field_name] = type_name 114 | 115 | return schema_dict 116 | 117 | except Exception: 118 | return {"error": "Could not parse schema"} 119 | 120 | def _extract_task_info(self, execution_nodes: List[Dict[str, Any]]) -> List[TaskInfo]: 121 | """ 122 | Extract task information from execution nodes. 123 | 124 | Args: 125 | execution_nodes: List of execution node dictionaries 126 | 127 | Returns: 128 | List of TaskInfo objects 129 | """ 130 | task_infos = [] 131 | 132 | for node in execution_nodes: 133 | node_type = node["type"] 134 | 135 | if node_type == "sequential": 136 | task = node["task"] 137 | task_infos.append(TaskInfo( 138 | id=task.id, 139 | description=task.description, 140 | type="sequential", 141 | input_schema=self._serialize_schema(task.input_schema), 142 | output_schema=self._serialize_schema(task.output_schema) 143 | )) 144 | 145 | elif node_type == "parallel": 146 | for task in node["tasks"]: 147 | task_infos.append(TaskInfo( 148 | id=task.id, 149 | description=task.description, 150 | type="parallel", 151 | input_schema=self._serialize_schema(task.input_schema), 152 | output_schema=self._serialize_schema(task.output_schema) 153 | )) 154 | 155 | elif node_type == "branch": 156 | for branch in node["branches"]: 157 | task = branch["task"] 158 | task_infos.append(TaskInfo( 159 | id=task.id, 160 | description=task.description, 161 | type="branch", 162 | input_schema=self._serialize_schema(task.input_schema), 163 | output_schema=self._serialize_schema(task.output_schema) 164 | )) 165 | 166 | elif node_type == "loop": 167 | task = node["task"] 168 | task_infos.append(TaskInfo( 169 | id=task.id, 170 | description=task.description, 171 | type="loop", 172 | input_schema=self._serialize_schema(task.input_schema), 173 | output_schema=self._serialize_schema(task.output_schema) 174 | )) 175 | 176 | return task_infos 177 | 178 | def get_app(self) -> FastAPI: 179 | """ 180 | Create and configure the FastAPI application. 181 | 182 | Returns: 183 | Configured FastAPI application instance 184 | """ 185 | app = FastAPI( 186 | title="Water Flows API", 187 | description="REST API for executing Water framework workflows", 188 | version="1.0.0" 189 | ) 190 | 191 | # Add CORS middleware for development 192 | from fastapi.middleware.cors import CORSMiddleware 193 | app.add_middleware( 194 | CORSMiddleware, 195 | allow_origins=["*"], 196 | allow_credentials=True, 197 | allow_methods=["*"], 198 | allow_headers=["*"], 199 | ) 200 | 201 | @app.get("/health") 202 | async def health_check(): 203 | """Health check endpoint.""" 204 | return { 205 | "status": "healthy", 206 | "flows_count": len(self.flows), 207 | "timestamp": datetime.utcnow().isoformat() 208 | } 209 | 210 | @app.get("/flows", response_model=FlowsListResponse) 211 | async def list_flows(): 212 | """Get list of all available flows.""" 213 | flows_summary = [] 214 | for flow in self.flows.values(): 215 | task_infos = self._extract_task_info(flow._tasks) 216 | flows_summary.append(FlowSummary( 217 | id=flow.id, 218 | description=flow.description, 219 | tasks=task_infos, 220 | )) 221 | 222 | return FlowsListResponse(flows=flows_summary) 223 | 224 | @app.get("/flows/{flow_id}", response_model=FlowDetail) 225 | async def get_flow_details(flow_id: str): 226 | """Get detailed information about a specific flow.""" 227 | if flow_id not in self.flows: 228 | raise HTTPException(status_code=404, detail=f"Flow '{flow_id}' not found") 229 | 230 | flow = self.flows[flow_id] 231 | task_infos = self._extract_task_info(flow._tasks) 232 | 233 | return FlowDetail( 234 | id=flow.id, 235 | description=flow.description, 236 | metadata=flow.metadata, 237 | tasks=task_infos, 238 | ) 239 | 240 | @app.post("/flows/{flow_id}/run", response_model=RunFlowResponse) 241 | async def run_flow(flow_id: str, request: RunFlowRequest): 242 | """Execute a specific flow with input data.""" 243 | if flow_id not in self.flows: 244 | raise HTTPException(status_code=404, detail=f"Flow '{flow_id}' not found") 245 | 246 | flow = self.flows[flow_id] 247 | 248 | try: 249 | start_time = datetime.utcnow() 250 | result = await flow.run(request.input_data) 251 | end_time = datetime.utcnow() 252 | 253 | execution_time_ms = round((end_time - start_time).total_seconds() * 1000, 4) 254 | 255 | return RunFlowResponse( 256 | flow_id=flow_id, 257 | status="success", 258 | result=result, 259 | execution_time_ms=execution_time_ms, 260 | timestamp=end_time 261 | ) 262 | 263 | except Exception as e: 264 | raise HTTPException( 265 | status_code=500, 266 | detail={ 267 | "error": str(e), 268 | "flow_id": flow_id 269 | } 270 | ) 271 | 272 | return app -------------------------------------------------------------------------------- /cookbook/agno_flow.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is a blog post generator flow using the Agno framework. 3 | 4 | pip install agno water-ai 5 | """ 6 | 7 | import json 8 | from textwrap import dedent 9 | from typing import Dict, Optional, Any 10 | import asyncio 11 | 12 | from water import Flow, create_task 13 | from agno.agent import Agent 14 | from agno.models.openai import OpenAIChat 15 | from agno.tools.duckduckgo import DuckDuckGoTools 16 | from agno.tools.newspaper4k import Newspaper4kTools 17 | from agno.utils.log import logger 18 | from agno.workflow import RunResponse 19 | from pydantic import BaseModel, Field 20 | from rich.prompt import Prompt 21 | 22 | 23 | class NewsArticle(BaseModel): 24 | title: str = Field(..., description="Title of the article.") 25 | url: str = Field(..., description="Link to the article.") 26 | summary: Optional[str] = Field( 27 | ..., description="Summary of the article if available." 28 | ) 29 | 30 | 31 | class SearchResults(BaseModel): 32 | articles: list[NewsArticle] 33 | 34 | 35 | class ScrapedArticle(BaseModel): 36 | title: str = Field(..., description="Title of the article.") 37 | url: str = Field(..., description="Link to the article.") 38 | summary: Optional[str] = Field( 39 | ..., description="Summary of the article if available." 40 | ) 41 | content: Optional[str] = Field( 42 | ..., 43 | description="Full article content in markdown format. None if content is unavailable.", 44 | ) 45 | 46 | searcher: Agent = Agent( 47 | model=OpenAIChat(id="gpt-4o-mini"), 48 | tools=[DuckDuckGoTools()], 49 | description=dedent("""\ 50 | You are BlogResearch-X, an elite research assistant specializing in discovering 51 | high-quality sources for compelling blog content. Your expertise includes: 52 | 53 | - Finding authoritative and trending sources 54 | - Evaluating content credibility and relevance 55 | - Identifying diverse perspectives and expert opinions 56 | - Discovering unique angles and insights 57 | - Ensuring comprehensive topic coverage\ 58 | """), 59 | instructions=dedent("""\ 60 | 1. Search Strategy 🔍 61 | - Find 10-15 relevant sources and select the 3 best ones 62 | - Prioritize recent, authoritative content 63 | - Look for unique angles and expert insights 64 | 2. Source Evaluation 📊 65 | - Verify source credibility and expertise 66 | - Check publication dates for timeliness 67 | - Assess content depth and uniqueness 68 | 3. Diversity of Perspectives 🌐 69 | - Include different viewpoints 70 | - Gather both mainstream and expert opinions 71 | - Find supporting data and statistics\ 72 | """), 73 | response_model=SearchResults, 74 | ) 75 | 76 | # Content Scraper: Extracts and processes article content 77 | article_scraper: Agent = Agent( 78 | model=OpenAIChat(id="gpt-4o-mini"), 79 | tools=[Newspaper4kTools()], 80 | description=dedent("""\ 81 | You are ContentBot-X, a specialist in extracting and processing digital content 82 | for blog creation. Your expertise includes: 83 | 84 | - Efficient content extraction 85 | - Smart formatting and structuring 86 | - Key information identification 87 | - Quote and statistic preservation 88 | - Maintaining source attribution\ 89 | """), 90 | instructions=dedent("""\ 91 | 1. Content Extraction 📑 92 | - Extract content from the article 93 | - Preserve important quotes and statistics 94 | - Maintain proper attribution 95 | - Handle paywalls gracefully 96 | 2. Content Processing 🔄 97 | - Format text in clean markdown 98 | - Preserve key information 99 | - Structure content logically 100 | 3. Quality Control ✅ 101 | - Verify content relevance 102 | - Ensure accurate extraction 103 | - Maintain readability\ 104 | """), 105 | response_model=ScrapedArticle, 106 | ) 107 | 108 | # Content Writer Agent: Crafts engaging blog posts from research 109 | writer: Agent = Agent( 110 | model=OpenAIChat(id="gpt-4o"), 111 | description=dedent("""\ 112 | You are BlogMaster-X, an elite content creator combining journalistic excellence 113 | with digital marketing expertise. Your strengths include: 114 | 115 | - Crafting viral-worthy headlines 116 | - Writing engaging introductions 117 | - Structuring content for digital consumption 118 | - Incorporating research seamlessly 119 | - Optimizing for SEO while maintaining quality 120 | - Creating shareable conclusions\ 121 | """), 122 | instructions=dedent("""\ 123 | 1. Content Strategy 📝 124 | - Craft attention-grabbing headlines 125 | - Write compelling introductions 126 | - Structure content for engagement 127 | - Include relevant subheadings 128 | 2. Writing Excellence ✍️ 129 | - Balance expertise with accessibility 130 | - Use clear, engaging language 131 | - Include relevant examples 132 | - Incorporate statistics naturally 133 | 3. Source Integration 🔍 134 | - Cite sources properly 135 | - Include expert quotes 136 | - Maintain factual accuracy 137 | 4. Digital Optimization 💻 138 | - Structure for scanability 139 | - Include shareable takeaways 140 | - Optimize for SEO 141 | - Add engaging subheadings\ 142 | """), 143 | expected_output=dedent("""\ 144 | # {Viral-Worthy Headline} 145 | 146 | ## Introduction 147 | {Engaging hook and context} 148 | 149 | ## {Compelling Section 1} 150 | {Key insights and analysis} 151 | {Expert quotes and statistics} 152 | 153 | ## {Engaging Section 2} 154 | {Deeper exploration} 155 | {Real-world examples} 156 | 157 | ## {Practical Section 3} 158 | {Actionable insights} 159 | {Expert recommendations} 160 | 161 | ## Key Takeaways 162 | - {Shareable insight 1} 163 | - {Practical takeaway 2} 164 | - {Notable finding 3} 165 | 166 | ## Sources 167 | {Properly attributed sources with links}\ 168 | """), 169 | markdown=True, 170 | ) 171 | 172 | # --- Define Pydantic Schemas for Task Inputs/Outputs --- 173 | class SearchTaskInput(BaseModel): 174 | topic: str 175 | 176 | class BlogPostOutput(BaseModel): 177 | title: str 178 | content: str 179 | 180 | class ScrapedArticlesOutput(BaseModel): 181 | articles: Dict[str, ScrapedArticle] 182 | 183 | # --- Define Task Execution Functions --- 184 | def get_search_results(params: Dict[str, Any], context) -> SearchResults: 185 | """Executes the search logic from get_search_results as a Water task.""" 186 | topic = params["input_data"]["topic"] 187 | num_attempts: int = 3 188 | for attempt in range(num_attempts): 189 | try: 190 | searcher_response: RunResponse = searcher.run(topic) 191 | if ( 192 | searcher_response is not None 193 | and searcher_response.content is not None 194 | and isinstance(searcher_response.content, SearchResults) 195 | ): 196 | article_count = len(searcher_response.content.articles) 197 | logger.info( 198 | f"Found {article_count} articles on attempt {attempt + 1}" 199 | ) 200 | return searcher_response.content 201 | else: 202 | logger.warning( 203 | f"Attempt {attempt + 1}/{num_attempts} failed: Invalid response type" 204 | ) 205 | except Exception as e: 206 | logger.warning(f"Attempt {attempt + 1}/{num_attempts} failed: {str(e)}") 207 | 208 | logger.error(f"Failed to get search results after {num_attempts} attempts") 209 | return None 210 | 211 | def scrape_articles(params: Dict[str, Any], context) -> ScrapedArticlesOutput: 212 | scraped_articles: Dict[str, ScrapedArticle] = {} 213 | search_results = params["input_data"] 214 | 215 | # Scrape the articles that are not in the cache 216 | for article in search_results.articles: 217 | if article.url in scraped_articles: 218 | logger.info(f"Found scraped article in cache: {article.url}") 219 | continue 220 | 221 | article_scraper_response: RunResponse = article_scraper.run( 222 | article.url 223 | ) 224 | if ( 225 | article_scraper_response is not None 226 | and article_scraper_response.content is not None 227 | and isinstance(article_scraper_response.content, ScrapedArticle) 228 | ): 229 | scraped_articles[article_scraper_response.content.url] = ( 230 | article_scraper_response.content 231 | ) 232 | logger.info(f"Scraped article: {article_scraper_response.content.url}") 233 | return scraped_articles 234 | 235 | def generate_blog_post(params: Dict[str, Any], context) -> BlogPostOutput: 236 | """Generate a blog post using the Agno writer agent.""" 237 | articles = params["input_data"] 238 | 239 | logger.info(f"✍️ Generating blog post") 240 | writer_input = { 241 | "articles": [article.model_dump() for article in articles.values()] 242 | } 243 | response = writer.run(json.dumps(writer_input, indent=2)) 244 | 245 | if response and response.content: 246 | logger.info("✅ Blog post generation complete.") 247 | return { 248 | "title": f"Comprehensive Guide", 249 | "content": response.content 250 | } 251 | else: 252 | logger.error("Failed to generate blog post.") 253 | return { 254 | "title": f"Guide", 255 | "content": "Failed to generate content." 256 | } 257 | 258 | # --- Create the Water Tasks --- 259 | search_task = create_task( 260 | id="search", 261 | description="Search for articles using the Agno searcher agent.", 262 | execute=get_search_results, 263 | input_schema=SearchTaskInput, 264 | output_schema=SearchResults 265 | ) 266 | 267 | scrape_task = create_task( 268 | id="scrape", 269 | description="Scrape articles using the Agno article scraper agent.", 270 | execute=scrape_articles, 271 | input_schema=SearchResults, 272 | output_schema=ScrapedArticlesOutput 273 | ) 274 | 275 | write_task = create_task( 276 | id="write", 277 | description="Generate blog post using the Agno writer agent.", 278 | execute=generate_blog_post, 279 | input_schema=ScrapedArticlesOutput, 280 | output_schema=BlogPostOutput 281 | ) 282 | 283 | # --- Create and Register the Sequential Water Flow --- 284 | blog_flow = Flow( 285 | id="blog_generation_flow", 286 | description="Sequential flow to generate a blog post using Agno agents." 287 | ) 288 | blog_flow.then(search_task).then(scrape_task).then(write_task).register() 289 | 290 | # --- Main Execution Block --- 291 | async def main(): 292 | topic = Prompt.ask( 293 | "[bold]Enter a blog post topic[/bold] (or press Enter for a random example)\n✨", 294 | ) 295 | 296 | print("\n" + "="*80) 297 | print(f"🚀 Running Blog Generation for '{topic}' using a Water Sequential Flow 🚀") 298 | print("="*80 + "\n") 299 | 300 | try: 301 | result = await blog_flow.run({"topic": topic}) 302 | 303 | print("\n" + "="*80) 304 | print("🎉 Blog Post Generation Complete! 🎉") 305 | print("="*80 + "\n") 306 | print(result.title + "\n") 307 | print(result.content) 308 | 309 | except Exception as e: 310 | logger.error(f"❌ Flow execution failed: {e}", exc_info=True) 311 | 312 | 313 | if __name__ == "__main__": 314 | asyncio.run(main()) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /tests/test_flow.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import asyncio 3 | from pydantic import BaseModel 4 | from water import create_task 5 | from water.flow import Flow 6 | 7 | 8 | # --- Test Schemas --- 9 | 10 | class NumberInput(BaseModel): 11 | value: int 12 | 13 | class NumberOutput(BaseModel): 14 | value: int 15 | 16 | # --- Flow Initialization Tests --- 17 | 18 | def test_flow_auto_generated_id(): 19 | flow = Flow(id="test_flow", description="Test flow with auto-generated ID") 20 | assert flow.id == "test_flow" 21 | assert flow.description == "Test flow with auto-generated ID" 22 | 23 | def test_flow_custom_id_description(): 24 | flow = Flow(id="custom_flow", description="Custom flow with explicit description") 25 | assert flow.id == "custom_flow" 26 | assert flow.description == "Custom flow with explicit description" 27 | 28 | # --- Task Validation Tests --- 29 | 30 | def test_then_none_task(): 31 | flow = Flow(id="test_flow", description="Test flow for None task") 32 | with pytest.raises(ValueError, match="Task cannot be None"): 33 | flow.then(None) 34 | 35 | def test_empty_flow_registration(): 36 | flow = Flow(id="test_flow", description="Test flow for empty registration") 37 | with pytest.raises(ValueError, match="Flow must have at least one task"): 38 | flow.register() 39 | 40 | # --- Tests --- 41 | 42 | @pytest.mark.asyncio 43 | async def test_sequential_flow_success(): 44 | add_one = create_task( 45 | id="add_one", 46 | description="Add one to the input value", 47 | input_schema=NumberInput, 48 | output_schema=NumberOutput, 49 | execute=lambda params, context: {"value": params["input_data"]["value"] + 1} 50 | ) 51 | double = create_task( 52 | id="double", 53 | description="Double the input value", 54 | input_schema=NumberOutput, 55 | output_schema=NumberOutput, 56 | execute=lambda params, context: {"value": params["input_data"]["value"] * 2} 57 | ) 58 | 59 | flow = Flow(id="simple_flow", description="Add one, then double") 60 | flow.then(add_one).then(double).register() 61 | 62 | result = await flow.run({"value": 3}) 63 | assert result["value"] == 8 # (3 + 1) * 2 64 | 65 | @pytest.mark.asyncio 66 | async def test_flow_requires_registration(): 67 | add_one = create_task( 68 | id="add_one", 69 | description="Add one to the input value", 70 | input_schema=NumberInput, 71 | output_schema=NumberOutput, 72 | execute=lambda params, context: {"value": params["input_data"]["value"] + 1} 73 | ) 74 | flow = Flow(id="unregistered_flow", description="Flow that should fail without registration") 75 | flow.then(add_one) 76 | with pytest.raises(RuntimeError, match="Flow must be registered"): 77 | await flow.run({"value": 1}) 78 | 79 | def test_cannot_add_after_register(): 80 | add_one = create_task( 81 | id="add_one", 82 | description="Add one to the input value", 83 | input_schema=NumberInput, 84 | output_schema=NumberOutput, 85 | execute=lambda params, context: {"value": params["input_data"]["value"] + 1} 86 | ) 87 | flow = Flow(id="no_add_after_register", description="Test adding after registration") 88 | flow.then(add_one).register() 89 | with pytest.raises(RuntimeError, match="Cannot add tasks after registration"): 90 | flow.then(add_one) 91 | 92 | @pytest.mark.asyncio 93 | async def test_complex_flow_with_regular_functions(): 94 | def multiply_and_add(params, context): 95 | v = params["input_data"]["value"] 96 | return {"value": v * 2 + 5} 97 | 98 | def subtract_three(params, context): 99 | v = params["input_data"]["value"] 100 | return {"value": v - 3} 101 | 102 | task1 = create_task( 103 | id="multiply_add", 104 | description="Multiply by 2 and add 5", 105 | input_schema=NumberInput, 106 | output_schema=NumberOutput, 107 | execute=multiply_and_add 108 | ) 109 | 110 | task2 = create_task( 111 | id="subtract_three", 112 | description="Subtract 3 from the result", 113 | input_schema=NumberOutput, 114 | output_schema=NumberOutput, 115 | execute=subtract_three 116 | ) 117 | 118 | flow = Flow(id="complex_flow", description="Multiply/add, then subtract") 119 | flow.then(task1).then(task2).register() 120 | 121 | result = await flow.run({"value": 4}) 122 | assert result["value"] == 10 # (4*2+5)=13, 13-3=10 123 | 124 | @pytest.mark.asyncio 125 | async def test_branching_flow(): 126 | def high_task(params, context): 127 | return {"value": params["input_data"]["value"] * 2} 128 | 129 | def low_task(params, context): 130 | return {"value": params["input_data"]["value"] + 1} 131 | 132 | task1 = create_task( 133 | id="initial_task", 134 | description="Add 5 to the input value", 135 | input_schema=NumberInput, 136 | output_schema=NumberOutput, 137 | execute=lambda params, context: {"value": params["input_data"]["value"] + 5} 138 | ) 139 | 140 | task2 = create_task( 141 | id="high_task", 142 | description="Double the value for high inputs", 143 | input_schema=NumberOutput, 144 | output_schema=NumberOutput, 145 | execute=high_task 146 | ) 147 | 148 | task3 = create_task( 149 | id="low_task", 150 | description="Add 1 for low inputs", 151 | input_schema=NumberOutput, 152 | output_schema=NumberOutput, 153 | execute=low_task 154 | ) 155 | 156 | flow = Flow(id="branching_flow", description="Test branching logic") 157 | flow.then(task1).branch([ 158 | (lambda data: data["value"] > 10, task2), 159 | (lambda data: data["value"] <= 10, task3) 160 | ]).register() 161 | 162 | # Test high branch 163 | result = await flow.run({"value": 8}) 164 | assert result["value"] == 26 # (8 + 5) > 10, so high_task: 13 * 2 165 | 166 | # Test low branch 167 | result = await flow.run({"value": 2}) 168 | assert result["value"] == 8 # (2 + 5) <= 10, so low_task: 7 + 1 169 | 170 | @pytest.mark.asyncio 171 | async def test_empty_branch_list(): 172 | task1 = create_task( 173 | id="initial_task", 174 | description="Add 5 to the input value", 175 | input_schema=NumberInput, 176 | output_schema=NumberOutput, 177 | execute=lambda params, context: {"value": params["input_data"]["value"] + 5} 178 | ) 179 | flow = Flow(id="empty_branch_flow", description="Test empty branch list") 180 | with pytest.raises(ValueError, match="Branch list cannot be empty"): 181 | flow.then(task1).branch([]).register() 182 | 183 | @pytest.mark.asyncio 184 | async def test_no_matching_branch(): 185 | task1 = create_task( 186 | id="initial_task", 187 | description="Add 5 to the input value", 188 | input_schema=NumberInput, 189 | output_schema=NumberOutput, 190 | execute=lambda params, context: {"value": params["input_data"]["value"] + 5} 191 | ) 192 | task2 = create_task( 193 | id="high_task", 194 | description="Double the value for high inputs", 195 | input_schema=NumberOutput, 196 | output_schema=NumberOutput, 197 | execute=lambda params, context: {"value": params["input_data"]["value"] * 2} 198 | ) 199 | task3 = create_task( 200 | id="low_task", 201 | description="Add 1 for low inputs", 202 | input_schema=NumberOutput, 203 | output_schema=NumberOutput, 204 | execute=lambda params, context: {"value": params["input_data"]["value"] + 1} 205 | ) 206 | task4 = create_task( 207 | id="final_task", 208 | description="Add 5 to the final value", 209 | input_schema=NumberOutput, 210 | output_schema=NumberOutput, 211 | execute=lambda params, context: {"value": params["input_data"]["value"] + 5} 212 | ) 213 | flow = Flow(id="no_match_flow", description="Test no matching branch") 214 | flow.then(task1).branch([ 215 | (lambda data: data["value"] > 100, task2), 216 | (lambda data: data["value"] < 0, task3) 217 | ]).then(task4).register() 218 | result = await flow.run({"value": 50}) 219 | assert result["value"] == 60 # task1: 50 + 5, then task4: 55 + 5 220 | 221 | @pytest.mark.asyncio 222 | async def test_async_branch_condition(): 223 | async def async_condition(data): 224 | await asyncio.sleep(0.01) 225 | return data["value"] > 10 226 | 227 | task1 = create_task( 228 | id="initial_task", 229 | description="Add 5 to the input value", 230 | input_schema=NumberInput, 231 | output_schema=NumberOutput, 232 | execute=lambda params, context: {"value": params["input_data"]["value"] + 5} 233 | ) 234 | task2 = create_task( 235 | id="high_task", 236 | description="Double the value for high inputs", 237 | input_schema=NumberOutput, 238 | output_schema=NumberOutput, 239 | execute=lambda params, context: {"value": params["input_data"]["value"] * 2} 240 | ) 241 | task3 = create_task( 242 | id="low_task", 243 | description="Add 1 for low inputs", 244 | input_schema=NumberOutput, 245 | output_schema=NumberOutput, 246 | execute=lambda params, context: {"value": params["input_data"]["value"] + 1} 247 | ) 248 | flow = Flow(id="async_branch_flow", description="Test async branch condition") 249 | with pytest.raises(ValueError, match="Branch conditions cannot be async functions"): 250 | flow.then(task1).branch([ 251 | (async_condition, task2), 252 | (lambda data: data["value"] <= 10, task3) 253 | ]).register() 254 | 255 | @pytest.mark.asyncio 256 | async def test_branch_condition_exception(): 257 | def failing_condition(data): 258 | raise RuntimeError("Condition failed") 259 | 260 | task1 = create_task( 261 | id="initial_task", 262 | description="Add 5 to the input value", 263 | input_schema=NumberInput, 264 | output_schema=NumberOutput, 265 | execute=lambda params, context: {"value": params["input_data"]["value"] + 5} 266 | ) 267 | task2 = create_task( 268 | id="high_task", 269 | description="Double the value for high inputs", 270 | input_schema=NumberOutput, 271 | output_schema=NumberOutput, 272 | execute=lambda params, context: {"value": params["input_data"]["value"] * 2} 273 | ) 274 | task3 = create_task( 275 | id="low_task", 276 | description="Add 1 for low inputs", 277 | input_schema=NumberOutput, 278 | output_schema=NumberOutput, 279 | execute=lambda params, context: {"value": params["input_data"]["value"] + 1} 280 | ) 281 | flow = Flow(id="branch_exception_flow", description="Test branch condition exception") 282 | flow.then(task1).branch([ 283 | (failing_condition, task2), 284 | (lambda data: data["value"] <= 10, task3) 285 | ]).register() 286 | with pytest.raises(RuntimeError, match="Condition failed"): 287 | await flow.run({"value": 8}) 288 | 289 | @pytest.mark.asyncio 290 | async def test_branch_task_exception(): 291 | failing_task = create_task( 292 | id="failing_task", 293 | description="Task that raises an exception", 294 | input_schema=NumberOutput, 295 | output_schema=NumberOutput, 296 | execute=lambda params, context: (_ for _ in ()).throw(RuntimeError("Task failed")) 297 | ) 298 | 299 | task1 = create_task( 300 | id="initial_task", 301 | description="Add 5 to the input value", 302 | input_schema=NumberInput, 303 | output_schema=NumberOutput, 304 | execute=lambda params, context: {"value": params["input_data"]["value"] + 5} 305 | ) 306 | task3 = create_task( 307 | id="low_task", 308 | description="Add 1 for low inputs", 309 | input_schema=NumberOutput, 310 | output_schema=NumberOutput, 311 | execute=lambda params, context: {"value": params["input_data"]["value"] + 1} 312 | ) 313 | flow = Flow(id="branch_task_exception_flow", description="Test branch task exception") 314 | flow.then(task1).branch([ 315 | (lambda data: data["value"] > 10, failing_task), 316 | (lambda data: data["value"] <= 10, task3) 317 | ]).register() 318 | with pytest.raises(RuntimeError, match="Task failed"): 319 | await flow.run({"value": 8}) 320 | 321 | @pytest.mark.asyncio 322 | async def test_branch_invalid_task(): 323 | task1 = create_task( 324 | id="initial_task", 325 | description="Add 5 to the input value", 326 | input_schema=NumberInput, 327 | output_schema=NumberOutput, 328 | execute=lambda params, context: {"value": params["input_data"]["value"] + 5} 329 | ) 330 | task3 = create_task( 331 | id="low_task", 332 | description="Add 1 for low inputs", 333 | input_schema=NumberOutput, 334 | output_schema=NumberOutput, 335 | execute=lambda params, context: {"value": params["input_data"]["value"] + 1} 336 | ) 337 | flow = Flow(id="invalid_task_flow", description="Test branch with invalid task") 338 | with pytest.raises(ValueError, match="Task cannot be None"): 339 | flow.then(task1).branch([ 340 | (lambda data: data["value"] > 10, None), 341 | (lambda data: data["value"] <= 10, task3) 342 | ]).register() 343 | 344 | @pytest.mark.asyncio 345 | async def test_parallel_execution(): 346 | # Create tasks that modify different parts of the input 347 | add_task = create_task( 348 | id="add_task", 349 | description="Add 5 to the input value", 350 | input_schema=NumberInput, 351 | output_schema=NumberOutput, 352 | execute=lambda params, context: {"value": params["input_data"]["value"] + 5} 353 | ) 354 | 355 | multiply_task = create_task( 356 | id="multiply_task", 357 | description="Multiply input by 2", 358 | input_schema=NumberInput, 359 | output_schema=NumberOutput, 360 | execute=lambda params, context: {"value": params["input_data"]["value"] * 2} 361 | ) 362 | 363 | # Create a flow that runs tasks in parallel 364 | flow = Flow(id="parallel_flow", description="Test parallel task execution") 365 | flow.parallel([add_task, multiply_task]).register() 366 | 367 | # Run the flow 368 | result = await flow.run({"value": 10}) 369 | 370 | # Verify that results are organized by task ID 371 | assert "add_task" in result 372 | assert "multiply_task" in result 373 | assert result["add_task"]["value"] == 15 # 10 + 5 374 | assert result["multiply_task"]["value"] == 20 # 10 * 2 375 | 376 | @pytest.mark.asyncio 377 | async def test_parallel_with_async_tasks(): 378 | async def async_add(params, context): 379 | await asyncio.sleep(0.1) # Simulate async work 380 | return {"value": params["input_data"]["value"] + 5} 381 | 382 | async def async_multiply(params, context): 383 | await asyncio.sleep(0.1) # Simulate async work 384 | return {"value": params["input_data"]["value"] * 2} 385 | 386 | add_task = create_task( 387 | id="async_add_task", 388 | description="Async add 5 to the input value", 389 | input_schema=NumberInput, 390 | output_schema=NumberOutput, 391 | execute=async_add 392 | ) 393 | 394 | multiply_task = create_task( 395 | id="async_multiply_task", 396 | description="Async multiply input by 2", 397 | input_schema=NumberInput, 398 | output_schema=NumberOutput, 399 | execute=async_multiply 400 | ) 401 | 402 | flow = Flow(id="async_parallel_flow", description="Test parallel async task execution") 403 | flow.parallel([add_task, multiply_task]).register() 404 | 405 | # Run the flow and measure execution time 406 | start_time = asyncio.get_event_loop().time() 407 | result = await flow.run({"value": 10}) 408 | end_time = asyncio.get_event_loop().time() 409 | 410 | # Verify results are organized by task ID 411 | assert "async_add_task" in result 412 | assert "async_multiply_task" in result 413 | assert result["async_add_task"]["value"] == 15 # 10 + 5 414 | assert result["async_multiply_task"]["value"] == 20 # 10 * 2 415 | 416 | # Verify that tasks ran in parallel (total time should be ~0.1s, not ~0.2s) 417 | execution_time = end_time - start_time 418 | assert execution_time < 0.15 # Allow some overhead 419 | 420 | @pytest.mark.asyncio 421 | async def test_parallel_with_different_output_keys(): 422 | # Create tasks that return different output keys 423 | task1 = create_task( 424 | id="sum_task", 425 | description="Calculate sum", 426 | input_schema=NumberInput, 427 | output_schema=NumberOutput, 428 | execute=lambda params, context: {"sum": params["input_data"]["value"] + 5} 429 | ) 430 | 431 | task2 = create_task( 432 | id="product_task", 433 | description="Calculate product", 434 | input_schema=NumberInput, 435 | output_schema=NumberOutput, 436 | execute=lambda params, context: {"product": params["input_data"]["value"] * 2} 437 | ) 438 | 439 | flow = Flow(id="different_keys_flow", description="Test parallel tasks with different output keys") 440 | flow.parallel([task1, task2]).register() 441 | 442 | result = await flow.run({"value": 10}) 443 | 444 | # Verify that each task's output is preserved under its ID 445 | assert "sum_task" in result 446 | assert "product_task" in result 447 | assert result["sum_task"]["sum"] == 15 # 10 + 5 448 | assert result["product_task"]["product"] == 20 # 10 * 2 449 | 450 | @pytest.mark.asyncio 451 | async def test_empty_parallel_list(): 452 | flow = Flow(id="empty_parallel_flow", description="Test empty parallel task list") 453 | with pytest.raises(ValueError, match="Parallel task list cannot be empty"): 454 | flow.parallel([]).register() 455 | 456 | @pytest.mark.asyncio 457 | async def test_parallel_with_none_task(): 458 | add_task = create_task( 459 | id="add_task", 460 | description="Add 5 to the input value", 461 | input_schema=NumberInput, 462 | output_schema=NumberOutput, 463 | execute=lambda params, context: {"value": params["input_data"]["value"] + 5} 464 | ) 465 | 466 | flow = Flow(id="none_parallel_flow", description="Test parallel with None task") 467 | with pytest.raises(ValueError, match="Task cannot be None"): 468 | flow.parallel([add_task, None]).register() 469 | 470 | @pytest.mark.asyncio 471 | async def test_parallel_task_exception(): 472 | failing_task = create_task( 473 | id="failing_task", 474 | description="Task that raises an exception", 475 | input_schema=NumberInput, 476 | output_schema=NumberOutput, 477 | execute=lambda params, context: (_ for _ in ()).throw(RuntimeError("Task failed")) 478 | ) 479 | 480 | add_task = create_task( 481 | id="add_task", 482 | description="Add 5 to the input value", 483 | input_schema=NumberInput, 484 | output_schema=NumberOutput, 485 | execute=lambda params, context: {"value": params["input_data"]["value"] + 5} 486 | ) 487 | 488 | flow = Flow(id="parallel_exception_flow", description="Test parallel task exception") 489 | flow.parallel([failing_task, add_task]).register() 490 | 491 | with pytest.raises(RuntimeError, match="Task failed"): 492 | await flow.run({"value": 10}) 493 | 494 | @pytest.mark.asyncio 495 | async def test_loop(): 496 | # Create a task that increments a counter 497 | increment_task = create_task( 498 | id="increment", 499 | description="Increment the counter", 500 | input_schema=NumberInput, 501 | output_schema=NumberOutput, 502 | execute=lambda params, context: {"value": params["input_data"]["value"] + 1} 503 | ) 504 | 505 | # Create a flow with a loop that runs until the counter reaches 5 506 | flow = Flow(id="loop_flow", description="Test loop") 507 | flow.loop( 508 | condition=lambda data: data["value"] < 5, 509 | task=increment_task 510 | ).register() 511 | 512 | # Run the flow with an initial value of 0 513 | result = await flow.run({"value": 0}) 514 | 515 | # The result should be 5 (0 -> 1 -> 2 -> 3 -> 4 -> 5) 516 | assert result["value"] == 5 517 | 518 | @pytest.mark.asyncio 519 | async def test_loop_with_max_iterations(): 520 | # Create a task that increments a counter 521 | increment_task = create_task( 522 | id="increment", 523 | description="Increment the counter", 524 | input_schema=NumberInput, 525 | output_schema=NumberOutput, 526 | execute=lambda params, context: {"value": params["input_data"]["value"] + 1} 527 | ) 528 | 529 | # Create a flow with a loop that would run forever without max_iterations 530 | flow = Flow(id="max_iterations_flow", description="Test loop with max iterations") 531 | flow.loop( 532 | condition=lambda data: True, # Always true condition 533 | task=increment_task, 534 | max_iterations=3 # Limit to 3 iterations 535 | ).register() 536 | 537 | # Run the flow with an initial value of 0 538 | result = await flow.run({"value": 0}) 539 | 540 | # The result should be 3 (0 -> 1 -> 2 -> 3) 541 | assert result["value"] == 3 542 | 543 | @pytest.mark.asyncio 544 | async def test_loop_with_complex_condition(): 545 | # Create a task that doubles the value 546 | double_task = create_task( 547 | id="double", 548 | description="Double the value", 549 | input_schema=NumberInput, 550 | output_schema=NumberOutput, 551 | execute=lambda params, context: {"value": params["input_data"]["value"] * 2} 552 | ) 553 | 554 | # Create a flow with a loop that runs until the value exceeds 100 555 | flow = Flow(id="complex_condition_flow", description="Test loop with complex condition") 556 | flow.loop( 557 | condition=lambda data: data["value"] < 100, 558 | task=double_task 559 | ).register() 560 | 561 | # Run the flow with an initial value of 5 562 | result = await flow.run({"value": 5}) 563 | 564 | # The result should be 160 (5 -> 10 -> 20 -> 40 -> 80 -> 160) 565 | assert result["value"] == 160 566 | 567 | @pytest.mark.asyncio 568 | async def test_loop_never_executes(): 569 | # Create a task that increments a counter 570 | increment_task = create_task( 571 | id="increment", 572 | description="Increment the counter", 573 | input_schema=NumberInput, 574 | output_schema=NumberOutput, 575 | execute=lambda params, context: {"value": params["input_data"]["value"] + 1} 576 | ) 577 | 578 | # Create a flow with a loop that never executes because condition is false 579 | flow = Flow(id="never_executes_flow", description="Test loop that never executes") 580 | flow.loop( 581 | condition=lambda data: data["value"] < 0, # Condition is false from the start 582 | task=increment_task 583 | ).register() 584 | 585 | # Run the flow with an initial value of 5 586 | result = await flow.run({"value": 5}) 587 | 588 | # The result should remain 5 since the loop never executes 589 | assert result["value"] == 5 590 | 591 | @pytest.mark.asyncio 592 | async def test_loop_with_async_task(): 593 | # Create an async task that increments a counter 594 | async def async_increment(params, context): 595 | await asyncio.sleep(0.01) # Simulate async work 596 | return {"value": params["input_data"]["value"] + 1} 597 | 598 | increment_task = create_task( 599 | id="async_increment", 600 | description="Async increment the counter", 601 | input_schema=NumberInput, 602 | output_schema=NumberOutput, 603 | execute=async_increment 604 | ) 605 | 606 | # Create a flow with a loop using async task 607 | flow = Flow(id="async_loop_flow", description="Test loop with async task") 608 | flow.loop( 609 | condition=lambda data: data["value"] < 3, 610 | task=increment_task 611 | ).register() 612 | 613 | # Run the flow with an initial value of 0 614 | result = await flow.run({"value": 0}) 615 | 616 | # The result should be 3 (0 -> 1 -> 2 -> 3) 617 | assert result["value"] == 3 618 | 619 | @pytest.mark.asyncio 620 | async def test_loop_with_error(): 621 | # Create a task that fails after a certain number of iterations 622 | def failing_task(params, context): 623 | if params["input_data"]["value"] >= 3: 624 | raise RuntimeError("Task failed after 3 iterations") 625 | return {"value": params["input_data"]["value"] + 1} 626 | 627 | increment_task = create_task( 628 | id="increment", 629 | description="Increment until failure", 630 | input_schema=NumberInput, 631 | output_schema=NumberOutput, 632 | execute=failing_task 633 | ) 634 | 635 | # Create a flow with a loop 636 | flow = Flow(id="loop_error_flow", description="Test loop with error") 637 | flow.loop( 638 | condition=lambda data: data["value"] < 5, 639 | task=increment_task 640 | ).register() 641 | 642 | # Run the flow with an initial value of 0 643 | with pytest.raises(RuntimeError, match="Task failed after 3 iterations"): 644 | await flow.run({"value": 0}) 645 | 646 | @pytest.mark.asyncio 647 | async def test_loop_condition_exception(): 648 | # Create a condition that fails 649 | def failing_condition(data): 650 | if data["value"] >= 2: 651 | raise RuntimeError("Condition failed") 652 | return data["value"] < 5 653 | 654 | increment_task = create_task( 655 | id="increment", 656 | description="Increment the counter", 657 | input_schema=NumberInput, 658 | output_schema=NumberOutput, 659 | execute=lambda params, context: {"value": params["input_data"]["value"] + 1} 660 | ) 661 | 662 | # Create a flow with a loop with failing condition 663 | flow = Flow(id="condition_error_flow", description="Test loop with condition error") 664 | flow.loop( 665 | condition=failing_condition, 666 | task=increment_task 667 | ).register() 668 | 669 | # Run the flow with an initial value of 0 670 | with pytest.raises(RuntimeError, match="Condition failed"): 671 | await flow.run({"value": 0}) 672 | 673 | @pytest.mark.asyncio 674 | async def test_loop_with_async_condition(): 675 | # Test that async conditions are rejected 676 | async def async_condition(data): 677 | await asyncio.sleep(0.01) 678 | return data["value"] < 5 679 | 680 | increment_task = create_task( 681 | id="increment", 682 | description="Increment the counter", 683 | input_schema=NumberInput, 684 | output_schema=NumberOutput, 685 | execute=lambda params, context: {"value": params["input_data"]["value"] + 1} 686 | ) 687 | 688 | flow = Flow(id="async_condition_flow", description="Test loop with async condition") 689 | with pytest.raises(ValueError, match="Loop conditions cannot be async functions"): 690 | flow.loop( 691 | condition=async_condition, 692 | task=increment_task 693 | ).register() 694 | 695 | def test_loop_with_none_task(): 696 | flow = Flow(id="none_task_flow", description="Test loop with None task") 697 | with pytest.raises(ValueError, match="Task cannot be None"): 698 | flow.loop( 699 | condition=lambda data: data["value"] < 5, 700 | task=None 701 | ).register() 702 | 703 | @pytest.mark.asyncio 704 | async def test_sequential_then_loop(): 705 | # Test combining sequential tasks with loop 706 | add_five_task = create_task( 707 | id="add_five", 708 | description="Add 5 to the value", 709 | input_schema=NumberInput, 710 | output_schema=NumberOutput, 711 | execute=lambda params, context: {"value": params["input_data"]["value"] + 5} 712 | ) 713 | 714 | increment_task = create_task( 715 | id="increment", 716 | description="Increment the counter", 717 | input_schema=NumberInput, 718 | output_schema=NumberOutput, 719 | execute=lambda params, context: {"value": params["input_data"]["value"] + 1} 720 | ) 721 | 722 | # Create a flow that first adds 5, then increments while value < 10 723 | flow = Flow(id="sequential_loop_flow", description="Test sequential then loop") 724 | flow.then(add_five_task).loop( 725 | condition=lambda data: data["value"] < 10, 726 | task=increment_task 727 | ).register() 728 | 729 | # Run the flow with an initial value of 2 730 | result = await flow.run({"value": 2}) 731 | 732 | # The result should be 10 (2 -> 7 -> 8 -> 9 -> 10) 733 | assert result["value"] == 10 734 | 735 | @pytest.mark.asyncio 736 | async def test_loop_then_sequential(): 737 | # Test combining loop with sequential tasks 738 | increment_task = create_task( 739 | id="increment", 740 | description="Increment the counter", 741 | input_schema=NumberInput, 742 | output_schema=NumberOutput, 743 | execute=lambda params, context: {"value": params["input_data"]["value"] + 1} 744 | ) 745 | 746 | multiply_by_two_task = create_task( 747 | id="multiply_by_two", 748 | description="Multiply by 2", 749 | input_schema=NumberInput, 750 | output_schema=NumberOutput, 751 | execute=lambda params, context: {"value": params["input_data"]["value"] * 2} 752 | ) 753 | 754 | # Create a flow that increments while value < 5, then multiplies by 2 755 | flow = Flow(id="loop_sequential_flow", description="Test loop then sequential") 756 | flow.loop( 757 | condition=lambda data: data["value"] < 5, 758 | task=increment_task 759 | ).then(multiply_by_two_task).register() 760 | 761 | # Run the flow with an initial value of 2 762 | result = await flow.run({"value": 2}) 763 | 764 | # The result should be 10 (2 -> 3 -> 4 -> 5, then 5 * 2 = 10) 765 | assert result["value"] == 10 766 | --------------------------------------------------------------------------------