├── .gitignore ├── .gitmodules ├── Ayo ├── __init__.py ├── app.py ├── configs │ ├── __init__.py │ ├── config.py │ └── model_config.py ├── dags │ ├── __init__.py │ ├── dag.py │ ├── node.py │ └── node_commons.py ├── engines │ ├── __init__.py │ ├── aggregator.py │ ├── base_engine.py │ ├── embedder.py │ ├── engine_types.py │ ├── llm.py │ ├── payload_transformers.py │ ├── reranker.py │ └── vector_db.py ├── logger.py ├── modules │ ├── base_module.py │ ├── embedding.py │ ├── indexing.py │ ├── llm_syhthesizing.py │ ├── mod_to_prim.py │ ├── prompt_template.py │ ├── query_expanding.py │ ├── reranking.py │ └── searching.py ├── opt_pass │ ├── base_pass.py │ ├── decoding_pipeling.py │ ├── pass_manager.py │ ├── prefilling_split.py │ ├── pruning_dependency.py │ ├── stage_decomposition.py │ ├── test_dag_llm_decoding_pipeling.pdf │ └── test_dag_prefilling_split.pdf ├── queries │ ├── __init__.py │ ├── query.py │ └── query_state.py ├── schedulers │ ├── engine_scheduler.py │ └── graph_scheduler.py ├── utils.py └── vis │ ├── test_dag_node_types.png │ └── vis_graph.py ├── LICENSE ├── README.md ├── examples ├── modules_to_primitives_embedding_ingestion_searching_reranking.py ├── modules_to_primitives_indexing_searching.py ├── optimized_dag_for_embedding_ingestion_searching_reranking_llm.png ├── optimized_embedding_ingestion_rewriting_searching_reranking_llm.py ├── optimized_embedding_ingestion_searching.py ├── optimized_embedding_ingestion_searching_reranking.py ├── optimized_embedding_ingestion_searching_reranking_llm.py ├── test_dag.py ├── test_embedding_service.py ├── test_multiple_llm_calls.py ├── test_reranking_service.py ├── unoptimized_dag_for_embedding_ingestion_search_reranking_llm.png ├── unoptimized_embedding_ingestion_rewriting_searching_reranking_llm.py ├── unoptimized_embedding_ingestion_searching.py ├── unoptimized_embedding_ingestion_searching_reranking.py └── unoptimized_embedding_ingestion_searching_reranking_llm.py ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | */__pycache__ 3 | *.pyc 4 | *.pyo 5 | .DS_Store 6 | venv/ 7 | env/ 8 | .env 9 | .venv 10 | *.py[cod] 11 | *$py.class 12 | *.so 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | .pytest_cache/ 30 | 31 | */profiles 32 | */ray_worker -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "vllm"] 2 | path = vllm 3 | url = https://github.com/Txxx926/vllm 4 | branch = ayo 5 | -------------------------------------------------------------------------------- /Ayo/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Ayo/configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NetX-lab/Ayo/61607bcb31ed98e1101c5d60d8573105bf60d87c/Ayo/configs/__init__.py -------------------------------------------------------------------------------- /Ayo/configs/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, asdict 2 | from typing import Dict, Optional, List, Any 3 | 4 | @dataclass 5 | class EngineConfig: 6 | """configuration for execution engines""" 7 | name: str 8 | engine_type: str # e.g., "embedder", "llm", "vector_db" 9 | num_gpus: int = 0 10 | num_cpus: int = 1 11 | resources: Optional[Dict[str, int]] = None # e.g., {"GPU": 2} 12 | instances: int = 1 13 | model_config: Optional[Dict] = None 14 | latency_profile: Optional[Dict] = None 15 | 16 | def dict(self) -> Dict: 17 | """Convert config to dictionary""" 18 | return asdict(self) 19 | 20 | 21 | @dataclass 22 | class AppConfig: 23 | """application config""" 24 | engines: Dict[str, EngineConfig] 25 | optimization_passes: List[str] = None 26 | workflow_template: Dict[str, Any] = None -------------------------------------------------------------------------------- /Ayo/configs/model_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from enum import Enum 3 | from Ayo.dags.node_commons import NodeOps 4 | from Ayo.dags.node import Node 5 | from typing import Optional, Dict, List 6 | 7 | @dataclass 8 | class EmbeddingModelConfig: 9 | """embedding model config""" 10 | model_name: str 11 | dimension: int 12 | max_length: int = 512 13 | device: str = "cuda" 14 | batch_size: int = 1024 15 | vector_dim: int = 1024 16 | 17 | @dataclass 18 | class LLMConfig: 19 | """llm model config""" 20 | model_name: str 21 | temperature: float = 0.9 22 | top_p: float = 0.95 23 | device: str = "cuda" 24 | 25 | @dataclass 26 | class VectorDBConfig: 27 | """vector database config""" 28 | db_path: str 29 | dimension: int 30 | index_type: str = "HNSW" # or "IVF", "HNSW" etc. 31 | metric_type: str = "cosine" # or "IP", "cosine" etc. 32 | nprobe: int = 10 33 | 34 | 35 | class AggMode(str, Enum): 36 | """aggregator mode""" 37 | DUMMY = "dummy" 38 | MERGE = "merge" 39 | TOP_K = "top_k" 40 | 41 | 42 | def get_aggregator_config(node: Node, **kwargs) -> Dict: 43 | """get the aggregator config""" 44 | 45 | assert node.op_type == NodeOps.AGGREGATOR, f"node {node.name} is not an aggregator node" 46 | if node.parents[0].op_type == NodeOps.EMBEDDING: 47 | return {"agg_mode": AggMode.DUMMY} 48 | elif node.parents[0].op_type == NodeOps.VECTORDB_SEARCHING: 49 | return {"agg_mode": AggMode.MERGE} 50 | elif node.parents[0].op_type == NodeOps.RERANKING: 51 | 52 | agg_config= {"agg_mode": AggMode.TOP_K} 53 | agg_config.update( 54 | { 55 | node.config.get("topk",{}) or node.config.get("top_k",{}) or node.config.get("k",{}) or 5 56 | } 57 | ) 58 | return agg_config 59 | elif node.parents[0].op_type == NodeOps.VECTORDB_INGESTION: 60 | return {"agg_mode": AggMode.DUMMY} 61 | else: 62 | raise ValueError(f"Unsupported node op type: {node.op_type}") 63 | 64 | 65 | def get_aggregator_config_for_parent_node(node: Node, **kwargs) -> Dict: 66 | if node.op_type == NodeOps.EMBEDDING: 67 | return {"agg_mode": AggMode.DUMMY} 68 | elif node.op_type == NodeOps.VECTORDB_SEARCHING: 69 | return {"agg_mode": AggMode.MERGE} 70 | elif node.op_type == NodeOps.RERANKING: 71 | return {"agg_mode": AggMode.TOP_K} 72 | elif node.op_type == NodeOps.VECTORDB_INGESTION: 73 | return {"agg_mode": AggMode.DUMMY} 74 | else: 75 | raise ValueError(f"Unsupported node op type: {node.op_type}") -------------------------------------------------------------------------------- /Ayo/dags/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NetX-lab/Ayo/61607bcb31ed98e1101c5d60d8573105bf60d87c/Ayo/dags/__init__.py -------------------------------------------------------------------------------- /Ayo/dags/node_commons.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from enum import Enum 4 | from dataclasses import dataclass 5 | from typing import Dict, Optional 6 | 7 | class NodeStatus(Enum): 8 | PENDING = "pending" 9 | RUNNING = "running" 10 | COMPLETED = "completed" 11 | FAILED = "failed" 12 | 13 | class NodeAnnotation(str, Enum): 14 | """Node annotations for optimization hints""" 15 | SPLITTABLE = "splittable" # Node output can be split 16 | BATCHABLE = "batchable" # Node can process batched inputs 17 | NONE = "none" # No special optimization 18 | 19 | class NodeOps(str, Enum): 20 | """Node operations""" 21 | INPUT = "input" 22 | OUTPUT = "output" 23 | EMBEDDING = "embedding" 24 | VECTORDB_INGESTION = "vectordb_ingestion" 25 | VECTORDB_SEARCHING = "vectordb_searching" 26 | RERANKING = "reranking" 27 | LLM_PREFILLING = "llm_prefilling" 28 | LLM_DECODING = "llm_decoding" 29 | LLM_PARTIAL_PREFILLING = "llm_partial_prefilling" 30 | LLM_FULL_PREFILLING = "llm_full_prefilling" 31 | LLM_PARTIAL_DECODING = "llm_parallel_decoding" 32 | AGGREGATOR = "aggregator" 33 | 34 | 35 | 36 | class NodeType(Enum): 37 | """Node types in DAG""" 38 | INPUT = "input" # Input node that holds query inputs 39 | COMPUTE = "compute" # Computation node that performs operations 40 | OUTPUT = "output" # Output node that collects results 41 | 42 | 43 | @dataclass 44 | class NodeConfig: 45 | """Configuration for node execution""" 46 | batch_size: Optional[int] = None 47 | max_tokens: Optional[int] = None 48 | temperature: Optional[float] = None 49 | top_p: Optional[float] = None 50 | # Add other configuration parameters as needed 51 | 52 | @dataclass 53 | class NodeIOSchema: 54 | """Define the input and output schema for the node""" 55 | input_format: Dict[str, type] # input field name and type 56 | output_format: Dict[str, type] # output field name and type 57 | -------------------------------------------------------------------------------- /Ayo/engines/__init__.py: -------------------------------------------------------------------------------- 1 | from Ayo.engines.engine_types import EngineType, EngineRegistry, ENGINE_REGISTRY 2 | from Ayo.engines.base_engine import BaseEngine 3 | from Ayo.engines.embedder import EmbeddingEngine 4 | from Ayo.engines.vector_db import VectorDBEngine 5 | from Ayo.engines.reranker import RerankerEngine 6 | from Ayo.engines.llm import LLMEngine 7 | from Ayo.engines.aggregator import AggregateEngine 8 | 9 | -------------------------------------------------------------------------------- /Ayo/engines/aggregator.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Any, Tuple, Union 2 | import ray 3 | import asyncio 4 | import time 5 | import numpy as np 6 | from dataclasses import dataclass 7 | from collections import defaultdict 8 | 9 | # This engine would not really be used in the pipeline; 10 | # Rather we do the in-place aggregation in the engine scheduler slide 11 | 12 | @dataclass 13 | class AggregateRequest: 14 | """Data class for aggregate requests""" 15 | request_id: str 16 | query_id: str # Group requests by query ID 17 | agg_mode: str # Aggregation mode: concat, merge_dicts, select_best, custom 18 | data_sources: List[Any] # Data sources to aggregate 19 | callback_ref: Any = None # Ray ObjectRef for result callback 20 | timestamp: float = time.time() 21 | 22 | @ray.remote 23 | class AggregateEngine: 24 | """Ray Actor, used to handle aggregate requests 25 | 26 | Features: 27 | - Asynchronous request processing 28 | - Supports multiple aggregation modes 29 | - Groups requests by query ID 30 | """ 31 | 32 | def __init__(self, 33 | max_batch_size: int = 32, 34 | max_queue_size: int = 1000, 35 | scheduler_ref: Optional[ray.actor.ActorHandle] = None, 36 | **kwargs): 37 | 38 | self.max_batch_size = max_batch_size 39 | self.max_queue_size = max_queue_size 40 | 41 | # Asynchronous queue 42 | self.request_queue = asyncio.Queue(maxsize=max_queue_size) 43 | self.batch_queue = asyncio.Queue(maxsize=max_queue_size) 44 | 45 | # Track requests by query ID 46 | self.query_requests: Dict[str, List[AggregateRequest]] = {} 47 | 48 | # Start processing tasks 49 | self.running = True 50 | self.tasks = [ 51 | asyncio.create_task(self._batch_requests()), 52 | asyncio.create_task(self._process_batches()) 53 | ] 54 | 55 | self.scheduler_ref = scheduler_ref 56 | 57 | async def submit_request(self, 58 | request_id: str, 59 | query_id: str, 60 | agg_mode: str, 61 | data_sources: List[Any]) -> None: 62 | """Submit new aggregate request""" 63 | request = AggregateRequest( 64 | request_id=request_id, 65 | query_id=query_id, 66 | agg_mode=agg_mode, 67 | data_sources=data_sources, 68 | callback_ref=None 69 | ) 70 | 71 | if self.request_queue.qsize() >= self.max_queue_size: 72 | raise RuntimeError("Request queue is full") 73 | 74 | await self.request_queue.put(request) 75 | 76 | if query_id not in self.query_requests: 77 | self.query_requests[query_id] = [] 78 | self.query_requests[query_id].append(request) 79 | 80 | async def _batch_requests(self): 81 | """Asynchronous task for batching requests""" 82 | while self.running: 83 | try: 84 | batch_requests = await self._get_next_batch() 85 | if batch_requests: 86 | await self.batch_queue.put(batch_requests) 87 | else: 88 | await asyncio.sleep(0.01) # Avoid busy waiting 89 | except Exception as e: 90 | print(f"Error in batch processing task: {e}") 91 | continue 92 | 93 | async def _process_batches(self): 94 | """Asynchronous task for processing batches""" 95 | while self.running: 96 | try: 97 | try: 98 | batch_requests = await asyncio.wait_for( 99 | self.batch_queue.get(), 100 | timeout=0.1 101 | ) 102 | except asyncio.TimeoutError: 103 | continue 104 | 105 | # Process each request 106 | for request in batch_requests: 107 | try: 108 | # Process data based on aggregation mode 109 | result = await self._aggregate_data( 110 | request.agg_mode, 111 | request.data_sources 112 | ) 113 | 114 | # Create ObjectRef for result 115 | result_ref = ray.put(result) 116 | 117 | # If scheduler is set, send result to scheduler 118 | if self.scheduler_ref is not None: 119 | await self.scheduler_ref.on_result.remote( 120 | request.request_id, 121 | request.query_id, 122 | result_ref 123 | ) 124 | 125 | # Clean up request records 126 | if request.query_id in self.query_requests: 127 | self.query_requests[request.query_id].remove(request) 128 | if not self.query_requests[request.query_id]: 129 | del self.query_requests[request.query_id] 130 | 131 | except Exception as e: 132 | import traceback 133 | traceback.print_exc() 134 | print(f"Error in processing single request: {e}") 135 | continue 136 | 137 | except Exception as e: 138 | import traceback 139 | traceback.print_exc() 140 | print(f"Error in processing batch: {e}") 141 | continue 142 | 143 | async def _get_next_batch(self) -> List[AggregateRequest]: 144 | """Get next batch of requests to process""" 145 | batch_requests = [] 146 | processed_queries = set() 147 | 148 | while len(batch_requests) < self.max_batch_size: 149 | try: 150 | request = await asyncio.wait_for( 151 | self.request_queue.get(), 152 | timeout=0.01 153 | ) 154 | 155 | if request.query_id in processed_queries: 156 | # Process requests from the same query 157 | pending_requests = self.query_requests[request.query_id] 158 | batch_requests.extend(pending_requests) 159 | else: 160 | batch_requests.append(request) 161 | processed_queries.add(request.query_id) 162 | 163 | except asyncio.TimeoutError: 164 | break 165 | 166 | return batch_requests 167 | 168 | async def _aggregate_data(self, agg_mode: str, data_sources: List[Any]) -> Any: 169 | """Aggregate data based on aggregation mode""" 170 | if agg_mode == "concat": 171 | # Simple list concatenation 172 | result = [] 173 | for source in data_sources: 174 | if isinstance(source, list): 175 | result.extend(source) 176 | else: 177 | result.append(source) 178 | return result 179 | 180 | elif agg_mode == "merge_dicts": 181 | # Merge multiple dictionaries 182 | result = {} 183 | for source in data_sources: 184 | if isinstance(source, dict): 185 | result.update(source) 186 | return result 187 | 188 | elif agg_mode == "select_best": 189 | # Select best result (assuming each source has a score field) 190 | if not data_sources: 191 | return None 192 | 193 | best_source = None 194 | best_score = float('-inf') 195 | 196 | for source in data_sources: 197 | if isinstance(source, dict) and 'score' in source: 198 | if source['score'] > best_score: 199 | best_score = source['score'] 200 | best_source = source 201 | 202 | return best_source 203 | 204 | elif agg_mode == "topk": 205 | # Select top k results with highest scores 206 | # datasource format: 207 | if not data_sources: 208 | return [] 209 | 210 | # Use n parameter from request as k value, default to 3 if not provided 211 | k = 3 212 | 213 | # Filter valid data sources ( must be dictionaries and contain score field) 214 | valid_sources = [ 215 | source for source in data_sources 216 | if isinstance(source, dict) and 'score' in source 217 | ] 218 | 219 | # Sort by score in descending order and return top k 220 | sorted_sources = sorted( 221 | valid_sources, 222 | key=lambda x: x['score'], 223 | reverse=True 224 | ) 225 | 226 | return sorted_sources[:k] 227 | 228 | elif agg_mode == "custom": 229 | # Custom aggregation function (needs function and data in data_sources) 230 | if len(data_sources) >= 2 and callable(data_sources[0]): 231 | custom_func = data_sources[0] 232 | data = data_sources[1:] 233 | return custom_func(*data) 234 | return None 235 | 236 | else: 237 | # Default return original data sources 238 | return data_sources 239 | 240 | async def shutdown(self): 241 | """Shutdown service""" 242 | self.running = False 243 | for task in self.tasks: 244 | task.cancel() 245 | try: 246 | await task 247 | except asyncio.CancelledError: 248 | pass -------------------------------------------------------------------------------- /Ayo/engines/base_engine.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Any 2 | import ray 3 | import asyncio 4 | import time 5 | from abc import ABC, abstractmethod 6 | from dataclasses import dataclass 7 | 8 | @dataclass 9 | class BaseRequest: 10 | """Base data class for all engine requests""" 11 | request_id: str 12 | query_id: str # Group requests from same query 13 | callback_ref: Any # Ray ObjectRef for result 14 | timestamp: float = time.time() 15 | 16 | 17 | class BaseEngine(ABC): 18 | """Base class for all Ray Actor engines 19 | 20 | Features: 21 | - Async request handling 22 | - Request queuing and batching 23 | - Request tracking by query_id 24 | """ 25 | 26 | def __init__(self, 27 | max_batch_size: int = 32, 28 | max_queue_size: int = 1000, 29 | scheduler_ref: Optional[ray.actor.ActorHandle] = None, 30 | **kwargs): 31 | 32 | self.max_batch_size = max_batch_size 33 | self.max_queue_size = max_queue_size 34 | self.scheduler_ref = scheduler_ref 35 | 36 | # Async queues 37 | self.request_queue = asyncio.Queue(maxsize=max_queue_size) 38 | self.batch_queue = asyncio.Queue(maxsize=max_queue_size) 39 | 40 | # Track requests by query_id 41 | self.query_requests: Dict[str, List[BaseRequest]] = {} 42 | 43 | # Create event loop 44 | self.loop = asyncio.get_event_loop() 45 | 46 | # Start processing tasks 47 | self.running = True 48 | self.tasks = [ 49 | self.loop.create_task(self._batch_requests()), 50 | self.loop.create_task(self._process_batches()) 51 | ] 52 | 53 | @abstractmethod 54 | def _load_model(self): 55 | """Load the model - must be implemented by subclasses""" 56 | pass 57 | 58 | @abstractmethod 59 | async def submit_request(self, request_id: str, query_id: str, **kwargs) -> ray.ObjectRef: 60 | """Submit a new request - must be implemented by subclasses""" 61 | pass 62 | 63 | @abstractmethod 64 | async def _get_next_batch(self): 65 | """Get next batch of requests - must be implemented by subclasses""" 66 | pass 67 | 68 | @abstractmethod 69 | async def _process_batch(self, batch_data): 70 | """Process a batch of requests - must be implemented by subclasses""" 71 | pass 72 | 73 | async def _batch_requests(self): 74 | """Async task for batching requests""" 75 | while self.running: 76 | try: 77 | batch_data = await self._get_next_batch() 78 | if batch_data: 79 | await self.batch_queue.put(batch_data) 80 | else: 81 | await asyncio.sleep(0.01) 82 | except Exception as e: 83 | print(f"Error in batching task: {e}") 84 | continue 85 | 86 | async def _process_batches(self): 87 | """Async task for processing batches""" 88 | while self.running: 89 | try: 90 | try: 91 | batch_data = await asyncio.wait_for( 92 | self.batch_queue.get(), 93 | timeout=0.1 94 | ) 95 | except asyncio.TimeoutError: 96 | continue 97 | 98 | await self._process_batch(batch_data) 99 | 100 | except Exception as e: 101 | print(f"Error in process loop: {e}") 102 | continue 103 | 104 | async def shutdown(self): 105 | """Shutdown the engine""" 106 | self.running = False 107 | for task in self.tasks: 108 | task.cancel() 109 | try: 110 | await task 111 | except asyncio.CancelledError: 112 | pass -------------------------------------------------------------------------------- /Ayo/engines/embedder.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Any, Tuple 2 | import ray 3 | import torch 4 | import asyncio 5 | import time 6 | from dataclasses import dataclass 7 | import numpy as np 8 | from collections import deque 9 | from Ayo.logger import get_logger, GLOBAL_INFO_LEVEL 10 | 11 | logger = get_logger(__name__, level=GLOBAL_INFO_LEVEL) 12 | 13 | @dataclass 14 | class EmbeddingRequest: 15 | """Data class for embedding requests""" 16 | request_id: str 17 | query_id: str # Group requests from same query 18 | texts: List[str] 19 | callback_ref: Any # Ray ObjectRef for result 20 | timestamp: float = time.time() 21 | 22 | 23 | @ray.remote(num_gpus=1) 24 | class EmbeddingEngine: 25 | """Ray Actor for serving embedding requests with async processing 26 | 27 | Features: 28 | - Async request handling 29 | - Batches requests for efficient processing 30 | - Groups requests from same query 31 | """ 32 | 33 | def __init__(self, 34 | model_name: str = "BAAI/bge-large-en-v1.5", 35 | max_batch_size: int = 512, 36 | max_queue_size: int = 1000, 37 | device: str = "cuda" if torch.cuda.is_available() else "cpu", 38 | scheduler_ref: Optional[ray.actor.ActorHandle] = None, 39 | **kwargs): 40 | 41 | print(f"CUDA is available: {torch.cuda.is_available()}") 42 | if torch.cuda.is_available(): 43 | print(f"Number of available GPUs: {torch.cuda.device_count()}") 44 | print(f"Current GPU device: {torch.cuda.current_device()}") 45 | print(f"GPU name: {torch.cuda.get_device_name()}") 46 | 47 | self.model_name = model_name 48 | self.max_batch_size = max_batch_size 49 | self.max_queue_size = max_queue_size 50 | self.device = device 51 | 52 | self.name = kwargs.get("name", None) 53 | 54 | # Initialize model 55 | self.model = self._load_model() 56 | 57 | # Async queues 58 | self.request_queue = asyncio.Queue(maxsize=max_queue_size) 59 | self.batch_queue = asyncio.Queue(maxsize=max_queue_size) 60 | 61 | # Track requests by query_id 62 | self.query_requests: Dict[str, List[EmbeddingRequest]] = {} 63 | 64 | # Create event loop 65 | self.loop = asyncio.get_event_loop() 66 | 67 | # Start processing tasks 68 | self.running = True 69 | self.tasks = [ 70 | self.loop.create_task(self._batch_requests()), 71 | self.loop.create_task(self._process_batches()) 72 | ] 73 | 74 | self.scheduler_ref = scheduler_ref 75 | 76 | 77 | def is_ready(self): 78 | """Check if the engine is ready""" 79 | return True 80 | 81 | def _load_model(self): 82 | """Load the embedding model""" 83 | from sentence_transformers import SentenceTransformer 84 | 85 | #self.model.encode(["hello","world"]) 86 | 87 | 88 | model=SentenceTransformer(model_name_or_path=self.model_name,device=self.device) 89 | model.half() 90 | model.eval() 91 | 92 | #warm up 93 | warm_up_embeddings=model.encode(["hello","world"]) 94 | logger.debug(f"warm_up_embeddings: {warm_up_embeddings}") 95 | 96 | logger.debug(f"Load and warm up embedding model:{self.model_name} successfully on {self.device}") 97 | 98 | return model 99 | 100 | async def submit_request(self, 101 | request_id: str, 102 | query_id: str, 103 | texts: List[str]) -> ray.ObjectRef: 104 | """Submit a new embedding request""" 105 | 106 | 107 | request = EmbeddingRequest( 108 | request_id=request_id, 109 | query_id=query_id, 110 | texts=texts, 111 | callback_ref=None 112 | ) 113 | 114 | if self.request_queue.qsize() >= self.max_queue_size: 115 | raise RuntimeError("Request queue is full") 116 | 117 | await self.request_queue.put(request) 118 | 119 | if query_id not in self.query_requests: 120 | self.query_requests[query_id] = [] 121 | 122 | self.query_requests[query_id].append(request) 123 | 124 | 125 | async def _batch_requests(self): 126 | """Async task for batching requests""" 127 | while self.running: 128 | try: 129 | batch_requests, batch_texts = await self._get_next_batch() 130 | if batch_requests: 131 | await self.batch_queue.put((batch_requests, batch_texts)) 132 | else: 133 | await asyncio.sleep(0.01) # Avoid busy waiting 134 | except Exception as e: 135 | print(f"Error in batching task: {e}") 136 | continue 137 | 138 | async def _process_batches(self): 139 | """Async task for processing batches""" 140 | while self.running: 141 | try: 142 | try: 143 | batch_requests, batch_texts = await asyncio.wait_for( 144 | self.batch_queue.get(), 145 | timeout=0.1 146 | ) 147 | except asyncio.TimeoutError: 148 | continue 149 | 150 | #print(f"Processing batch requests: {batch_requests}") 151 | 152 | try: 153 | embeddings = await self.loop.run_in_executor( 154 | None, 155 | self._compute_embeddings, 156 | batch_texts 157 | ) 158 | 159 | start_idx = 0 160 | for request in batch_requests: 161 | try: 162 | end_idx = start_idx + len(request.texts) 163 | request_embeddings = embeddings[start_idx:end_idx] 164 | 165 | # create ObjectRef for result 166 | print(f"request_embeddings: {request_embeddings.shape}") 167 | 168 | result_ref = ray.put(request_embeddings) 169 | 170 | # If scheduler is set, send result to scheduler 171 | 172 | if self.scheduler_ref is not None: 173 | await self.scheduler_ref.on_result.remote( 174 | request.request_id, 175 | request.query_id, 176 | result_ref 177 | ) 178 | else: 179 | # If no scheduler is set, use the original callback 180 | ray.get(ray.put(request_embeddings, _owner=request.callback_ref)) 181 | 182 | # clean up request records 183 | # if request.query_id in self.query_requests: 184 | # self.query_requests[request.query_id].remove(request) 185 | # if not self.query_requests[request.query_id]: 186 | # del self.query_requests[request.query_id] 187 | 188 | start_idx = end_idx 189 | except Exception as e: 190 | print(f"Error processing individual request: {e}") 191 | continue 192 | 193 | except Exception as e: 194 | print(f"Error computing embeddings: {e}") 195 | continue 196 | 197 | except Exception as e: 198 | print(f"Error in inference task: {e}") 199 | continue 200 | 201 | async def _get_next_batch(self) -> Tuple[List[EmbeddingRequest], List[str]]: 202 | """Get next batch of requests to process""" 203 | batch_requests = [] 204 | batch_texts = [] 205 | processed_queries = set() 206 | 207 | 208 | 209 | #while len(batch_texts) < self.max_batch_size: 210 | while len(batch_texts) ==0: 211 | try: 212 | request = await asyncio.wait_for( 213 | self.request_queue.get(), 214 | timeout=0.01 215 | ) 216 | 217 | if request.query_id in processed_queries: 218 | pending_requests = self.query_requests[request.query_id] 219 | for pending_req in pending_requests: 220 | if len(batch_texts) + len(pending_req.texts) <= self.max_batch_size: 221 | batch_requests.append(pending_req) 222 | batch_texts.extend(pending_req.texts) 223 | else: 224 | if len(batch_texts) + len(request.texts) <= self.max_batch_size: 225 | batch_requests.append(request) 226 | batch_texts.extend(request.texts) 227 | processed_queries.add(request.query_id) 228 | else: 229 | await self.request_queue.put(request) 230 | break 231 | 232 | except asyncio.TimeoutError: 233 | break 234 | 235 | return batch_requests, batch_texts 236 | 237 | def _compute_embeddings(self, texts: List[str]) -> np.ndarray: 238 | """Compute embeddings for a batch of texts""" 239 | #self.model.encode(["hello","world"]) 240 | with torch.no_grad(): 241 | 242 | assert isinstance(texts,list) or isinstance(texts,str) 243 | 244 | begin = time.time() 245 | 246 | batch_size = len(texts) if isinstance(texts,list) else 1 247 | embeddings = self.model.encode(texts,batch_size=batch_size,show_progress_bar=False) 248 | #embeddings = embeddings[:,:768] 249 | 250 | logger.debug(f"texts' type: {type(texts)}, len: {len(texts)}, embeddings shape: {embeddings.shape}") 251 | end = time.time() 252 | logger.debug(f"embedding time for {len(texts)} texts: {end - begin}") 253 | return embeddings 254 | 255 | async def shutdown(self): 256 | """Shutdown the service""" 257 | self.running = False 258 | for task in self.tasks: 259 | task.cancel() 260 | try: 261 | await task 262 | except asyncio.CancelledError: 263 | pass 264 | 265 | 266 | 267 | 268 | -------------------------------------------------------------------------------- /Ayo/engines/engine_types.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Dict, Type, Optional, Any, Union 3 | from dataclasses import dataclass 4 | from Ayo.engines.base_engine import BaseEngine 5 | 6 | class EngineType(str, Enum): 7 | """Supported engine types""" 8 | INPUT = "input" 9 | OUTPUT = "output" 10 | EMBEDDER = "embedder" 11 | VECTOR_DB = "vector_db" 12 | RERANKER = "reranker" 13 | LLM = "llm" 14 | AGGREGATOR = "aggregator" 15 | DUMMY = "dummy" 16 | 17 | @classmethod 18 | def list(cls) -> list: 19 | """Get list of all engine types""" 20 | return list(cls) 21 | 22 | @classmethod 23 | def validate(cls, engine_type: str) -> bool: 24 | """Validate if engine type is supported""" 25 | return engine_type in cls.__members__.values() 26 | 27 | @dataclass 28 | class EngineSpec: 29 | """Engine specifications""" 30 | engine_class: Any # 使用Any代替具体类型 31 | default_config: Dict 32 | description: str 33 | 34 | class EngineRegistry: 35 | """Central registry for engine types and specifications""" 36 | 37 | def __init__(self): 38 | self._registry: Dict[str, EngineSpec] = {} 39 | self._register_default_engines() 40 | 41 | def _register_default_engines(self): 42 | """Register built-in engines""" 43 | # Lazy import to avoid circular imports 44 | from Ayo.engines.base_engine import BaseEngine 45 | from Ayo.engines.embedder import EmbeddingEngine 46 | from Ayo.engines.vector_db import VectorDBEngine 47 | from Ayo.engines.reranker import RerankerEngine 48 | from Ayo.engines.llm import LLMEngine 49 | from Ayo.engines.aggregator import AggregateEngine 50 | 51 | self.register( 52 | EngineType.INPUT, 53 | EngineSpec( 54 | engine_class=BaseEngine, 55 | default_config={}, 56 | description="The dummy engine for input node" 57 | ) 58 | ) 59 | 60 | self.register( 61 | EngineType.OUTPUT, 62 | EngineSpec( 63 | engine_class=BaseEngine, 64 | default_config={}, 65 | description="The dummy engine for output node" 66 | ) 67 | ) 68 | 69 | self.register( 70 | EngineType.EMBEDDER, 71 | EngineSpec( 72 | engine_class=EmbeddingEngine, 73 | default_config={ 74 | "model_name": "BAAI/bge-large-en-v1.5", 75 | "max_batch_size": 1024, 76 | 'vector_dim' :1024 77 | }, 78 | description="Text embedding engine using BGE model" 79 | ) 80 | ) 81 | 82 | self.register( 83 | EngineType.VECTOR_DB, 84 | EngineSpec( 85 | engine_class=VectorDBEngine, 86 | default_config={ 87 | "host": "localhost", 88 | "user": "asplos25", 89 | "password": "123456", 90 | "database": "database_asplos", 91 | "port": 5432, 92 | "max_batch_size": 1000 93 | }, 94 | description="Vector database engine using pgvector" 95 | ) 96 | ) 97 | 98 | self.register( 99 | EngineType.RERANKER, 100 | EngineSpec( 101 | engine_class=RerankerEngine, 102 | default_config={ 103 | "model_name": "BAAI/bge-reranker-large", 104 | "max_batch_size": 512 105 | }, 106 | description="Cross-encoder reranking engine" 107 | ) 108 | ) 109 | 110 | self.register( 111 | EngineType.LLM, 112 | EngineSpec( 113 | engine_class=LLMEngine, 114 | default_config={ 115 | "model_name": "meta-llama/Llama-2-7b-chat-hf", 116 | "tensor_parallel_size": 1, 117 | "max_num_seqs": 256, 118 | "max_queue_size": 1000, 119 | "trust_remote_code": False, 120 | "dtype": "auto" 121 | }, 122 | description="Large language model engine" 123 | ) 124 | ) 125 | 126 | self.register( 127 | EngineType.AGGREGATOR, 128 | EngineSpec( 129 | engine_class=AggregateEngine, 130 | default_config={ 131 | "max_batch_size": 32, 132 | "max_queue_size": 1000 133 | }, 134 | description="Data Aggregation Engine, supports multiple aggregation modes" 135 | ) 136 | ) 137 | 138 | def register(self, engine_type: str, spec: EngineSpec) -> None: 139 | """Register a new engine type""" 140 | if engine_type in self._registry: 141 | raise ValueError(f"Engine type {engine_type} already registered") 142 | self._registry[engine_type] = spec 143 | 144 | def unregister(self, engine_type: str) -> None: 145 | """Unregister an engine type""" 146 | if engine_type not in self._registry: 147 | raise ValueError(f"Engine type {engine_type} not registered") 148 | del self._registry[engine_type] 149 | 150 | def get_spec(self, engine_type: str) -> Optional[EngineSpec]: 151 | """Get engine specifications""" 152 | return self._registry.get(engine_type) 153 | 154 | def get_engine_class(self, engine_type: str) -> Optional[Type[BaseEngine]]: 155 | """Get engine class for given type""" 156 | spec = self.get_spec(engine_type) 157 | return spec.engine_class if spec else None 158 | 159 | def get_default_config(self, engine_type: str) -> Optional[Dict]: 160 | """Get default configuration for engine type""" 161 | spec = self.get_spec(engine_type) 162 | return spec.default_config if spec else None 163 | 164 | def list_engines(self) -> Dict[str, str]: 165 | """List all registered engines and their descriptions""" 166 | return { 167 | engine_type: spec.description 168 | for engine_type, spec in self._registry.items() 169 | } 170 | 171 | # Global engine registry instance 172 | ENGINE_REGISTRY = EngineRegistry() -------------------------------------------------------------------------------- /Ayo/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | import time 5 | from datetime import datetime 6 | from typing import Optional, Dict, Any, Union 7 | import threading 8 | 9 | class AyoLogger: 10 | 11 | # ANSI color codes 12 | COLORS = { 13 | 'DEBUG': '\033[36m', # cyan 14 | 'INFO': '\033[32m', # green 15 | 'WARNING': '\033[33m', # yellow 16 | 'ERROR': '\033[31m', # red 17 | 'CRITICAL': '\033[35m', # purple 18 | 'RESET': '\033[0m' # reset 19 | } 20 | 21 | # singleton instance 22 | _instance = None 23 | _lock = threading.Lock() 24 | 25 | def __new__(cls, *args, **kwargs): 26 | with cls._lock: 27 | if cls._instance is None: 28 | cls._instance = super(AyoLogger, cls).__new__(cls) 29 | cls._instance._initialized = False 30 | return cls._instance 31 | 32 | def __init__(self, 33 | name: str = "ayo", 34 | level: str = "INFO", 35 | log_file: Optional[str] = None, 36 | use_colors: bool = True, 37 | log_format: Optional[str] = None): 38 | """ 39 | initialize the logger 40 | 41 | Args: 42 | name: logger name 43 | level: logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) 44 | log_file: log file path, if None, only output to console 45 | use_colors: whether to use colors in console output 46 | log_format: custom log format, if None, use the default format 47 | """ 48 | if self._initialized: 49 | return 50 | 51 | self.name = name 52 | self.level = getattr(logging, level.upper()) 53 | self.log_file = log_file 54 | self.use_colors = use_colors 55 | 56 | # create the logger 57 | self.logger = logging.getLogger(name) 58 | self.logger.setLevel(self.level) 59 | self.logger.handlers = [] # clear the existing handlers 60 | 61 | # set the default log format 62 | if log_format is None: 63 | self.log_format = "%(asctime)s - %(levelname)s - %(name)s - [%(filename)s:%(lineno)d] - %(message)s" 64 | else: 65 | self.log_format = log_format 66 | 67 | # create the formatter 68 | self.formatter = logging.Formatter(self.log_format, datefmt="%Y-%m-%d %H:%M:%S") 69 | 70 | # add the console handler 71 | self._add_console_handler() 72 | 73 | # if the log file is specified, add the file handler 74 | if self.log_file: 75 | self._add_file_handler() 76 | 77 | self._initialized = True 78 | 79 | def _add_console_handler(self): 80 | """add the console handler""" 81 | console_handler = logging.StreamHandler(sys.stdout) 82 | console_handler.setLevel(self.level) 83 | 84 | if self.use_colors: 85 | # use the colored formatter 86 | colored_formatter = self._get_colored_formatter() 87 | console_handler.setFormatter(colored_formatter) 88 | else: 89 | console_handler.setFormatter(self.formatter) 90 | 91 | self.logger.addHandler(console_handler) 92 | 93 | def _add_file_handler(self): 94 | """add the file handler""" 95 | # ensure the log directory exists 96 | log_dir = os.path.dirname(self.log_file) 97 | if log_dir and not os.path.exists(log_dir): 98 | os.makedirs(log_dir) 99 | 100 | file_handler = logging.FileHandler(self.log_file) 101 | file_handler.setLevel(self.level) 102 | file_handler.setFormatter(self.formatter) 103 | self.logger.addHandler(file_handler) 104 | 105 | def _get_colored_formatter(self): 106 | """create the colored formatter""" 107 | class ColoredFormatter(logging.Formatter): 108 | def __init__(self, fmt, datefmt=None): 109 | super().__init__(fmt, datefmt) 110 | self.colors = AyoLogger.COLORS 111 | 112 | def format(self, record): 113 | levelname = record.levelname 114 | if levelname in self.colors: 115 | record.levelname = f"{self.colors[levelname]}{levelname}{self.colors['RESET']}" 116 | record.msg = f"{self.colors[levelname.upper()]}{record.msg}{self.colors['RESET']}" 117 | return super().format(record) 118 | 119 | return ColoredFormatter(self.log_format, datefmt="%Y-%m-%d %H:%M:%S") 120 | 121 | def set_level(self, level: str): 122 | """set the logging level""" 123 | level_upper = level.upper() 124 | if hasattr(logging, level_upper): 125 | self.level = getattr(logging, level_upper) 126 | self.logger.setLevel(self.level) 127 | for handler in self.logger.handlers: 128 | handler.setLevel(self.level) 129 | 130 | def debug(self, msg: str, *args, **kwargs): 131 | """record the DEBUG level log""" 132 | kwargs.setdefault('stacklevel', 2) 133 | self.logger.debug(msg, *args, **kwargs) 134 | 135 | def info(self, msg: str, *args, **kwargs): 136 | """record the INFO level log""" 137 | kwargs.setdefault('stacklevel', 2) 138 | self.logger.info(msg, *args, **kwargs) 139 | 140 | def warning(self, msg: str, *args, **kwargs): 141 | """record the WARNING level log""" 142 | kwargs.setdefault('stacklevel', 2) 143 | self.logger.warning(msg, *args, **kwargs) 144 | 145 | def error(self, msg: str, *args, **kwargs): 146 | """record the ERROR level log""" 147 | kwargs.setdefault('stacklevel', 2) 148 | self.logger.error(msg, *args, **kwargs) 149 | 150 | def critical(self, msg: str, *args, **kwargs): 151 | """record the CRITICAL level log""" 152 | kwargs.setdefault('stacklevel', 2) 153 | self.logger.critical(msg, *args, **kwargs) 154 | 155 | def exception(self, msg: str, *args, **kwargs): 156 | """record the exception information""" 157 | kwargs.setdefault('stacklevel', 2) 158 | self.logger.exception(msg, *args, **kwargs) 159 | 160 | def log_dict(self, level: str, data: Dict[str, Any], prefix: str = ""): 161 | """record the dictionary data""" 162 | level_method = getattr(self.logger, level.lower()) 163 | for key, value in data.items(): 164 | if prefix: 165 | key = f"{prefix}.{key}" 166 | if isinstance(value, dict): 167 | self.log_dict(level, value, key) 168 | else: 169 | level_method(f"{key}: {value}") 170 | 171 | 172 | 173 | 174 | 175 | def get_logger(name: str = "ayo", 176 | level: str = "INFO", 177 | log_file: Optional[str] = None, 178 | use_colors: bool = True) -> AyoLogger: 179 | """ 180 | get the Ayo logger instance 181 | 182 | Args: 183 | name: logger name 184 | level: logging level 185 | log_file: log file path 186 | use_colors: whether to use colored output 187 | 188 | Returns: 189 | AyoLogger instance 190 | """ 191 | return AyoLogger(name=name, level=level, log_file=log_file, use_colors=use_colors) 192 | 193 | 194 | GLOBAL_INFO_LEVEL = os.environ.get("AYO_INFO_LEVEL", "INFO") 195 | 196 | default_logger = AyoLogger(level=GLOBAL_INFO_LEVEL) 197 | 198 | debug = default_logger.debug 199 | info = default_logger.info 200 | warning = default_logger.warning 201 | error = default_logger.error 202 | critical = default_logger.critical 203 | exception = default_logger.exception 204 | -------------------------------------------------------------------------------- /Ayo/modules/base_module.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Any 2 | from Ayo.dags.node import Node 3 | from Ayo.dags.node_commons import NodeIOSchema 4 | 5 | class BaseModule: 6 | """ 7 | Base class for all modules 8 | """ 9 | 10 | def __init__(self, 11 | input_format: Dict[str, Any] = None, 12 | output_format: Dict[str, Any] = None, 13 | config: Dict[str, Any] = None): 14 | """ 15 | Initialize the module 16 | 17 | Args: 18 | input_format: Input format definition 19 | output_format: Output format definition 20 | config: Module configuration parameters 21 | """ 22 | self.input_format = input_format or {} 23 | self.output_format = output_format or {} 24 | self.config = config or {} 25 | 26 | self.pre_dependencies = [] 27 | self.post_dependencies = [] 28 | 29 | 30 | def __rshift__(self, other): 31 | self.post_dependencies.append(other) 32 | other.pre_dependencies.append(self) 33 | return self 34 | 35 | 36 | def to_primitive_nodes(self) -> List[Node]: 37 | """ 38 | Convert the module to a list of primitive nodes 39 | 40 | Returns: 41 | List[Node]: List of primitive nodes 42 | """ 43 | raise NotImplementedError("Subclasses must implement the to_primitive_nodes method") 44 | 45 | def validate_io_schema(self) -> bool: 46 | """ 47 | Validate the input and output format 48 | 49 | Returns: 50 | bool: Validation result 51 | """ 52 | # Default implementation, subclasses can override this method for more detailed validation 53 | return len(self.input_format) > 0 and len(self.output_format) > 0 54 | 55 | 56 | def __str__(self) -> str: 57 | """Return the string representation of the module""" 58 | return f"{self.__class__.__name__}(input={self.input_format}, output={self.output_format})" 59 | 60 | def __repr__(self) -> str: 61 | """Return the detailed string representation of the module""" 62 | return f"{self.__class__.__name__}(input={self.input_format}, output={self.output_format}, config={self.config})" -------------------------------------------------------------------------------- /Ayo/modules/embedding.py: -------------------------------------------------------------------------------- 1 | from Ayo.dags.node import Node 2 | from Ayo.dags.node_commons import NodeType, NodeOps, NodeIOSchema 3 | from Ayo.engines.engine_types import EngineType 4 | from Ayo.modules.base_module import BaseModule 5 | 6 | 7 | class EmbeddingModule(BaseModule): 8 | def __init__(self, 9 | input_format: dict={ 10 | "text": list[str] 11 | }, 12 | output_format: dict={ 13 | "embeddings": list 14 | }, 15 | config: dict=None): 16 | """Initialize the Embedding Module. 17 | 18 | This module is responsible for converting text into vector embeddings 19 | using an embedding model. 20 | 21 | Args: 22 | input_format (dict): Input format definition, defaults to: 23 | - text (list[str]): List of text strings to be embedded 24 | output_format (dict): Output format definition, defaults to: 25 | - embeddings (list): List of vector embeddings 26 | config (dict, optional): Configuration parameters for the embedding process 27 | """ 28 | super().__init__(input_format, output_format, config) 29 | 30 | def to_primitive_nodes(self): 31 | return [ 32 | Node( 33 | name="Embedding", 34 | io_schema=NodeIOSchema( 35 | input_format=self.input_format, 36 | output_format=self.output_format 37 | ), 38 | op_type=NodeOps.EMBEDDING, 39 | engine_type=EngineType.EMBEDDING, 40 | node_type=NodeType.COMPUTE, 41 | config=self.config 42 | ) 43 | ] 44 | -------------------------------------------------------------------------------- /Ayo/modules/indexing.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from Ayo.dags.node import Node 3 | from Ayo.dags.node_commons import NodeType, NodeOps, NodeIOSchema 4 | from Ayo.engines.engine_types import EngineType 5 | from Ayo.modules.base_module import BaseModule 6 | 7 | class IndexingModule(BaseModule): 8 | def __init__(self, 9 | input_format: dict={ 10 | "passages": List[str] 11 | }, 12 | output_format: dict={ 13 | "index_status": bool 14 | }, 15 | config: dict=None): 16 | """Initialize the Indexing Module. 17 | 18 | This module is responsible for embedding passages and ingesting them into a vector database. 19 | It creates an index that can be used for vector similarity search. 20 | 21 | Args: 22 | input_format (dict): Input format definition, defaults to: 23 | - passages (List[str]): List of text passages to be indexed 24 | output_format (dict): Output format definition, defaults to: 25 | - index_status (bool): Status indicating whether indexing was successful 26 | config (dict, optional): Configuration parameters for the indexing process 27 | """ 28 | super().__init__(input_format, output_format, config) 29 | 30 | def to_primitive_nodes(self): 31 | # create embedding node 32 | embedding_node = Node( 33 | name="EmbeddingForIndex", 34 | io_schema=NodeIOSchema( 35 | input_format={"passages": List[str]}, 36 | output_format={"passages_embeddings": List[float]} 37 | ), 38 | op_type=NodeOps.EMBEDDING, 39 | engine_type=EngineType.EMBEDDER, 40 | node_type=NodeType.COMPUTE, 41 | config=self.config 42 | ) 43 | 44 | # create ingestion node 45 | ingestion_node = Node( 46 | name="IngestionForIndex", 47 | io_schema=NodeIOSchema( 48 | input_format={"passages": List[str], "passages_embeddings": List[float]}, 49 | output_format={"index_status": bool} 50 | ), 51 | op_type=NodeOps.VECTORDB_INGESTION, 52 | engine_type=EngineType.VECTOR_DB, 53 | node_type=NodeType.COMPUTE, 54 | config=self.config 55 | ) 56 | 57 | # connect nodes 58 | embedding_node >> ingestion_node 59 | 60 | return [embedding_node, ingestion_node] 61 | -------------------------------------------------------------------------------- /Ayo/modules/llm_syhthesizing.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from Ayo.dags.node import Node 3 | from Ayo.dags.node_commons import NodeType, NodeOps, NodeIOSchema 4 | from Ayo.engines.engine_types import EngineType 5 | from Ayo.modules.base_module import BaseModule 6 | from enum import Enum 7 | 8 | class LLMGenerationMode(Enum): 9 | NORMAL = "normal" 10 | SUMMARIZATION = "summarization" 11 | REFINEMENT = "refinement" 12 | 13 | class LLMSynthesizingModule(BaseModule): 14 | RAG_PROMPT_TEMPLATE = """\ 15 | You are an AI assistant specialized in Retrieval-Augmented Generation (RAG). Your responses 16 | must be based strictly on the retrieved documents provided to you. Follow these guidelines: 17 | 1. Use Retrieved Information Only - Your responses must rely solely on the retrieved documents. 18 | If the retrieved documents do not contain relevant information, explicitly state: 'Based on the 19 | available information, I cannot determine the answer.'\n" 20 | 2. Response Formatting - Directly answer the question using the retrieved data. If multiple 21 | sources provide information, synthesize them in a coherent manner. If no relevant information 22 | is found, clearly state that.\n" 23 | 3. Clarity and Precision - Avoid speculative language such as 'I think' or 'It might be.' 24 | Maintain a neutral and factual tone.\n" 25 | 4. Information Transparency - Do not fabricate facts or sources. If needed, summarize the 26 | retrieved information concisely.\n" 27 | 5. Handling Out-of-Scope Queries - If a question is outside the retrieved data (e.g., opinions, 28 | unverifiable claims), state: 'The retrieved documents do not provide information on this topic.'\n 29 | ---\n 30 | Example Interactions:\n 31 | User Question: Who founded Apple Inc.?\n 32 | Retrieved Context: 'Apple Inc. was co-founded in 1976 by Steve Jobs, Steve Wozniak, and Ronald Wayne.'\n 33 | Model Answer: 'Apple Inc. was co-founded in 1976 by Steve Jobs, Steve Wozniak, and Ronald Wayne.'\n 34 | ---\n 35 | User Question: When was the first iPhone released, and what were its key features?\n" 36 | Retrieved Context: 'The first iPhone was announced by Steve Jobs on January 9, 2007, and released on June 29, 2007.' " 37 | "'The original iPhone featured a 3.5-inch touchscreen display, a 2-megapixel camera, and ran on iOS.'\n" 38 | Model Answer: 'The first iPhone was announced on January 9, 2007, and released on June 29, 2007. " 39 | "It featured a 3.5-inch touchscreen display, a 2-megapixel camera, and ran on iOS.'\ 40 | This ensures accuracy, reliability, and transparency in all responses. And you should directly answer the question based on the retrieved context and keep it concise as possible. 41 | Here is the question: {question}? 42 | Here is the retrieved context: {context} 43 | Here is your answer: 44 | """ 45 | 46 | SUMMARIZATION_PROMPT_TEMPLATE = """\ 47 | You are an AI assistant specialized in Question Answering. You would be provided with a question and several candidate answers. 48 | Your task is to summarize the candidate answers into a single answer. You should keep the original meaning of the question and the candidate answers. 49 | You should select the most relevant answer from the candidate answers and summarize it. Always need to keep your answer concise and to the point. 50 | Here is the question: {question}? 51 | Here are the candidate answers: {answers} 52 | Here is your answer: 53 | """ 54 | 55 | def __init__(self, 56 | input_format: dict={ 57 | "question": str, 58 | "context": List[str] 59 | }, 60 | output_format: dict={ 61 | "answer": str}, 62 | config: dict={ 63 | "generation_mode":LLMGenerationMode.NORMAL, 64 | 'prompt_template': RAG_PROMPT_TEMPLATE, 65 | 'parse_json': True, 66 | 'prompt':RAG_PROMPT_TEMPLATE, 67 | 'partial_output': False, 68 | 'partial_prefilling': False, 69 | 'llm_partial_decoding_idx': -1 70 | }): 71 | """Initialize the LLM Synthesizing Module. 72 | 73 | This module is responsible for generating answers using a Large Language Model (LLM) 74 | based on retrieved context. It supports multiple generation modes: normal generation, 75 | summarization, and refinement. 76 | 77 | Args: 78 | input_format (dict): Input format definition, defaults to: 79 | - question (str): User question 80 | - context (List[str]): List of retrieved context documents 81 | output_format (dict): Output format definition, defaults to: 82 | - answer (str): Generated response 83 | config (dict): Configuration parameters, including: 84 | - generation_mode (LLMGenerationMode): Generation mode, can be NORMAL, SUMMARIZATION, or REFINEMENT 85 | - prompt_template (str): Prompt template string 86 | - parse_json (bool): Whether to parse JSON output 87 | - prompt (str): Complete prompt string 88 | - partial_output (bool): Whether to enable partial output 89 | - partial_prefilling (bool): Whether to enable partial prefilling 90 | - llm_partial_decoding_idx (int): Partial decoding index 91 | 92 | Notes: 93 | - Summarization mode (SUMMARIZATION) requires 'context_num' to be specified in config 94 | - Refinement mode (REFINEMENT) also requires 'context_num' to be specified in config 95 | """ 96 | 97 | super().__init__(input_format, output_format, config) 98 | 99 | # TODO: support the below generation mode 100 | if config["generation_mode"]==LLMGenerationMode.SUMMARIZATION: 101 | assert 'context_num' in config, "context_num is required for summarization" 102 | elif config["generation_mode"]==LLMGenerationMode.REFINEMENT: 103 | assert 'context_num' in config, "context_num is required for refinement" 104 | 105 | 106 | def to_primitive_nodes(self): 107 | if self.config["generation_mode"]==LLMGenerationMode.NORMAL: 108 | return [ 109 | Node( 110 | name="LLMSynthesizingPrefilling", 111 | input_format=self.input_format, 112 | output_format=self.output_format, 113 | node_type=NodeType.COMPUTE, 114 | engine_type=EngineType.LLM, 115 | op_type=NodeOps.LLM_PREFILLING, 116 | config=self.config 117 | ), 118 | Node( 119 | name="LLMSynthesizingDecoding", 120 | input_format=self.input_format, 121 | output_format=self.output_format, 122 | node_type=NodeType.COMPUTE, 123 | engine_type=EngineType.LLM, 124 | op_type=NodeOps.LLM_DECODING, 125 | config=self.config 126 | ) 127 | ] 128 | 129 | elif self.config["generation_mode"]==LLMGenerationMode.SUMMARIZATION: 130 | return [ 131 | Node( 132 | name="LLMSynthesizingPrefilling", 133 | ) 134 | ] 135 | 136 | elif self.config["generation_mode"]==LLMGenerationMode.REFINEMENT: 137 | return [ 138 | Node( 139 | name="LLMSynthesizingPrefilling", 140 | ) 141 | ] 142 | 143 | -------------------------------------------------------------------------------- /Ayo/modules/mod_to_prim.py: -------------------------------------------------------------------------------- 1 | from Ayo.modules.base_module import BaseModule 2 | from Ayo.modules.indexing import IndexingModule 3 | from Ayo.modules.query_expanding import QueryExpandingModule 4 | from Ayo.modules.searching import SearchingModule 5 | from Ayo.modules.reranking import RerankingModule 6 | from typing import List 7 | 8 | def transform_mod_to_prim(mods: List[BaseModule]): 9 | """ 10 | Transform a chain of modules to a list of primitive nodes 11 | """ 12 | mods_2_nodes={} 13 | for mod in mods: 14 | mods_2_nodes[mod]=mod.to_primitive_nodes() 15 | 16 | for mod in mods: 17 | for post_mod in mod.post_dependencies: 18 | mods_2_nodes[mod][-1]>>mods_2_nodes[post_mod][0] 19 | 20 | node_list=[] 21 | for mod in mods: 22 | node_list.extend(mods_2_nodes[mod]) 23 | return node_list 24 | 25 | 26 | if __name__=="__main__": 27 | indexing_module = IndexingModule(input_format={"passages": List[str]}, output_format={"index_status": bool}) 28 | query_expanding_module=QueryExpandingModule(input_format={"query": str}, output_format={"expanded_queries": List[str]},config={"expanded_query_num": 3}) 29 | searching_module = SearchingModule(input_format={"index_status": bool, "expanded_queries": List[str]}, output_format={"searching_results": List[str]}) 30 | reranking_module=RerankingModule(input_format={"searching_results": List[str]}, output_format={"reranking_results": List[str]}) 31 | 32 | 33 | indexing_module>>query_expanding_module>>searching_module>>reranking_module 34 | 35 | 36 | node_list=transform_mod_to_prim([indexing_module,query_expanding_module,searching_module,reranking_module]) 37 | 38 | 39 | print(node_list) 40 | 41 | -------------------------------------------------------------------------------- /Ayo/modules/prompt_template.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | QUERY_EXPANDING_PROMPT_TEMPLATE_STRING= """\ 4 | Please rewrite the following question into {refine_question_number} more refined one. \ 5 | You should keep the original meaning of the question, but make it more suitable and clear for context retrieval. \ 6 | The original question is: {question}? \ 7 | Please output your answer in json format. \ 8 | It should contain {refine_question_number} new refined questions.\ 9 | For example, if the expaned number is 3, the json output should be like this: \ 10 | {{\ 11 | "revised question1": "[refined question 1]",\ 12 | "revised question2": "[refined question 2]",\ 13 | "revised question3": "[refined question 3]"\ 14 | }}\ 15 | You just need to output the json string, do not output any other information or additional text!!! \ 16 | The json output:""" 17 | 18 | RAG_QUESTION_ANSWERING_PROMPT_TEMPLATE_STRING="""\ 19 | You are an AI assistant specialized in Retrieval-Augmented Generation (RAG). Your responses 20 | must be based strictly on the retrieved documents provided to you. Follow these guidelines: 21 | 1. Use Retrieved Information Only - Your responses must rely solely on the retrieved documents. 22 | If the retrieved documents do not contain relevant information, explicitly state: 'Based on the 23 | available information, I cannot determine the answer.'\n" 24 | 2. Response Formatting - Directly answer the question using the retrieved data. If multiple 25 | sources provide information, synthesize them in a coherent manner. If no relevant information 26 | is found, clearly state that.\n" 27 | 3. Clarity and Precision - Avoid speculative language such as 'I think' or 'It might be.' 28 | Maintain a neutral and factual tone.\n" 29 | 4. Information Transparency - Do not fabricate facts or sources. If needed, summarize the 30 | retrieved information concisely.\n" 31 | 5. Handling Out-of-Scope Queries - If a question is outside the retrieved data (e.g., opinions, 32 | unverifiable claims), state: 'The retrieved documents do not provide information on this topic.'\n 33 | ---\n 34 | Example Interactions:\n 35 | User Question: Who founded Apple Inc.?\n 36 | Retrieved Context: 'Apple Inc. was co-founded in 1976 by Steve Jobs, Steve Wozniak, and Ronald Wayne.'\n 37 | Model Answer: 'Apple Inc. was co-founded in 1976 by Steve Jobs, Steve Wozniak, and Ronald Wayne.'\n 38 | ---\n 39 | User Question: When was the first iPhone released, and what were its key features?\n" 40 | Retrieved Context: 'The first iPhone was announced by Steve Jobs on January 9, 2007, and released on June 29, 2007.' " 41 | "'The original iPhone featured a 3.5-inch touchscreen display, a 2-megapixel camera, and ran on iOS.'\n" 42 | Model Answer: 'The first iPhone was announced on January 9, 2007, and released on June 29, 2007. " 43 | "It featured a 3.5-inch touchscreen display, a 2-megapixel camera, and ran on iOS.'\ 44 | This ensures accuracy, reliability, and transparency in all responses. And you should directly answer the question based on the retrieved context and keep it concise as possible. 45 | Here is the question: {question}? 46 | Here is the retrieved context: {context} 47 | Here is your answer, make sure it is concise: 48 | """ 49 | 50 | def replace_placeholders(prompt_template: str, **kwargs): 51 | for key,value in kwargs.items(): 52 | prompt_template = prompt_template.replace(f"{{{key}}}", f"{{{value}}}") 53 | print(prompt_template) 54 | return prompt_template 55 | 56 | #TODO: Currently, these classes have been actually used in the modules and the related payload-transformations 57 | # We should do these to make the prompt-template-transformation more flexible and reusable. 58 | 59 | class PromptTemplate: 60 | def __init__(self, prompt_template: str): 61 | # the template is a string with placeholders like {key} 62 | self.prompt_template = prompt_template 63 | 64 | self.placeholders = re.findall(r'\{([^{}]+)\}', self.prompt_template) 65 | 66 | def fill_template(self, **kwargs): 67 | raise NotImplementedError("This method should be implemented by the subclass") 68 | 69 | 70 | class QueryExpandingPromptTemplate(PromptTemplate): 71 | 72 | default_template = """\ 73 | Please rewrite the following question into {refine_question_number} more refined one. \ 74 | You should keep the original meaning of the question, but make it more suitable and clear for context retrieval. \ 75 | The original question is: {question}? \ 76 | Please output your answer in json format. \ 77 | It should contain {refine_question_number} new refined questions.\ 78 | For example, if the expaned number is 3, the json output should be like this: \ 79 | {{\ 80 | "revised question1": "[refined question 1]",\ 81 | "revised question2": "[refined question 2]",\ 82 | "revised question3": "[refined question 3]"\ 83 | }}\ 84 | You just need to output the json string, do not output any other information or additional text!!! \ 85 | The json output:""" 86 | 87 | 88 | def __init__(self, prompt_template: str=None): 89 | if prompt_template is None: 90 | prompt_template = self.default_template 91 | super().__init__(prompt_template) 92 | 93 | # special case, here we do not need to check the placeholders 94 | self.placeholders = re.findall(r'\{([^{}]+)\}', self.prompt_template) 95 | 96 | def fill_template(self, **kwargs): 97 | refine_question_number = None 98 | question = None 99 | for key,value in kwargs.items(): 100 | if 'num' in key.lower(): 101 | refine_question_number = value 102 | elif key.lower() in ['query','question','question_']: 103 | question = value 104 | 105 | 106 | assert refine_question_number is not None, "refine_question_number is required" 107 | assert question is not None, "question is required" 108 | 109 | keys = ", ".join([f"question{i+1}" for i in range(refine_question_number)]) 110 | json_example = "{\n " + "\n ".join([f"\"question{i+1}\": \"[refined version {i+1}]\"" + ("," if i < refine_question_number-1 else "") for i in range(refine_question_number)]) + "\n }" 111 | 112 | return self.prompt_template.format( 113 | refine_question_number=refine_question_number, 114 | question=question, 115 | keys=keys, 116 | json_example=json_example 117 | ) 118 | 119 | 120 | class RAGQuestionAnsweringPromptTemplate(PromptTemplate): 121 | 122 | default_template="""\ 123 | You are an AI assistant specialized in Retrieval-Augmented Generation (RAG). Your responses 124 | must be based strictly on the retrieved documents provided to you. Follow these guidelines: 125 | 1. Use Retrieved Information Only - Your responses must rely solely on the retrieved documents. 126 | If the retrieved documents do not contain relevant information, explicitly state: 'Based on the 127 | available information, I cannot determine the answer.'\n" 128 | 2. Response Formatting - Directly answer the question using the retrieved data. If multiple 129 | sources provide information, synthesize them in a coherent manner. If no relevant information 130 | is found, clearly state that.\n" 131 | 3. Clarity and Precision - Avoid speculative language such as 'I think' or 'It might be.' 132 | Maintain a neutral and factual tone.\n" 133 | 4. Information Transparency - Do not fabricate facts or sources. If needed, summarize the 134 | retrieved information concisely.\n" 135 | 5. Handling Out-of-Scope Queries - If a question is outside the retrieved data (e.g., opinions, 136 | unverifiable claims), state: 'The retrieved documents do not provide information on this topic.'\n 137 | ---\n 138 | Example Interactions:\n 139 | User Question: Who founded Apple Inc.?\n 140 | Retrieved Context: 'Apple Inc. was co-founded in 1976 by Steve Jobs, Steve Wozniak, and Ronald Wayne.'\n 141 | Model Answer: 'Apple Inc. was co-founded in 1976 by Steve Jobs, Steve Wozniak, and Ronald Wayne.'\n 142 | ---\n 143 | User Question: When was the first iPhone released, and what were its key features?\n" 144 | Retrieved Context: 'The first iPhone was announced by Steve Jobs on January 9, 2007, and released on June 29, 2007.' " 145 | "'The original iPhone featured a 3.5-inch touchscreen display, a 2-megapixel camera, and ran on iOS.'\n" 146 | Model Answer: 'The first iPhone was announced on January 9, 2007, and released on June 29, 2007. " 147 | "It featured a 3.5-inch touchscreen display, a 2-megapixel camera, and ran on iOS.'\ 148 | This ensures accuracy, reliability, and transparency in all responses. And you should directly answer the question based on the retrieved context and keep it concise as possible. 149 | Here is the question: {question}? 150 | Here is the retrieved context: {context} 151 | Here is your answer: 152 | """ 153 | def __init__(self, prompt_template: str=None): 154 | 155 | if prompt_template is None: 156 | prompt_template = self.default_template 157 | super().__init__(prompt_template) 158 | 159 | #TODO: add more prompt templates 160 | -------------------------------------------------------------------------------- /Ayo/modules/query_expanding.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from Ayo.dags.node import Node 3 | from Ayo.dags.node_commons import NodeType, NodeOps, NodeIOSchema 4 | from Ayo.engines.engine_types import EngineType 5 | from typing import List 6 | from Ayo.modules.base_module import BaseModule 7 | 8 | class QueryExpandingModule(BaseModule): 9 | 10 | prompt_template="""Please rewrite the following question into {refine_question_number} more refined one. \ 11 | You should keep the original meaning of the question, but make it more suitable and clear for context retrieval. \ 12 | The original question is: {question}? \ 13 | Please output your answer in json format. \ 14 | It should contain {refine_question_number} new refined questions.\ 15 | For example, if the expaned number is 3, the json output should be like this: \ 16 | {{\ 17 | "revised question1": "[refined question 1]",\ 18 | "revised question2": "[refined question 2]",\ 19 | "revised question3": "[refined question 3]"\ 20 | }}\ 21 | You just need to output the json string, do not output any other information or additional text!!! \ 22 | The json output:""" 23 | 24 | 25 | #temperature=0.9, top_p=0.95 for llama2-7b-chat-hf 26 | 27 | # prompt 28 | # support_partial_output = node.config.get('partial_output', False) 29 | # support_partial_prefilling = node.config.get('partial_prefilling', False) 30 | 31 | # llm_partial_decoding_idx = node.config.get("llm_partial_decoding_idx", -1) 32 | 33 | def __init__(self, 34 | input_format: dict={ 35 | "query": str 36 | }, 37 | output_format: dict={ 38 | "expanded_queries": List[str] 39 | }, 40 | config: dict={ 41 | 'expanded_query_num': 3, 42 | 'prompt_template': prompt_template, 43 | 'parse_json': True, 44 | 'prompt':prompt_template, 45 | 'partial_output': False, 46 | 'partial_prefilling': False, 47 | 'llm_partial_decoding_idx': -1 48 | }): 49 | """Initialize the Query Expanding Module. 50 | 51 | This module is responsible for expanding a single query into multiple refined queries 52 | using a language model, which can improve retrieval performance by capturing different 53 | aspects of the original query. 54 | 55 | Args: 56 | input_format (dict): Input format definition, defaults to: 57 | - query (str): Original user query 58 | output_format (dict): Output format definition, defaults to: 59 | - expanded_queries (List[str]): List of expanded/refined queries 60 | config (dict): Configuration parameters, including: 61 | - expanded_query_num (int): Number of queries to generate 62 | - prompt_template (str): Template for query expansion prompt 63 | - parse_json (bool): Whether to parse JSON output 64 | - prompt (str): Complete prompt string 65 | - partial_output (bool): Whether to enable partial output 66 | - partial_prefilling (bool): Whether to enable partial prefilling 67 | - llm_partial_decoding_idx (int): Partial decoding index 68 | """ 69 | super().__init__(input_format, output_format, config) 70 | 71 | def to_primitive_nodes(self): 72 | # create LLM prefilling node 73 | 74 | llm_internal_id = f"query_expanding_{uuid.uuid4()}" 75 | 76 | llm_prefilling_node = Node( 77 | name="QueryExpandingPrefilling", 78 | io_schema=NodeIOSchema( 79 | input_format={"query": str}, 80 | output_format={"prefill_state": dict} 81 | ), 82 | op_type=NodeOps.LLM_PREFILLING, 83 | engine_type=EngineType.LLM, 84 | node_type=NodeType.COMPUTE, 85 | config={ 86 | 'prompt_template': self.config.get('prompt_template', self.prompt_template), 87 | 'prompt': self.config.get('prompt', self.prompt_template), 88 | 'expanded_query_num': self.config.get('expanded_query_num', 3), 89 | 'parse_json': self.config.get('parse_json', True), 90 | 'partial_output': self.config.get('partial_output', False), 91 | 'partial_prefilling': self.config.get('partial_prefilling', False), 92 | 'llm_partial_decoding_idx': self.config.get('llm_partial_decoding_idx', -1), 93 | 'llm_internal_id': llm_internal_id 94 | } 95 | ) 96 | 97 | # create LLM decoding node 98 | llm_decoding_node = Node( 99 | name="QueryExpandingDecoding", 100 | io_schema=NodeIOSchema( 101 | input_format={"query": str, "prefill_state": dict}, 102 | output_format={"expanded_queries": List[str]} 103 | ), 104 | op_type=NodeOps.LLM_DECODING, 105 | engine_type=EngineType.LLM, 106 | node_type=NodeType.COMPUTE, 107 | config={ 108 | 'prompt_template': self.config.get('prompt_template', self.prompt_template), 109 | 'prompt': self.config.get('prompt', self.prompt_template), 110 | 'expanded_query_num': self.config.get('expanded_query_num', 3), 111 | 'parse_json': self.config.get('parse_json', True), 112 | 'partial_output': self.config.get('partial_output', False), 113 | 'partial_prefilling': self.config.get('partial_prefilling', False), 114 | 'llm_partial_decoding_idx': self.config.get('llm_partial_decoding_idx', -1), 115 | 'llm_internal_id': llm_internal_id 116 | } 117 | ) 118 | 119 | # Connect the nodes 120 | llm_prefilling_node >> llm_decoding_node 121 | 122 | return [llm_prefilling_node, llm_decoding_node] 123 | 124 | def format_prompt(self, question): 125 | refine_question_number = self.config.get('expanded_query_num', None) 126 | 127 | assert refine_question_number is not None, "expanded_query_num is not set" 128 | #keys = ", ".join([f"question{i+1}" for i in range(refine_question_number)]) 129 | #json_example = "{\n " + "\n ".join([f"\"question{i+1}\": \"[refined version {i+1}]\"" + ("," if i < refine_question_number-1 else "") for i in range(refine_question_number)]) + "\n }" 130 | 131 | return self.prompt_template.format( 132 | refine_question_number=refine_question_number, 133 | question=question, 134 | ) 135 | 136 | 137 | if __name__ == "__main__": 138 | query_expanding_module = QueryExpandingModule() 139 | question = "What is the capital of France?" 140 | print(query_expanding_module.format_prompt(question)) -------------------------------------------------------------------------------- /Ayo/modules/reranking.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from Ayo.dags.node import Node 3 | from Ayo.dags.node_commons import NodeType, NodeOps,NodeIOSchema 4 | from Ayo.engines.engine_types import EngineType 5 | from Ayo.modules.base_module import BaseModule 6 | 7 | class RerankingModule(BaseModule): 8 | def __init__(self, 9 | input_format: dict={ 10 | "query": str, 11 | "passages": List[str] 12 | }, 13 | output_format: dict={ 14 | "passages": List[str] 15 | }, 16 | config: dict={ 17 | 'top_k': 10 18 | }): 19 | """Initialize the Reranking Module. 20 | 21 | This module is responsible for reranking the passages based on the query. 22 | 23 | Args: 24 | input_format (dict): Input format definition, defaults to: 25 | - query (str): The query to rerank the passages 26 | - passages (List[str]): The passages to rerank 27 | output_format (dict): Output format definition, defaults to: 28 | - passages (List[str]): The reranked passages 29 | config (dict, optional): Configuration parameters for the reranking process 30 | """ 31 | 32 | super().__init__(input_format, output_format, config) 33 | 34 | def to_primitive_nodes(self): 35 | return [ 36 | Node( 37 | name="Reranking", 38 | io_schema=NodeIOSchema( 39 | input_format=self.input_format, 40 | output_format=self.output_format 41 | ), 42 | op_type=NodeOps.RERANKING, 43 | engine_type=EngineType.RERANKER, 44 | node_type=NodeType.COMPUTE, 45 | config=self.config 46 | ) 47 | ] 48 | 49 | 50 | -------------------------------------------------------------------------------- /Ayo/modules/searching.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from Ayo.dags.node import Node 3 | from Ayo.dags.node_commons import NodeType, NodeOps, NodeIOSchema 4 | from Ayo.engines.engine_types import EngineType 5 | from Ayo.modules.base_module import BaseModule 6 | 7 | class SearchingModule(BaseModule): 8 | ''' 9 | The searching module is used to caluate the embedding for the query and search the vector database for the most relevant documents. 10 | ''' 11 | def __init__(self, 12 | input_format: dict={ 13 | "queries": List[str], 14 | 'index_status': bool 15 | }, 16 | output_format: dict={ 17 | 'search_results': List[str] 18 | }, 19 | config: dict={ 20 | 'top_k': 3 21 | }): 22 | 23 | """Initialize the Searching Module. 24 | 25 | This module is responsible for calculating embeddings for queries and searching 26 | the vector database for the most relevant documents based on embedding similarity. 27 | 28 | Args: 29 | input_format (dict): Input format definition, defaults to: 30 | - queries (List[str]): List of queries to search for 31 | - index_status (bool): Status indicating whether the index is ready 32 | output_format (dict): Output format definition, defaults to: 33 | - search_results (List[str]): List of retrieved documents 34 | config (dict): Configuration parameters, including: 35 | - top_k (int): Number of top results to return per query 36 | """ 37 | super().__init__(input_format, output_format, config) 38 | 39 | def to_primitive_nodes(self): 40 | 41 | query_embedd_input_key=None 42 | index_input_key=None 43 | for key in self.input_format.keys(): 44 | if 'query' in key or 'queries' in key or 'expanded_queries' in key or 'passages' in key or 'expanded_passages' in key or 'expanded_query' in key: 45 | query_embedd_input_key=key 46 | 47 | elif 'index' in key or 'index_status' in key: 48 | index_input_key=key 49 | 50 | 51 | query_embedding_node = Node( 52 | name="QueryEmbedding", 53 | io_schema=NodeIOSchema( 54 | input_format={query_embedd_input_key: self.input_format[query_embedd_input_key]}, 55 | output_format={"queries_embeddings": List[float]} 56 | ), 57 | op_type=NodeOps.EMBEDDING, 58 | engine_type=EngineType.EMBEDDER, 59 | node_type=NodeType.COMPUTE, 60 | config={} 61 | ) 62 | 63 | vector_db_searching_node = Node( 64 | name="VectorDBSearching", 65 | io_schema=NodeIOSchema( 66 | input_format={ 67 | "queries_embeddings": List[float], 68 | index_input_key: self.input_format[index_input_key] 69 | }, 70 | output_format=self.output_format 71 | ), 72 | op_type=NodeOps.VECTORDB_SEARCHING, 73 | engine_type=EngineType.VECTOR_DB, 74 | node_type=NodeType.COMPUTE, 75 | config={ 76 | "top_k": self.config["top_k"], 77 | } 78 | ) 79 | 80 | query_embedding_node >> vector_db_searching_node 81 | 82 | return [query_embedding_node, vector_db_searching_node] 83 | 84 | 85 | -------------------------------------------------------------------------------- /Ayo/opt_pass/base_pass.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict, List, Optional, Any, TYPE_CHECKING 3 | from Ayo.logger import get_logger, GLOBAL_INFO_LEVEL 4 | 5 | logger = get_logger(__name__, level=GLOBAL_INFO_LEVEL) 6 | 7 | if TYPE_CHECKING: 8 | from Ayo.dags.dag import DAG 9 | from Ayo.dags.node import Node 10 | 11 | class OPT_Pass(ABC): 12 | """Base class for all optimization passes 13 | 14 | An optimization pass takes a DAG as input, performs specific optimizations, 15 | and returns the optimized DAG. Each pass should focus on a specific type 16 | of optimization (e.g., pruning dependencies, batching, splitting). 17 | """ 18 | 19 | def __init__(self, name: str): 20 | """Initialize the optimization pass 21 | 22 | Args: 23 | name: Unique identifier for this optimization pass 24 | """ 25 | self.name = name 26 | self.enabled = True 27 | self.config: Dict[str, Any] = {} 28 | 29 | @abstractmethod 30 | def run(self, dag: 'DAG') -> 'DAG': 31 | """Execute the optimization pass on the given DAG 32 | 33 | Args: 34 | dag: Input DAG to optimize 35 | 36 | Returns: 37 | Optimized DAG 38 | 39 | This method must be implemented by all concrete optimization passes. 40 | """ 41 | pass 42 | 43 | def configure(self, **kwargs) -> None: 44 | """Configure the optimization pass 45 | 46 | Args: 47 | **kwargs: Configuration parameters specific to this pass 48 | """ 49 | self.config.update(kwargs) 50 | 51 | def enable(self) -> None: 52 | """Enable this optimization pass""" 53 | self.enabled = True 54 | 55 | def disable(self) -> None: 56 | """Disable this optimization pass""" 57 | self.enabled = False 58 | 59 | def is_enabled(self) -> bool: 60 | """Check if this pass is enabled""" 61 | return self.enabled 62 | 63 | def get_config(self, key: str, default: Any = None) -> Any: 64 | """Get configuration value 65 | 66 | Args: 67 | key: Configuration key 68 | default: Default value if key doesn't exist 69 | """ 70 | return self.config.get(key, default) 71 | 72 | def validate_dag(self, dag: 'DAG') -> bool: 73 | """Validate DAG before optimization 74 | 75 | Args: 76 | dag: DAG to validate 77 | 78 | Returns: 79 | True if DAG is valid for this optimization 80 | """ 81 | return True 82 | 83 | def get_applicable_nodes(self, dag: 'DAG') -> List['Node']: 84 | """Get nodes that this pass can optimize 85 | 86 | Args: 87 | dag: Input DAG 88 | 89 | Returns: 90 | List of nodes that can be optimized by this pass 91 | """ 92 | return [] 93 | 94 | def log_optimization(self, message: str) -> None: 95 | """Log optimization information 96 | 97 | Args: 98 | message: Message to log 99 | """ 100 | logger.info(f"[{self.name}] {message}") 101 | 102 | def __str__(self) -> str: 103 | return f"OPT_Pass(name={self.name}, enabled={self.enabled})" 104 | 105 | def __repr__(self) -> str: 106 | return self.__str__() 107 | -------------------------------------------------------------------------------- /Ayo/opt_pass/pass_manager.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional 2 | from Ayo.opt_pass.base_pass import OPT_Pass 3 | from Ayo.dags.node import Node 4 | 5 | class PassManager: 6 | """Manager for handling optimization passes 7 | 8 | Features: 9 | - Pass registration and management 10 | - Pass selection based on node properties 11 | """ 12 | 13 | def __init__(self): 14 | self.passes: Dict[str, OPT_Pass] = {} 15 | 16 | def register_pass(self, opt_pass: OPT_Pass) -> None: 17 | """Register an optimization pass 18 | 19 | Args: 20 | opt_pass: Optimization pass to register 21 | """ 22 | self.passes[opt_pass.name] = opt_pass 23 | 24 | def get_passes(self) -> List[OPT_Pass]: 25 | """Get all registered passes 26 | 27 | Returns: 28 | List of registered optimization passes 29 | """ 30 | return list(self.passes.values()) 31 | 32 | def get_enabled_passes(self) -> List[OPT_Pass]: 33 | """Get all enabled passes 34 | 35 | Returns: 36 | List of enabled optimization passes 37 | """ 38 | return [p for p in self.passes.values() if p.is_enabled()] 39 | 40 | def get_pass(self, name: str) -> Optional[OPT_Pass]: 41 | """Get a specific pass by name 42 | 43 | Args: 44 | name: Name of the pass to retrieve 45 | 46 | Returns: 47 | The requested pass or None if not found 48 | """ 49 | return self.passes.get(name) 50 | 51 | def enable_pass(self, name: str) -> None: 52 | """Enable a specific pass 53 | 54 | Args: 55 | name: Name of the pass to enable 56 | """ 57 | if name in self.passes: 58 | self.passes[name].enable() 59 | 60 | def disable_pass(self, name: str) -> None: 61 | """Disable a specific pass 62 | 63 | Args: 64 | name: Name of the pass to disable 65 | """ 66 | if name in self.passes: 67 | self.passes[name].disable() -------------------------------------------------------------------------------- /Ayo/opt_pass/pruning_dependency.py: -------------------------------------------------------------------------------- 1 | from typing import List, TYPE_CHECKING 2 | from Ayo.dags.node_commons import NodeType 3 | from Ayo.opt_pass.base_pass import OPT_Pass 4 | from Ayo.logger import get_logger, GLOBAL_INFO_LEVEL 5 | 6 | logger = get_logger(__name__, level=GLOBAL_INFO_LEVEL) 7 | 8 | if TYPE_CHECKING: 9 | from Ayo.dags.dag import DAG 10 | from Ayo.dags.node import Node 11 | 12 | class PruningDependencyPass(OPT_Pass): 13 | """Optimization Pass: Clean up invalid dependencies in DAG""" 14 | 15 | def __init__(self): 16 | super().__init__(name="PruningDependencyPass") 17 | 18 | def run(self, dag: 'DAG') -> 'DAG': 19 | """Execute the optimization pass""" 20 | # save the reference of DAG 21 | self.dag = dag 22 | 23 | # Ensure topological sort is up to date 24 | dag._ensure_topo_sort() 25 | 26 | # Check and add missing parent node connections 27 | self._add_missing_connections(dag) 28 | 29 | # Identify essential parents for each node based on topological structure 30 | self._identify_essential_parents(dag) 31 | 32 | # Finally prune invalid connections 33 | for node in dag.nodes: 34 | self._prune_node_dependencies(node) 35 | 36 | # clear the reference of DAG, avoid memory leak 37 | self.dag = None 38 | 39 | return dag 40 | 41 | def get_applicable_nodes(self, dag: 'DAG') -> List['Node']: 42 | """Get nodes that can be pruned""" 43 | return [node for node in dag.nodes if len(node.parents) > 0] 44 | 45 | def validate_dag(self, dag: 'DAG') -> bool: 46 | """Validate if DAG can be pruned""" 47 | return len(dag.nodes) > 0 48 | 49 | def _prune_node_dependencies(self, node: 'Node') -> None: 50 | """Clean up invalid dependencies for a single node""" 51 | invalid_parents = [] 52 | 53 | for parent in node.parents: 54 | if not self._has_valid_connection(parent, node): 55 | invalid_parents.append(parent) 56 | 57 | for parent in invalid_parents: 58 | self._remove_connection(parent, node) 59 | 60 | def _has_valid_connection(self, parent: 'Node', child: 'Node') -> bool: 61 | """Check if there is valid data flow between parent and child nodes""" 62 | if parent.name in child.input_key_from_parents: 63 | output_key = child.input_key_from_parents[parent.name] 64 | if (output_key in parent.output_names and 65 | output_key in child.input_names): 66 | return True 67 | return False 68 | 69 | def _remove_connection(self, parent: 'Node', child: 'Node') -> None: 70 | """Remove connection between two nodes""" 71 | if child in parent.children: 72 | parent.children.remove(child) 73 | 74 | if parent in child.parents: 75 | child.parents.remove(parent) 76 | 77 | if parent.name in child.input_key_from_parents: 78 | del child.input_key_from_parents[parent.name] 79 | 80 | # update the in_degree information of DAG 81 | if hasattr(self, 'dag'): 82 | if child in self.dag.in_degree: 83 | self.dag.in_degree[child] -= 1 84 | # mark the topological sort need to be recalculated 85 | self.dag._mark_topo_dirty() 86 | 87 | def _add_missing_connections(self, dag: 'DAG') -> None: 88 | """Check and add missing parent node connections""" 89 | # preprocess: create the mapping from output name to node 90 | output_providers = {} 91 | for node in dag.topo_list: 92 | for output_name in node.output_names: 93 | if output_name not in output_providers: 94 | output_providers[output_name] = [] 95 | output_providers[output_name].append(node) 96 | 97 | # process the nodes in topological order 98 | for node in dag.topo_list: 99 | if node.node_type == NodeType.INPUT: 100 | continue 101 | 102 | # check each input needed by the node 103 | for input_name in node.input_names: 104 | # if this input has no provider 105 | # here we assume the output name from different nodes are unique 106 | logger.info(f"{node.name}, {input_name}, {node.input_key_from_parents.values()}") 107 | if (input_name not in node.input_key_from_parents.values()): 108 | logger.warning(f"input_name: {input_name} not in node.input_key_from_parents.values(): {node.input_key_from_parents.values()}") 109 | # find the node that can provide this input 110 | potential_parents = output_providers.get(input_name, []) 111 | for potential_parent in potential_parents: 112 | if (potential_parent != node and 113 | potential_parent not in node.parents): 114 | # add the connection 115 | node.add_parent(potential_parent) 116 | node.input_key_from_parents[potential_parent.name] = input_name 117 | break 118 | 119 | def _identify_essential_parents(self, dag: 'DAG') -> None: 120 | """Identify essential parents for each node based on topological structure""" 121 | # process the nodes in topological order 122 | for node in dag.topo_list: 123 | if node.node_type == NodeType.INPUT or len(node.parents) <= 1: 124 | continue 125 | 126 | self._find_essential_parents(node) 127 | 128 | def _find_essential_parents(self, node: 'Node') -> None: 129 | """Find essential parents for a node, removing redundant parents""" 130 | # Record each input provider 131 | input_providers = {} # input name -> best parent node 132 | redundant_parents = [] 133 | 134 | # ensure we know each parent's input 135 | for parent in node.parents: 136 | if parent.name in node.input_key_from_parents: 137 | output_key = node.input_key_from_parents[parent.name] 138 | 139 | # if this input has no provider, record current parent 140 | if output_key not in input_providers: 141 | input_providers[output_key] = parent 142 | else: 143 | # there is already a node providing this input 144 | # here we can implement the logic to select the best parent node, for example: 145 | # 1. select the node with earlier topological order (may result in an earlier availability) 146 | # 2. select the node based on other criteria 147 | # default: keep the first encountered parent node 148 | redundant_parents.append(parent) 149 | 150 | # remove the redundant parent node connections 151 | for parent in redundant_parents: 152 | self._remove_connection(parent, node) 153 | 154 | 155 | if __name__ == "__main__": 156 | from Ayo.dags.dag import DAG 157 | from Ayo.dags.node import Node 158 | from Ayo.dags.node_commons import NodeType, NodeIOSchema, NodeAnnotation 159 | from Ayo.engines.engine_types import EngineType 160 | from typing import Any, Dict 161 | 162 | dag = DAG() 163 | 164 | dag.set_query_inputs({"query": "What is the capital of France?", "passages": ["Paris is the capital of France.", "Paris is the capital of France.", "Paris is the capital of France."]}) 165 | 166 | embedding_node = Node( 167 | name="Embedding", 168 | node_type=NodeType.COMPUTE, 169 | engine_type=EngineType.EMBEDDER, 170 | io_schema=NodeIOSchema( 171 | input_format={"passages": List[str]}, 172 | output_format={"embeddings_passages": List[Any]} 173 | ), 174 | anno=NodeAnnotation.BATCHABLE, 175 | config={ 176 | } 177 | ) 178 | 179 | reranker_node = Node( 180 | name="Reranker", 181 | node_type=NodeType.COMPUTE, 182 | engine_type=EngineType.RERANKER, 183 | io_schema=NodeIOSchema( 184 | input_format={ 185 | "query": str, 186 | "passages": List[str] 187 | }, 188 | output_format={ 189 | "ranked_results": List[str] 190 | } 191 | ), 192 | anno=NodeAnnotation.BATCHABLE, 193 | config={ 194 | } 195 | ) 196 | 197 | llm_node = Node( 198 | name="LLM", 199 | node_type=NodeType.COMPUTE, 200 | engine_type=EngineType.LLM, 201 | io_schema=NodeIOSchema( 202 | input_format={"query": str, "ranked_results": List[str]}, 203 | output_format={"answer": str} 204 | ), 205 | anno=NodeAnnotation.BATCHABLE, 206 | config={ 207 | } 208 | ) 209 | 210 | embedding_node>>reranker_node >> llm_node 211 | 212 | dag.register_nodes(embedding_node, reranker_node, llm_node) 213 | 214 | print(dag.get_full_dag_nodes_info()) 215 | 216 | dag.optimize([PruningDependencyPass()]) 217 | 218 | 219 | print(dag.get_full_dag_nodes_info()) 220 | 221 | -------------------------------------------------------------------------------- /Ayo/opt_pass/test_dag_llm_decoding_pipeling.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NetX-lab/Ayo/61607bcb31ed98e1101c5d60d8573105bf60d87c/Ayo/opt_pass/test_dag_llm_decoding_pipeling.pdf -------------------------------------------------------------------------------- /Ayo/opt_pass/test_dag_prefilling_split.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NetX-lab/Ayo/61607bcb31ed98e1101c5d60d8573105bf60d87c/Ayo/opt_pass/test_dag_prefilling_split.pdf -------------------------------------------------------------------------------- /Ayo/queries/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NetX-lab/Ayo/61607bcb31ed98e1101c5d60d8573105bf60d87c/Ayo/queries/__init__.py -------------------------------------------------------------------------------- /Ayo/queries/query.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional 2 | from enum import Enum 3 | import time 4 | from dataclasses import dataclass, field 5 | import ray 6 | from Ayo.dags.dag import DAG 7 | from Ayo.queries.query_state import QueryStates, QueryStatus 8 | 9 | 10 | class Query: 11 | """Base class for all query types in the system""" 12 | def __init__( 13 | self, 14 | uuid: str, 15 | query_id: str, 16 | query_inputs: Dict[str, Any], 17 | DAG: DAG, 18 | context: Optional[Dict] = None, 19 | uploaded_file: Optional[Any] = None, 20 | timeout: float = 30.0 21 | # here some attributes maybe not used... to clear some in the future 22 | ): 23 | # Basic query information 24 | self.uuid = uuid 25 | self.query_id = query_id 26 | self.query_inputs = query_inputs 27 | self.context = context or {} 28 | self.DAG = DAG 29 | self.uploaded_file = uploaded_file 30 | 31 | # Initialize DAG information 32 | self._init_dag() 33 | 34 | # Query state management 35 | self.status = QueryStatus.INIT 36 | self.query_state = QueryStates.remote() 37 | self.created_at = time.time() 38 | self.updated_at = self.created_at 39 | self.start_time: Optional[float] = None 40 | self.end_time: Optional[float] = None 41 | self.timeout = timeout 42 | self.error_message: Optional[str] = None 43 | 44 | # Results storage 45 | self.results: Dict[str, Any] = {} 46 | self.metadata: Dict[str, Any] = {} 47 | 48 | def _init_dag(self) -> None: 49 | """Initialize DAG query related information""" 50 | self.DAG.query_id = self.query_id 51 | self.DAG.set_query_inputs(self.query_inputs) 52 | self.DAG.create_input_nodes() 53 | 54 | def start(self): 55 | """Start query execution""" 56 | self.status = QueryStatus.RUNNING 57 | self.start_time = time.time() 58 | self.updated_at = self.start_time 59 | 60 | def complete(self): 61 | """Mark query as completed""" 62 | self.status = QueryStatus.COMPLETED 63 | self.end_time = time.time() 64 | self.updated_at = self.end_time 65 | 66 | def fail(self, error_message: str): 67 | """Mark query as failed""" 68 | self.status = QueryStatus.FAILED 69 | self.error_message = error_message 70 | self.end_time = time.time() 71 | self.updated_at = self.end_time 72 | 73 | def set_timeout(self): 74 | """Mark query as timed out""" 75 | self.status = QueryStatus.TIMEOUT 76 | self.error_message = "Query execution exceeded timeout" 77 | self.end_time = time.time() 78 | self.updated_at = self.end_time 79 | 80 | def is_timeout(self) -> bool: 81 | """Check if query has exceeded timeout""" 82 | if self.start_time is None: 83 | return False 84 | return (time.time() - self.start_time) > self.timeout 85 | 86 | def get_execution_time(self) -> Optional[float]: 87 | """Get query execution time in seconds""" 88 | if self.start_time is None: 89 | return None 90 | end = self.end_time or time.time() 91 | return end - self.start_time 92 | 93 | def _get_obj_name_recurse(self, name, obj): 94 | """Helper method for attribute access""" 95 | name = name.split(".", maxsplit=1) 96 | recurse = len(name) > 1 97 | next_name = name[1] if recurse else "" 98 | name = name[0] 99 | obj = self if obj is None else obj 100 | return obj, name, next_name, recurse 101 | 102 | def get_status(self) -> QueryStatus: 103 | """Get query status""" 104 | return self.status 105 | 106 | def get_remote_attr(self, __name: str, __obj: object = None): 107 | """Get remote attribute value""" 108 | obj, name, next_name, recurse = self._get_obj_name_recurse(__name, __obj) 109 | next_obj = getattr(obj, name) 110 | if recurse: 111 | next_obj = self.get_remote_attr(next_name, next_obj) 112 | return next_obj 113 | 114 | def set_remote_attr(self, __name: str, __value: Any, __obj: object = None): 115 | """Set remote attribute value""" 116 | obj, name, next_name, recurse = self._get_obj_name_recurse(__name, __obj) 117 | if recurse: 118 | next_obj = getattr(obj, name) 119 | self.set_remote_attr(next_name, __value, next_obj) 120 | else: 121 | if hasattr(obj, name): 122 | setattr(obj, name, __value) 123 | 124 | def to_dict(self) -> Dict[str, Any]: 125 | """Convert query to dictionary representation""" 126 | return { 127 | "uuid": self.uuid, 128 | "query_id": self.query_id, 129 | "query_inputs": self.query_inputs, 130 | "status": self.status.value, 131 | "created_at": self.created_at, 132 | "updated_at": self.updated_at, 133 | "execution_time": self.get_execution_time(), 134 | "error_message": self.error_message, 135 | "results": self.results, 136 | "context": self.context, 137 | "metadata": self.metadata 138 | } 139 | 140 | def __str__(self): 141 | return f"Query(uuid={self.uuid}, status={self.status.value}, query_id='{self.query_id}', query_inputs={self.query_inputs})" 142 | 143 | def __repr__(self): 144 | return self.__str__() 145 | -------------------------------------------------------------------------------- /Ayo/queries/query_state.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any, Optional 2 | import time 3 | import ray 4 | from enum import Enum 5 | 6 | class QueryStatus(Enum): 7 | """Query execution states""" 8 | INIT = "init" # Query is initialized 9 | PENDING = "pending" # Query is waiting to be processed 10 | RUNNING = "running" # Query is being processed 11 | COMPLETED = "completed"# Query completed successfully 12 | FAILED = "failed" # Query failed during execution 13 | TIMEOUT = "timeout" # Query exceeded time limit 14 | 15 | @ray.remote 16 | class QueryStates: 17 | """Ray Actor for managing query states and intermediate results 18 | 19 | Handles both: 20 | 1. Global variables and intermediate results for DAG nodes 21 | 2. Service-specific results and query states 22 | """ 23 | def __init__(self): 24 | # For storing node-specific variables and results 25 | self.global_var_idx = {} 26 | 27 | # For storing query states and metadata 28 | self.states: Dict[str, Dict[str, Any]] = {} 29 | 30 | # For storing node-specific results 31 | self.node_results: Dict[str, Any] = {} 32 | 33 | def set_global_var(self, var, node_name): 34 | """Set global variable for a node""" 35 | self.global_var_idx[node_name] = var 36 | 37 | def get_global_var(self, node_name): 38 | """Get global variable for a node""" 39 | if node_name not in self.global_var_idx: 40 | return None 41 | return self.global_var_idx[node_name] 42 | 43 | def get_node_results(self): 44 | """Get all node results""" 45 | return self.node_results 46 | 47 | def set_node_result(self, node_name: str, result): 48 | """Set the result of a specific node""" 49 | if node_name not in self.node_results: 50 | self.node_results[node_name] = {} 51 | self.node_results[node_name] = result 52 | 53 | def get_node_result(self, node_name: str): 54 | """Get the result of a specific node""" 55 | if node_name not in self.node_results: 56 | return None 57 | return self.node_results[node_name] 58 | 59 | def clear_node_result(self, node_name: str): 60 | """Clear the result of a specific node""" 61 | if node_name in self.node_results: 62 | self.node_results.pop(node_name, None) 63 | 64 | def clear_query(self, query_id: str): 65 | """Clear all data related to a query""" 66 | # Clear query state 67 | if query_id in self.states: 68 | del self.states[query_id] 69 | 70 | # Clear any service results 71 | for service_results in self.service_results.values(): 72 | service_results.pop(query_id, None) -------------------------------------------------------------------------------- /Ayo/schedulers/engine_scheduler.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Any 2 | import ray 3 | import asyncio 4 | from abc import ABC, abstractmethod 5 | from dataclasses import dataclass 6 | import time 7 | from collections import deque 8 | from Ayo.queries.query import Query 9 | from Ayo.configs.config import EngineConfig 10 | from Ayo.dags.node import NodeOps 11 | from Ayo.logger import get_logger, GLOBAL_INFO_LEVEL 12 | 13 | logger = get_logger(__name__, level=GLOBAL_INFO_LEVEL) 14 | 15 | 16 | @dataclass 17 | class EngineRequest: 18 | """Data class for engine requests""" 19 | request_id: str # unique id for each request, {query_id}_{node_name} 20 | query_id: str 21 | query: Query 22 | payload: Any # Request payload (e.g., texts for embedding), generated from the payload_transformer, could be different for different engines 23 | result_ref: Optional[ray.ObjectRef] = None 24 | timestamp: float = time.time() 25 | 26 | 27 | class BaseEngineScheduler(ABC): 28 | """Abstract base class for engine schedulers""" 29 | 30 | @abstractmethod 31 | async def submit_request(self, request: EngineRequest): 32 | """Submit a request to the engine""" 33 | pass 34 | 35 | @abstractmethod 36 | async def shutdown(self): 37 | """Shutdown the scheduler and all engine instances""" 38 | pass 39 | 40 | class SchedulingStrategy(ABC): 41 | """Abstract base class for scheduling strategies""" 42 | @abstractmethod 43 | def get_next_engine(self, engines: List[Any], current_idx: int) -> 'tuple[Any, int]': 44 | """Get the next available engine 45 | 46 | Args: 47 | engines: list of engine instances 48 | current_idx: current engine index 49 | 50 | Returns: 51 | tuple[engine, new_idx]: selected engine and updated index 52 | """ 53 | pass 54 | 55 | class RoundRobinStrategy(SchedulingStrategy): 56 | """Round-robin scheduling strategy""" 57 | def get_next_engine(self, engines: List[Any], current_idx: int) -> 'tuple[Any, int]': 58 | if not engines: 59 | raise ValueError("No engines available") 60 | engine = engines[current_idx] 61 | new_idx = (current_idx + 1) % len(engines) 62 | return engine, new_idx 63 | 64 | @ray.remote 65 | class EngineScheduler(BaseEngineScheduler): 66 | def __init__(self, 67 | engine_class, 68 | engine_config: EngineConfig, # use EngineConfig as config 69 | **engine_kwargs): 70 | 71 | self.engine_class = engine_class 72 | self.name = engine_config.name 73 | self.num_instances = engine_config.instances 74 | self.num_gpus = engine_config.num_gpus 75 | self.num_cpus = engine_config.num_cpus 76 | self.resources = engine_config.resources 77 | self.engine_kwargs = { 78 | **engine_kwargs, 79 | **(engine_config.model_config or {}) # merge model_config 80 | } 81 | self.engines = [] 82 | self._create_engines() 83 | 84 | # set scheduling strategy, default using round robin 85 | self.scheduling_strategy = RoundRobinStrategy() 86 | 87 | # create request queue 88 | self.request_queue = asyncio.Queue(maxsize=1000) 89 | 90 | 91 | # initialize current engine index 92 | self.current_engine_idx = 0 93 | 94 | 95 | self.pending_requests: Dict[str, EngineRequest] = {} 96 | 97 | # Start processing tasks 98 | self.running = True 99 | self.loop = asyncio.get_event_loop() 100 | self.submit_task = self.loop.create_task(self._submit_requests()) 101 | self.result_task = self.loop.create_task(self._process_results()) 102 | self._is_ready = True # initialization complete flag 103 | 104 | # The queue for completed requests 105 | self.complete_queue = asyncio.Queue(maxsize=1000) 106 | 107 | def _create_engines(self): 108 | """create engine instances based on EngineConfig""" 109 | for i in range(self.num_instances): 110 | # create options dict, only include custom resources in resources 111 | options = {} 112 | if self.resources: 113 | custom_resources = { 114 | k: v for k, v in self.resources.items() 115 | if k.upper() not in ['CPU', 'GPU'] 116 | } 117 | if custom_resources: 118 | options['resources'] = custom_resources 119 | 120 | if self.num_gpus: 121 | options['num_gpus'] = self.num_gpus 122 | if self.num_cpus: 123 | options['num_cpus'] = self.num_cpus 124 | else: 125 | options['num_cpus'] = 1 126 | 127 | 128 | 129 | # directly use the current actor's handle 130 | scheduler_handle = ray.runtime_context.get_runtime_context().current_actor 131 | 132 | # when creating engine instance, pass in scheduler handle 133 | logger.info(f"try to create engine instance with name: {self.name}_{i} for class: {self.engine_class} with options: {options}") 134 | 135 | engine = self.engine_class.options(**options).remote( 136 | name=f"{self.name}_{i}", 137 | scheduler_ref=scheduler_handle, # directly use actor handle 138 | **self.engine_kwargs 139 | ) 140 | 141 | _= ray.get(engine.is_ready.remote()) 142 | self.engines.append(engine) 143 | logger.info(f"created engine instance with name: {self.name}_{i} for class: {self.engine_class}") 144 | 145 | def _get_next_engine(self): 146 | """use scheduling strategy to get next available engine""" 147 | engine, new_idx = self.scheduling_strategy.get_next_engine( 148 | self.engines, 149 | self.current_engine_idx 150 | ) 151 | self.current_engine_idx = new_idx 152 | return engine 153 | 154 | async def add_engine(self): 155 | """add a new engine instance dynamically""" 156 | engine = self.engine_class.remote( 157 | name=f"{self.name}_{self.num_instances}", 158 | resource=self.resource, 159 | scheduler_ref=self, 160 | **self.engine_kwargs 161 | ) 162 | self.engines.append(engine) 163 | self.num_instances += 1 164 | return len(self.engines) 165 | 166 | async def remove_engine(self, index: Optional[int] = None): 167 | """remove an engine instance 168 | 169 | Args: 170 | index: the index of the engine to remove, if None, remove the last one 171 | """ 172 | if not self.engines or self.num_instances <= 1: 173 | raise ValueError("Cannot remove the last engine instance") 174 | 175 | if index is None: 176 | index = len(self.engines) - 1 177 | 178 | if 0 <= index < len(self.engines): 179 | engine = self.engines.pop(index) 180 | await engine.shutdown.remote() 181 | self.num_instances -= 1 182 | 183 | # adjust current engine index 184 | if self.current_engine_idx >= len(self.engines): 185 | self.current_engine_idx = 0 186 | 187 | return len(self.engines) 188 | else: 189 | raise ValueError(f"Invalid engine index: {index}") 190 | 191 | async def _submit_requests(self): 192 | """Task for submitting requests to engines""" 193 | while self.running: 194 | try: 195 | # timeout 0.1 second to avoid busy waiting 196 | request = await asyncio.wait_for(self.request_queue.get(), timeout=0.05) 197 | logger.info(f"Processing request: {request.request_id}") 198 | 199 | # get available engine 200 | engine = self._get_next_engine() 201 | 202 | logger.info(f"submit request {request.request_id} with payload: {request.payload.keys()}") 203 | # submit request to engine 204 | await engine.submit_request.remote( 205 | request_id=request.request_id, 206 | query_id=request.query_id, 207 | # texts=request.payload 208 | **request.payload 209 | ) 210 | 211 | # track result 212 | self.pending_requests[request.request_id] = request 213 | self.request_queue.task_done() 214 | 215 | except asyncio.TimeoutError: 216 | continue 217 | except asyncio.CancelledError: 218 | break 219 | except Exception as e: 220 | logger.error(f"Error submitting request for {request.request_id}: {e}") 221 | await asyncio.sleep(0.001) 222 | 223 | async def _process_results(self): 224 | """Task for processing completed results""" 225 | while self.running: 226 | try: 227 | request = await self.complete_queue.get() 228 | 229 | if request.result_ref is not None: 230 | try: 231 | result = await asyncio.get_event_loop().run_in_executor( 232 | None, 233 | ray.get, 234 | request.result_ref 235 | ) 236 | logger.debug(f"result in engine scheduler: {result}") 237 | 238 | query_states = request.query.query_state 239 | # Use ray.get to wait for result completion 240 | await asyncio.get_event_loop().run_in_executor( 241 | None, 242 | ray.get, 243 | query_states.set_node_result.remote( 244 | request.request_id.split("::")[-1], # here we assume the request_id is in the format of {query_id}::{node_name} as in the graph scheduler submit_node method 245 | request.result_ref 246 | ) 247 | ) 248 | 249 | logger.info(f"Successfully set result for request {request.request_id}") 250 | 251 | except Exception as e: 252 | logger.error(f"Error processing result for request {request.request_id}: {e}") 253 | 254 | finally: 255 | # clean up completed request 256 | self.complete_queue.task_done() 257 | if request.request_id in self.pending_requests: 258 | del self.pending_requests[request.request_id] 259 | 260 | await asyncio.sleep(0.01) 261 | 262 | except Exception as e: 263 | logger.error(f"Error in result processing loop: {e}") 264 | await asyncio.sleep(0.01) 265 | 266 | async def submit_request(self, request: EngineRequest): 267 | """Submit a request to the engine""" 268 | #Enqueue request instead of submitting directly 269 | try: 270 | logger.info(f"Enqueued request: {request.request_id}") 271 | await self.request_queue.put(request) 272 | self.pending_requests[request.request_id] = request 273 | logger.info(f"Enqueued request: {request.request_id}") 274 | 275 | except asyncio.QueueFull: 276 | logger.error(f"Request queue full, cannot enqueue {request.request_id}") 277 | raise 278 | 279 | async def on_result(self, request_id: str, query_id: str, result_ref_from_engine: Any): 280 | """Handle result callback from engine""" 281 | if not isinstance(result_ref_from_engine, ray.ObjectRef): 282 | result_ref_from_engine = ray.put(result_ref_from_engine) 283 | try: 284 | if request_id in self.pending_requests: 285 | request = self.pending_requests[request_id] 286 | request.result_ref = result_ref_from_engine 287 | 288 | # put completed request into complete queue 289 | await self.complete_queue.put(request) 290 | logger.info(f"Moved request {request_id} to complete queue") 291 | 292 | # Clean up pending_requests 293 | del self.pending_requests[request_id] 294 | else: 295 | logger.warning(f"Received result for unknown request {request_id}") 296 | 297 | except Exception as e: 298 | logger.error(f"Error processing result callback: {e}") 299 | 300 | async def shutdown(self): 301 | """Shutdown the scheduler and all engine instances""" 302 | self.running = False 303 | 304 | # Cancel processing tasks 305 | for task in [self.submit_task, self.result_task]: 306 | if task: 307 | task.cancel() 308 | try: 309 | await task 310 | except asyncio.CancelledError: 311 | pass 312 | 313 | # Shutdown all engine instances 314 | for engine in self.engines: 315 | await engine.shutdown.remote() 316 | 317 | # Clear engine list 318 | self.engines.clear() 319 | self.num_instances = 0 320 | 321 | async def is_ready(self) -> bool: 322 | """Check if the scheduler is fully initialized 323 | 324 | Returns: 325 | bool: if the scheduler is fully initialized and ready to process requests 326 | """ 327 | # check if all necessary components are initialized 328 | if not self.engines: 329 | return False 330 | 331 | # check if async tasks are running 332 | if not self.submit_task or self.submit_task.done(): 333 | return False 334 | if not self.result_task or self.result_task.done(): 335 | return False 336 | 337 | return self._is_ready 338 | 339 | -------------------------------------------------------------------------------- /Ayo/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | def print_warning(message): 4 | #yellow color 5 | print(f"\033[93m{message}\033[0m") 6 | 7 | def print_key_info(message): 8 | #green color 9 | print(f"\033[92m{message}\033[0m") 10 | 11 | def print_error(message): 12 | #red color 13 | print(f"\033[91m{message}\033[0m") 14 | 15 | 16 | def format_query_expanding_prompt(question, prompt_template, expanded_query_num=3): 17 | keys = ", ".join([f"question{i+1}" for i in range(expanded_query_num)]) 18 | json_example = "{\n " + "\n ".join([f"\"question{i+1}\": \"[refined version {i+1}]\"" + ("," if i < expanded_query_num-1 else "") for i in range(expanded_query_num)]) + "\n }" 19 | 20 | return prompt_template.format( 21 | expanded_query_num=expanded_query_num, 22 | question=question, 23 | keys=keys, 24 | json_example=json_example 25 | ) 26 | 27 | def rename_template_placeholders(template, placeholder_mapping): 28 | """ 29 | Modify the placeholder names in the template string 30 | 31 | Args: 32 | template (str): The template string containing placeholders 33 | placeholder_mapping (dict): The placeholder mapping, format as {old_name: new_name} 34 | 35 | Returns: 36 | str: The modified template string 37 | """ 38 | result = template 39 | for old_name, new_name in placeholder_mapping.items(): 40 | result = result.replace(f"{{{old_name}}}", f"{{{new_name}}}") 41 | return result 42 | 43 | def fill_prompt_template_with_placeholdersname_approximations(prompt_template, input_kwargs): 44 | """ 45 | Fill the prompt template, handle the mismatch between placeholder names and input parameter names 46 | 47 | Args: 48 | prompt_template (str): The prompt template containing placeholders 49 | input_kwargs (dict): The input parameters dictionary 50 | 51 | Returns: 52 | str: The filled prompt template 53 | """ 54 | import re 55 | from difflib import get_close_matches 56 | 57 | # find all placeholders in the prompt template 58 | placeholders = re.findall(r'\{([^{}]+)\}', prompt_template) 59 | result = prompt_template 60 | 61 | # create a match cache to avoid duplicate calculations 62 | input_keys = list(input_kwargs.keys()) 63 | 64 | # find the best match for each placeholder 65 | for placeholder in placeholders: 66 | if placeholder in input_kwargs: 67 | # if there is an exact match, use it 68 | value = str(input_kwargs[placeholder]) 69 | result = result.replace(f"{{{placeholder}}}", value) 70 | else: 71 | # check if the match result is already in the cache 72 | close_matches = get_close_matches(placeholder, input_keys, n=1, cutoff=0.1) 73 | 74 | if close_matches: 75 | print_warning(f"Use approximate matching: Placeholder '{placeholder}' is matched to input parameter '{close_matches[0]}'") 76 | value = str(input_kwargs[close_matches[0]]) 77 | result = result.replace(f"{{{placeholder}}}", value) 78 | else: 79 | print_error(f"No match found for placeholder '{placeholder}'") 80 | 81 | return result 82 | 83 | 84 | def check_unfilled_placeholders_in_prompt_template(prompt_template): 85 | """ 86 | Check if the prompt template is complete by ensuring all placeholders are filled (no {placeholder} pattern in the prompt template) 87 | """ 88 | placeholders = re.findall(r'\{([^{}]+)\}', prompt_template) 89 | if placeholders: 90 | raise ValueError(f"Prompt template is not complete without any unfilled placeholders") 91 | else: 92 | return True 93 | 94 | def check_prompt_template_and_placeholders_match(prompt_template, input_kwargs): 95 | """ 96 | Check if the prompt template is complete by ensuring all placeholders are filled 97 | 98 | Args: 99 | prompt_template (str): The prompt template to check 100 | input_kwargs (dict): The input kwargs to check 101 | 102 | Returns: 103 | bool: True if the prompt template is complete, False otherwise 104 | """ 105 | 106 | 107 | # check if all placeholders in the prompt template are in the input_kwargs 108 | import re 109 | placeholders = re.findall(r'\{([^{}]+)\}', prompt_template) 110 | for placeholder in placeholders: 111 | if placeholder not in input_kwargs: 112 | raise ValueError(f"Placeholder {placeholder} not found in input_kwargs") 113 | 114 | return True 115 | 116 | def fill_prompt_template(prompt_template, input_kwargs): 117 | """ 118 | Fill the placeholders in the prompt template with the input kwargs 119 | """ 120 | for key, value in input_kwargs.items(): 121 | prompt_template = prompt_template.replace(f"{{{key}}}", value) 122 | 123 | #check if the prompt template is complete 124 | return prompt_template 125 | 126 | 127 | 128 | 129 | -------------------------------------------------------------------------------- /Ayo/vis/test_dag_node_types.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NetX-lab/Ayo/61607bcb31ed98e1101c5d60d8573105bf60d87c/Ayo/vis/test_dag_node_types.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 NetX 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Ayo 2 | 3 | This repository contains the prototype implementation for our ASPLOS'25 paper: [Towards End-to-End Optimization of LLM-based Applications with Ayo](https://dl.acm.org/doi/10.1145/3676641.3716278) ([arXiv preprint](https://arxiv.org/pdf/2407.00326)). 4 | 5 | Ayo is a fine-grained orchestration framework designed for building and optimizing AI-powered applications—such as Retrieval-Augmented Generation (RAG) workflows—in environments where inference engines are ***deployed locally*** rather than accessed via remote APIs. 6 | 7 | Unlike existing frameworks that usually treat workflows as coarse-grained, sequential module chains, Ayo introduces a task-primitive-based abstraction, enabling highly flexible and dynamic orchestration. With minimal user input, Ayo automatically optimizes workflows for performance, exploiting parallelism, pipelining, and inherent scheduling strategies. 8 | 9 | > **Note**:Some parts of the repo are still under construction, e.g. the unified multi-request scheduling for engine schedulers, user-friendly interface, and the documentation. We would keep updating these. 10 | 11 | ## Key Features 12 | - Fine-grained task orchestration for LLM workflows 13 | - Dynamic optimization for performance (e.g., parallelism, pipelining) 14 | - Dependency Pruning 15 | - Stage decomposition parallelization 16 | - LLM Prefilling Splitting 17 | - LLM Decoding pipelining 18 | - Distributed Two-level Scheduling 19 | - A graph scheduler for the task primitive scheduling of each query graph 20 | - Several distinct engine schedulers for handling different types of engines and managing the different operations 21 | 22 | ## Quick Start 23 | 24 | 1. Install dependencies: 25 | 26 | Install postgres and pgvector: 27 | ```bash 28 | sudo apt-get install postgresql postgresql-contrib libpq-dev # install postgresql 29 | git clone https://github.com/pgvector/pgvector.git # compile and install pgvector; you could install through other ways as well 30 | cd pgvector 31 | make 32 | sudo make install 33 | sudo -u postgres psql template1 -c "CREATE EXTENSION vector;" # test 34 | ``` 35 | 36 | Install our modified vllm: 37 | ```bash 38 | git clone --recurse-submodules https://github.com/NetX-lab/Ayo.git # clone the repo and submodules 39 | cd vllm 40 | pip install -e . 41 | ``` 42 | 43 | Install Ayo: 44 | ```bash 45 | cd .. 46 | pip install -r requirements.txt 47 | pip install -e . 48 | ``` 49 | 50 | 2. Define the workflow with Nodes (Task Primitives) and Optimize the workflow with Ayo 51 | 52 | 53 |
54 | Click to expand the code 55 | 56 | ```python 57 | from Ayo.app import APP 58 | from Ayo.configs.config import EngineConfig 59 | from Ayo.engines.engine_types import EngineType 60 | 61 | app = APP.init() # initialize the app entry 62 | 63 | llm_config = EngineConfig( 64 | name="llm_service", 65 | engine_type=EngineType.LLM, 66 | resources={}, 67 | num_gpus=1, 68 | num_cpus=1, 69 | instances=1, 70 | model_config={ 71 | "model_name": "meta-llama/Llama-2-7b-chat-hf", 72 | "tensor_parallel_size": 1, 73 | #other config ... 74 | }, 75 | latency_profile={ 76 | "timeout": 300, 77 | } 78 | ) 79 | 80 | app.register_engine(llm_config) 81 | #register other engines ... 82 | 83 | 84 | # define the primitive nodes 85 | llm_prefilling_node = Node( 86 | name="LLMPrefilling", 87 | node_type=NodeType.COMPUTE, 88 | engine_type=EngineType.LLM, 89 | io_schema=NodeIOSchema( 90 | input_format={"queries": List[str], "reranked_results": List[List[str]]}, 91 | output_format={"prefill_state": bool} 92 | ), 93 | op_type=NodeOps.LLM_PREFILLING, 94 | config={ 95 | 'prompt_template': replace_placeholders(RAG_QUESTION_ANSWERING_PROMPT_TEMPLATE_STRING, question="queries", context="reranked_results"), 96 | 'parse_json': True, 97 | #other config ... 98 | } 99 | ) 100 | 101 | llm_decoding_node = Node( 102 | name="LLMDecoding", 103 | node_type=NodeType.COMPUTE, 104 | engine_type=EngineType.LLM, 105 | io_schema=NodeIOSchema( 106 | input_format={"prefill_state": bool}, 107 | output_format={"result": str} 108 | ), 109 | op_type=NodeOps.LLM_DECODING, 110 | config={ 111 | 'prompt_template': replace_placeholders(RAG_QUESTION_ANSWERING_PROMPT_TEMPLATE_STRING, question="queries", context="reranked_results"), 112 | 'parse_json': True, 113 | #other config ... 114 | } 115 | ) 116 | #define other nodes ... 117 | 118 | # create the DAG 119 | dag = DAG(dag_id="rag_workflow") 120 | dag.register_nodes(llm_prefilling_node, llm_decoding_node, ...) 121 | # set the query inputs 122 | dag.set_query_inputs( 123 | { 124 | 'queries': ['What is the capital of France?'], ## set the query inputs 125 | } 126 | ) 127 | 128 | from Ayo.opt_pass.pruning_dependency import PruningDependencyPass 129 | from Ayo.opt_pass.stage_decomposition import StageDecompositionPass 130 | from Ayo.opt_pass.prefilling_split import PrefillingSpiltPass 131 | from Ayo.opt_pass.decoding_pipeling import LLMDecodingPipeliningPass 132 | 133 | dag.optimize([PruningDependencyPass(), StageDecompositionPass(), PrefillingSpiltPass(), LLMDecodingPipeliningPass()]) 134 | 135 | query=Query( 136 | uuid=f"random-test-{query_id}", 137 | query_id=f"random-test-{query_id}", 138 | DAG=deepcopy(dag) 139 | ) 140 | 141 | future = await app.submit_query( 142 | query=query, 143 | timeout=300 144 | ) 145 | 146 | result = await asyncio.wait_for(future, timeout=300) 147 | 148 | 149 | ``` 150 | 151 |
152 | 153 | 154 | 3. Define the high-level task modules, then transform and optimize the workflow with Ayo (some parts are still under construction) 155 | 156 | 157 |
158 | Click to expand the code 159 | 160 | ```python 161 | from Ayo.modules import IndexingModule, QueryExpandingModule, SearchingModule, RerankingModule 162 | from Ayo.modules_to_primitives import transform_mod_to_prim 163 | 164 | indexing_module = IndexingModule( 165 | input_format={"passages": List[str]}, 166 | output_format={"index_status": bool} 167 | ) 168 | 169 | query_expanding_module = QueryExpandingModule( 170 | input_format={"query": str}, 171 | output_format={"expanded_queries": List[str]}, 172 | config={"expanded_query_num": 3} 173 | ) 174 | 175 | searching_module = SearchingModule( 176 | input_format={"index_status": bool, "expanded_queries": List[str]}, 177 | output_format={"searching_results": List[str]} 178 | ) 179 | 180 | reranking_module = RerankingModule( 181 | input_format={"searching_results": List[str]}, 182 | output_format={"reranking_results": List[str]} 183 | ) 184 | 185 | 186 | indexing_module>>query_expanding_module>>searching_module>>reranking_module 187 | 188 | 189 | node_list=transform_mod_to_prim([indexing_module,query_expanding_module,searching_module,reranking_module]) 190 | 191 | ### Then optimize the workflow with Ayo as above 192 | 193 | ``` 194 | 195 | 196 | 197 |
198 | 199 | 200 | ## Examples 201 | 202 | 203 | Some examples are in the `examples` folder. 204 | 205 | The testbed is a server with 4x NVIDIA 3090 GPUs and 52 cores Intel(R) Xeon(R) Gold 5320 CPU. 206 | 207 | For instance, in file `examples/optimized_embedding_ingestion_searching_reranking_llm.py`, we provide the optimized workflow for the naive RAG workflow with Ayo and the unoptimized workflow is in file `examples/unoptimized_embedding_ingestion_searching_reranking_llm.py`. 208 | 209 | We could see the visualization comparison of the unoptimized (left) and optimized (right) workflow under the same folder. 210 | 211 |
212 | unoptimized workflow 213 | optimized workflow 214 |
215 | 216 | The execution latency is: 217 | 218 | | Workflow Type | Latency | 219 | |---------------|---------| 220 | | Unoptimized | 3.72s | 221 | | Optimized | 1.97s | 222 | 223 | 224 | ## To-Do List 225 | 226 | - [x] Add the multiple-LLM calling example 227 | - [ ] Refine the support for more LLM prompt templates 228 | - [ ] Add the unified multi-request scheduling 229 | 230 | 231 | 232 | 233 | ## Acknowledgements 234 | 235 | We list open-source projects used by us and our modifications to them (if any). 236 | 237 | - [vLLM](https://github.com/vllm-project/vllm) 238 | - [Ray](https://github.com/ray-project/ray) 239 | - [postgresql](https://www.postgresql.org/) 240 | - [pgvector](https://github.com/pgvector/pgvector) 241 | - [sentence-transformers](https://github.com/UKPLab/sentence-transformers) 242 | 243 | 244 | 245 | ## Citation 246 | 247 | If you find this work useful, please cite our paper: 248 | 249 | ```bibtex 250 | @inproceedings{tan2025ayo, 251 | title = {Towards End-to-End Optimization of LLM-based Applications with Ayo}, 252 | author = {Xin Tan and Yimin Jiang and Yitao Yang and Hong Xu}, 253 | booktitle = {Proceedings of the 30th ACM International Conference on Architectural Support for Programming Languages and Operating Systems, Volume 2}, 254 | year = {2025} 255 | } 256 | 257 | ``` 258 | 259 | ## Contact 260 | 261 | If you have any questions or feedback, please email Xin Tan ([xtan22@cse.cuhk.edu.hk](mailto:xtan22@cse.cuhk.edu.hk)). 262 | 263 | 264 | 265 | ## License 266 | 267 | This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details. 268 | 269 | 270 | 271 | -------------------------------------------------------------------------------- /examples/modules_to_primitives_embedding_ingestion_searching_reranking.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import List 3 | from Ayo.modules.indexing import IndexingModule 4 | from Ayo.modules.searching import SearchingModule 5 | from Ayo.modules.query_expanding import QueryExpandingModule 6 | from Ayo.modules.reranking import RerankingModule 7 | from Ayo.dags.dag import DAG 8 | from Ayo.dags.node import Node 9 | from Ayo.dags.node_commons import NodeType, NodeOps, NodeIOSchema 10 | from Ayo.engines.engine_types import EngineType 11 | from Ayo.opt_pass.pruning_dependency import PruningDependencyPass 12 | from Ayo.opt_pass.stage_decomposition import StageDecompositionPass 13 | from Ayo.opt_pass.decoding_pipeling import LLMDecodingPipeliningPass 14 | from Ayo.utils import print_key_info 15 | 16 | 17 | indexing_module = IndexingModule(input_format={"passages": List[str]}, output_format={"index_status": bool}) 18 | query_expanding_module=QueryExpandingModule(input_format={"query": str}, output_format={"expanded_queries": List[str]},config={"expanded_query_num": 3}) 19 | searching_module = SearchingModule(input_format={"index_status": bool, "expanded_queries": List[str]}, output_format={"searching_results": List[str]}) 20 | reranking_module=RerankingModule(input_format={"query": str,"searching_results": List[str]}, output_format={"reranking_results": List[str]}) 21 | 22 | 23 | indexing_nodes = indexing_module.to_primitive_nodes() 24 | query_expanding_nodes=query_expanding_module.to_primitive_nodes() 25 | searching_nodes = searching_module.to_primitive_nodes() 26 | reranking_nodes=reranking_module.to_primitive_nodes() 27 | 28 | 29 | indexing_nodes[-1] >> query_expanding_nodes[0] 30 | 31 | query_expanding_nodes[-1] >> searching_nodes[0] 32 | 33 | searching_nodes[-1] >> reranking_nodes[0] 34 | 35 | dag=DAG(dag_id="test_embed_ingest_search_reranking") 36 | 37 | dag.register_nodes(*indexing_nodes,*query_expanding_nodes,*searching_nodes,*reranking_nodes) 38 | 39 | dag.set_query_inputs( 40 | { 41 | 'passages': [ 42 | 'passages1', 43 | 'passages2', 44 | 'passages3', 45 | ]*100, 46 | 'query': 'What is the capital of France?' 47 | } 48 | ) 49 | 50 | print(dag.get_full_dag_nodes_info()) 51 | 52 | begin_time=time.time() 53 | 54 | dag.optimize([PruningDependencyPass()]) 55 | 56 | dag.optimize([StageDecompositionPass()]) 57 | 58 | dag.optimize([LLMDecodingPipeliningPass()]) 59 | 60 | end_time=time.time() 61 | 62 | print_key_info(f"Time taken: {end_time - begin_time} seconds") 63 | 64 | from Ayo.utils import print_key_info 65 | from Ayo.vis.vis_graph import visualize_dag_with_node_types 66 | 67 | visualize_dag_with_node_types(dag, output_path="test_embed_ingest_search_reranking.png") 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /examples/modules_to_primitives_indexing_searching.py: -------------------------------------------------------------------------------- 1 | import time 2 | from Ayo.modules.indexing import IndexingModule 3 | from Ayo.modules.searching import SearchingModule 4 | from Ayo.dags.dag import DAG 5 | from Ayo.dags.node import Node 6 | from Ayo.dags.node_commons import NodeType, NodeOps, NodeIOSchema 7 | from Ayo.engines.engine_types import EngineType 8 | from Ayo.opt_pass.pruning_dependency import PruningDependencyPass 9 | from Ayo.opt_pass.stage_decomposition import StageDecompositionPass 10 | 11 | 12 | 13 | indexing_module = IndexingModule() 14 | searching_module = SearchingModule() 15 | 16 | 17 | indexing_nodes = indexing_module.to_primitive_nodes() 18 | searching_nodes = searching_module.to_primitive_nodes() 19 | 20 | indexing_nodes[-1] >> searching_nodes[0] 21 | 22 | chained_nodes = indexing_nodes + searching_nodes 23 | print(chained_nodes) 24 | 25 | dag=DAG(dag_id="test_module_to_primitives") 26 | 27 | dag.register_nodes(*chained_nodes) 28 | 29 | dag.set_query_inputs( 30 | { 31 | 'passages': [ 32 | 'passages1', 33 | 'passages2', 34 | 'passages3', 35 | 'passages4', 36 | 'passages5', 37 | ]*100, 38 | 'queries': [ 39 | 'query1', 40 | ] 41 | } 42 | ) 43 | 44 | print(dag.get_full_dag_nodes_info()) 45 | 46 | begin_time = time.time() 47 | dag.optimize([PruningDependencyPass(), StageDecompositionPass()]) 48 | end_time = time.time() 49 | 50 | print(f"\033[91mTime taken: {end_time - begin_time} seconds\033[0m") 51 | 52 | print(dag.get_full_dag_nodes_info()) 53 | 54 | 55 | from Ayo.vis.vis_graph import visualize_dag_with_node_types 56 | 57 | visualize_dag_with_node_types(dag, "test_module_to_primitives.png") 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /examples/optimized_dag_for_embedding_ingestion_searching_reranking_llm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NetX-lab/Ayo/61607bcb31ed98e1101c5d60d8573105bf60d87c/examples/optimized_dag_for_embedding_ingestion_searching_reranking_llm.png -------------------------------------------------------------------------------- /examples/optimized_embedding_ingestion_searching.py: -------------------------------------------------------------------------------- 1 | from Ayo.app import APP 2 | from Ayo.dags.node import Node, NodeAnnotation, NodeType, NodeIOSchema, NodeOps 3 | from Ayo.dags.dag import DAG 4 | from Ayo.queries.query import Query 5 | from Ayo.configs.config import EngineConfig 6 | from Ayo.engines.engine_types import ENGINE_REGISTRY, EngineType 7 | import time 8 | import asyncio 9 | import ray 10 | from copy import deepcopy 11 | from typing import List, Any, Dict 12 | import traceback 13 | 14 | # Create embedding engine config 15 | default_embed_config = ENGINE_REGISTRY.get_default_config(EngineType.EMBEDDER) 16 | default_embed_config.update( 17 | { 18 | "model_name": "BAAI/bge-large-en-v1.5", 19 | "max_batch_size": 1024, 20 | } 21 | ) 22 | embed_config = EngineConfig( 23 | name="embedding_service", 24 | engine_type=EngineType.EMBEDDER, 25 | resources={}, 26 | num_gpus=1, 27 | num_cpus=1, 28 | instances=1, 29 | model_config={ 30 | **default_embed_config, 31 | "device": "cuda" 32 | }, 33 | latency_profile={ 34 | "timeout": 300, 35 | "batch_wait": 0.1 36 | } 37 | ) 38 | 39 | # Create search engine config 40 | vectordb_config = EngineConfig( 41 | name="vector_db_service", 42 | engine_type=EngineType.VECTOR_DB, 43 | resources={}, 44 | num_gpus=0, 45 | num_cpus=4, 46 | instances=1, 47 | model_config={ 48 | "host": "localhost", 49 | "port": 5432, 50 | "user": "asplos25", 51 | "password": "123456", 52 | "database": "database_asplos", 53 | "vector_dim": 1024, 54 | "max_batch_size": 1000 55 | }, 56 | latency_profile={ 57 | "timeout": 60, 58 | "batch_wait": 0.05 59 | } 60 | ) 61 | 62 | 63 | 64 | 65 | 66 | def create_base_dag(): 67 | """Create basic DAG""" 68 | base_dag = DAG(dag_id="random-test") 69 | 70 | # Create embedding node 71 | passages_embedding_node = Node( 72 | name="Embedding", 73 | node_type=NodeType.COMPUTE, 74 | engine_type=EngineType.EMBEDDER, 75 | io_schema=NodeIOSchema( 76 | input_format={"passages": List[str]}, 77 | output_format={"passages_embeddings": List[List[float]]} 78 | ), 79 | op_type=NodeOps.EMBEDDING, 80 | anno=NodeAnnotation.BATCHABLE, 81 | config={ 82 | "batch_size": embed_config.model_config.get("max_batch_size", 1024) 83 | } 84 | ) 85 | 86 | ingestion_node = Node( 87 | name="Ingestion", 88 | node_type=NodeType.COMPUTE, 89 | engine_type=EngineType.VECTOR_DB, 90 | io_schema=NodeIOSchema( 91 | input_format={"passages": List[str], "passages_embeddings": List[List[float]]}, 92 | output_format={"index_status": bool} 93 | ), 94 | op_type=NodeOps.VECTORDB_INGESTION, 95 | anno=NodeAnnotation.BATCHABLE, 96 | config={ 97 | "batch_size": embed_config.model_config.get("max_batch_size", 256) 98 | } 99 | ) 100 | 101 | query_embedding_node = Node( 102 | name="QueryEmbedding", 103 | node_type=NodeType.COMPUTE, 104 | engine_type=EngineType.EMBEDDER, 105 | io_schema=NodeIOSchema( 106 | input_format={"queries": List[str]}, 107 | output_format={"queries_embeddings": List[List[float]]} 108 | ), 109 | op_type=NodeOps.EMBEDDING, 110 | anno=NodeAnnotation.BATCHABLE, 111 | config={ 112 | "batch_size": embed_config.model_config.get("max_batch_size", 256) 113 | } 114 | ) 115 | 116 | # Create search node 117 | search_node = Node( 118 | name="Search", 119 | node_type=NodeType.COMPUTE, 120 | engine_type=EngineType.VECTOR_DB, 121 | io_schema=NodeIOSchema( 122 | input_format={"queries_embeddings": List[List[float]], "index_status": bool}, 123 | output_format={"search_results": List[List[str]]} 124 | ), 125 | op_type=NodeOps.VECTORDB_SEARCHING, 126 | anno=NodeAnnotation.BATCHABLE, 127 | config={ 128 | "batch_size": 16, 129 | "top_k": 5 130 | } 131 | ) 132 | 133 | 134 | base_dag.set_query_inputs( 135 | { 136 | 137 | "passages": 138 | [ 139 | "OSDI is a conference about operating systems..." * 20, 140 | "MICRO is a conference about computer architecture..." * 20, 141 | "HPCA is a conference about computer architecture..." * 20, 142 | "MLSYS is a conference about machine learning..." * 20, 143 | "Machine learning system design is a conference about ..." * 20, 144 | "AI is a branch of computer science..." * 20, 145 | "Machine learning is a subset of AI using statistical models " * 20, 146 | "Deep learning revolutionized AI using neural networks " * 20, 147 | "The sun is the largest planet in the solar system..." * 20, 148 | "The moon is the only natural satellite of the earth..." * 20, 149 | "The earth is the third planet from the sun..." * 20, 150 | ]*80, 151 | "queries": [ 152 | "I want to know some knowledge about the top computer system conferences." 153 | ] 154 | } 155 | ) 156 | 157 | passages_embedding_node >> ingestion_node >> query_embedding_node>> search_node 158 | #query_embedding_node >> search_node 159 | base_dag.register_nodes(passages_embedding_node, ingestion_node, query_embedding_node, search_node) 160 | 161 | return base_dag 162 | 163 | 164 | async def process_query(app, queries: List[str], passages: List[str], dag, query_id: str): 165 | """Add a query to the app and process it""" 166 | try: 167 | query = Query( 168 | uuid=f"random-test-{query_id}", 169 | query_id=f"random-test-{query_id}", 170 | query_inputs={ 171 | "passages": passages, 172 | "queries": queries 173 | }, 174 | DAG=deepcopy(dag) 175 | ) 176 | 177 | future = await app.submit_query( 178 | query=query, 179 | timeout=300 180 | ) 181 | 182 | result = await asyncio.wait_for(future, timeout=300) 183 | return result 184 | 185 | except Exception as e: 186 | print(f"Query {query_id} processing failed:\n{traceback.format_exc()}") 187 | raise Exception(f"Query {query_id} processing failed: {str(e)}") 188 | 189 | 190 | async def run_app(dag): 191 | try: 192 | # initialize the app 193 | app = APP.init() 194 | app.register_engine(embed_config) 195 | 196 | app.register_engine(vectordb_config) 197 | 198 | app.update_template(dag) 199 | 200 | # start the app 201 | await app.start() 202 | await asyncio.sleep(5) 203 | 204 | async def delayed_query(query_data, index): 205 | if index > 0: 206 | await asyncio.sleep(3 * index) 207 | return await process_query( 208 | app, 209 | query_data["queries"], 210 | query_data["passages"], 211 | dag, 212 | str(index) 213 | ) 214 | 215 | # prepare test data 216 | test_queries = [ 217 | { 218 | "passages": 219 | [ 220 | "OSDI is a conference about operating systems..." * 20, 221 | "MICRO is a conference about computer architecture..." * 20, 222 | "HPCA is a conference about computer architecture..." * 20, 223 | "MLSYS is a conference about machine learning..." * 20, 224 | "Machine learning system design is a conference about ..." * 20, 225 | "AI is a branch of computer science..." * 20, 226 | "Machine learning is a subset of AI using statistical models " * 20, 227 | "Deep learning revolutionized AI using neural networks " * 20, 228 | "The sun is the largest planet in the solar system..." * 20, 229 | "The moon is the only natural satellite of the earth..." * 20, 230 | "The earth is the third planet from the sun..." * 20, 231 | ]*80, 232 | "queries": [ 233 | "I want to know some knowledge about the top computer system conferences." 234 | ] 235 | }, 236 | 237 | ] 238 | 239 | # Create all tasks 240 | tasks = [ 241 | delayed_query(query_data, i) 242 | for i, query_data in enumerate(test_queries) 243 | ] 244 | 245 | # Wait for all queries to complete 246 | results = await asyncio.gather(*tasks) 247 | 248 | # Print results 249 | for i, result in enumerate(results): 250 | print(f"\nQuery {i} results:") 251 | print(f"Results: {result}") 252 | 253 | except Exception as e: 254 | print(f"Main program error stack:\n{traceback.format_exc()}") 255 | print(f"Main program error: {e}") 256 | raise 257 | finally: 258 | # Cleanup 259 | try: 260 | await cleanup(app) 261 | except Exception as e: 262 | print(f"Cleanup process error: {e}") 263 | 264 | async def cleanup(app): 265 | """clear resources""" 266 | try: 267 | await app.stop() 268 | app.shutdown() 269 | except Exception as e: 270 | print(f"Cleanup failed:\n{traceback.format_exc()}") 271 | print(f"Cleanup failed: {str(e)}") 272 | 273 | 274 | if __name__ == "__main__": 275 | from Ayo.vis.vis_graph import visualize_dag_with_node_types 276 | 277 | dag = create_base_dag() 278 | 279 | print(dag.get_full_dag_nodes_info()) 280 | 281 | visualize_dag_with_node_types(dag, "before_optimize_embedd_ingest_search.png") 282 | 283 | #optimize the dag 284 | from Ayo.opt_pass.pruning_dependency import PruningDependencyPass 285 | from Ayo.opt_pass.stage_decomposition import StageDecompositionPass 286 | dag.optimize([PruningDependencyPass(), StageDecompositionPass()]) 287 | 288 | #print the dag 289 | print(dag.get_full_dag_nodes_info()) 290 | 291 | visualize_dag_with_node_types(dag, "optimize_embedd_ingest_search.png") 292 | 293 | 294 | asyncio.run(run_app(dag)) -------------------------------------------------------------------------------- /examples/optimized_embedding_ingestion_searching_reranking.py: -------------------------------------------------------------------------------- 1 | from Ayo.app import APP 2 | from Ayo.dags.node import Node, NodeAnnotation, NodeType, NodeIOSchema, NodeOps 3 | from Ayo.dags.dag import DAG 4 | from Ayo.queries.query import Query 5 | from Ayo.configs.config import EngineConfig 6 | from Ayo.engines.engine_types import ENGINE_REGISTRY, EngineType 7 | import time 8 | import asyncio 9 | import ray 10 | from copy import deepcopy 11 | from typing import List, Any, Dict, Optional, Union 12 | import traceback 13 | 14 | # Create embedding engine config 15 | default_embed_config = ENGINE_REGISTRY.get_default_config(EngineType.EMBEDDER) 16 | default_embed_config.update( 17 | { 18 | "model_name": "BAAI/bge-large-en-v1.5", 19 | "max_batch_size": 1024, 20 | } 21 | ) 22 | embed_config = EngineConfig( 23 | name="embedding_service", 24 | engine_type=EngineType.EMBEDDER, 25 | resources={}, 26 | num_gpus=1, 27 | num_cpus=1, 28 | instances=1, 29 | model_config={ 30 | **default_embed_config, 31 | "device": "cuda" 32 | }, 33 | latency_profile={ 34 | "timeout": 300, 35 | "batch_wait": 0.1 36 | } 37 | ) 38 | 39 | # Create search engine config 40 | vectordb_config = EngineConfig( 41 | name="vector_db_service", 42 | engine_type=EngineType.VECTOR_DB, 43 | resources={}, 44 | num_gpus=0, 45 | num_cpus=4, 46 | instances=1, 47 | model_config={ 48 | "host": "localhost", 49 | "port": 5432, 50 | "user": "asplos25", 51 | "password": "123456", 52 | "database": "database_asplos", 53 | "vector_dim": 1024, 54 | "max_batch_size": 1000 55 | }, 56 | latency_profile={ 57 | "timeout": 60, 58 | "batch_wait": 0.05 59 | } 60 | ) 61 | 62 | # Create reranker engine config 63 | default_reranker_config = ENGINE_REGISTRY.get_default_config(EngineType.RERANKER) 64 | default_reranker_config.update( 65 | { 66 | "model_name": "BAAI/bge-reranker-large", 67 | "max_batch_size": 32, 68 | } 69 | ) 70 | reranker_config = EngineConfig( 71 | name="reranker_service", 72 | engine_type=EngineType.RERANKER, 73 | resources={}, 74 | num_gpus=1, 75 | num_cpus=1, 76 | instances=1, 77 | model_config={ 78 | **default_reranker_config, 79 | "device": "cuda" 80 | }, 81 | latency_profile={ 82 | "timeout": 60, 83 | "batch_wait": 0.1 84 | } 85 | ) 86 | 87 | def create_base_dag(): 88 | """Create basic DAG""" 89 | base_dag = DAG(dag_id="random-test") 90 | 91 | # Create embedding node 92 | passages_embedding_node = Node( 93 | name="Embedding", 94 | node_type=NodeType.COMPUTE, 95 | engine_type=EngineType.EMBEDDER, 96 | io_schema=NodeIOSchema( 97 | input_format={"passages": List[str]}, 98 | output_format={"passages_embeddings": List[List[float]]} 99 | ), 100 | op_type=NodeOps.EMBEDDING, 101 | anno=NodeAnnotation.BATCHABLE, 102 | config={ 103 | 104 | } 105 | ) 106 | 107 | ingestion_node = Node( 108 | name="Ingestion", 109 | node_type=NodeType.COMPUTE, 110 | engine_type=EngineType.VECTOR_DB, 111 | io_schema=NodeIOSchema( 112 | input_format={"passages": List[str], "passages_embeddings": List[List[float]]}, 113 | output_format={"index_status": bool} 114 | ), 115 | op_type=NodeOps.VECTORDB_INGESTION, 116 | anno=NodeAnnotation.BATCHABLE, 117 | config={ 118 | 119 | } 120 | ) 121 | 122 | query_embedding_node = Node( 123 | name="QueryEmbedding", 124 | node_type=NodeType.COMPUTE, 125 | engine_type=EngineType.EMBEDDER, 126 | io_schema=NodeIOSchema( 127 | input_format={"queries": List[str]}, 128 | output_format={"queries_embeddings": List[List[float]]} 129 | ), 130 | op_type=NodeOps.EMBEDDING, 131 | anno=NodeAnnotation.BATCHABLE, 132 | config={ 133 | 134 | } 135 | ) 136 | 137 | # Create search node 138 | search_node = Node( 139 | name="Search", 140 | node_type=NodeType.COMPUTE, 141 | engine_type=EngineType.VECTOR_DB, 142 | io_schema=NodeIOSchema( 143 | input_format={"queries_embeddings": List[List[float]], "index_status": bool}, 144 | output_format={"search_results": List[List[str]]} 145 | ), 146 | op_type=NodeOps.VECTORDB_SEARCHING, 147 | anno=NodeAnnotation.BATCHABLE, 148 | config={ 149 | "top_k": 30 # Add more retrieval to provide more candidates for reranking 150 | } 151 | ) 152 | 153 | # Create reranking node 154 | reranking_node = Node( 155 | name="Reranking", 156 | node_type=NodeType.COMPUTE, 157 | engine_type=EngineType.RERANKER, 158 | io_schema=NodeIOSchema( 159 | input_format={"queries": str, "search_results": Union[List[List[str]],List[str]]}, 160 | output_format={"reranked_results": List[List[str]]} 161 | ), 162 | op_type=NodeOps.RERANKING, 163 | anno=NodeAnnotation.BATCHABLE, 164 | config={ 165 | "top_k": 10 # The number of results to return 166 | } 167 | ) 168 | 169 | base_dag.set_query_inputs( 170 | { 171 | "passages": [ 172 | "AI is a branch of computer science..." * 15, 173 | "Machine learning is a subset of AI..." * 15, 174 | "Deep learning revolutionized AI..." * 15, 175 | "I have a question about AI..." * 15, 176 | "I have a question about machine learning..." * 15, 177 | 'What is the latest news about AI?' * 15, 178 | 'What is the latest news about machine learning?' * 15, 179 | 'What is the latest news about deep learning?' * 15, 180 | 'What is the latest news about AI?' * 15, 181 | 'What is the latest news about machine learning?' * 15, 182 | 'What is the latest news about deep learning?' * 15, 183 | ]*60 + [ 184 | "OSDI is a conference about operating systems..." * 15, 185 | "ASPLOS is a conference about computer architecture..." * 15, 186 | "MICRO is a conference about computer architecture..." * 15, 187 | "HPCA is a conference about computer architecture..." * 15, 188 | "MLSYS is a conference about machine learning..." * 15, 189 | "Machine learning system design is a conference about machine ..." * 15, 190 | ], 191 | "queries": [ 192 | "I want to know some system conferences." 193 | ] 194 | }, 195 | ) 196 | 197 | 198 | passages_embedding_node >> ingestion_node >> query_embedding_node >> search_node >> reranking_node 199 | base_dag.register_nodes(passages_embedding_node, ingestion_node, query_embedding_node, search_node, reranking_node) 200 | 201 | return base_dag 202 | 203 | 204 | async def process_query(app, queries: List[str], passages: List[str], dag, query_id: str): 205 | """Add a query to the app and process it""" 206 | try: 207 | query = Query( 208 | uuid=f"random-test-{query_id}", 209 | query_id=f"random-test-{query_id}", 210 | query_inputs={ 211 | "passages": passages, 212 | "queries": queries 213 | }, 214 | DAG=deepcopy(dag) 215 | ) 216 | 217 | future = await app.submit_query( 218 | query=query, 219 | timeout=300 220 | ) 221 | 222 | result = await asyncio.wait_for(future, timeout=300) 223 | return result 224 | 225 | except Exception as e: 226 | print(f"Query {query_id} processing failed:\n{traceback.format_exc()}") 227 | raise Exception(f"Query {query_id} processing failed: {str(e)}") 228 | 229 | 230 | async def run_app(dag): 231 | try: 232 | # Initialize the application 233 | app = APP.init() 234 | app.register_engine(embed_config) 235 | app.register_engine(vectordb_config) 236 | app.register_engine(reranker_config) # Register the reranker engine 237 | 238 | 239 | # Start the application 240 | await app.start() 241 | await asyncio.sleep(5) 242 | 243 | async def delayed_query(query_data, index): 244 | if index > 0: 245 | await asyncio.sleep(3 * index) 246 | return await process_query( 247 | app, 248 | query_data["queries"], 249 | query_data["passages"], 250 | dag, 251 | str(index) 252 | ) 253 | 254 | # Prepare test data 255 | test_queries = [ 256 | { 257 | "passages": [ 258 | "AI is a branch of computer science..." * 15, 259 | "Machine learning is a subset of AI..." * 15, 260 | "Deep learning revolutionized AI..." * 15, 261 | "I have a question about AI..." * 15, 262 | "I have a question about machine learning..." * 15, 263 | 'What is the latest news about AI?' * 15, 264 | 'What is the latest news about machine learning?' * 15, 265 | 'What is the latest news about deep learning?' * 15, 266 | 'What is the latest news about AI?' * 15, 267 | 'What is the latest news about machine learning?' * 15, 268 | 'What is the latest news about deep learning?' * 15, 269 | ]*60 + [ 270 | "OSDI is a conference about operating systems..." * 15, 271 | "ASPLOS is a conference about computer architecture..." * 15, 272 | "MICRO is a conference about computer architecture..." * 15, 273 | "HPCA is a conference about computer architecture..." * 15, 274 | "MLSYS is a conference about machine learning..." * 15, 275 | "Machine learning system design is a conference about machine ..." * 15, 276 | ], 277 | "queries": [ 278 | "I want to know some system conferences." 279 | ] 280 | }, 281 | ] 282 | 283 | # create all tasks 284 | tasks = [ 285 | delayed_query(query_data, i) 286 | for i, query_data in enumerate(test_queries) 287 | ] 288 | 289 | # wait for all queries to complete 290 | results = await asyncio.gather(*tasks) 291 | 292 | # print results 293 | for i, result in enumerate(results): 294 | print(f"\nQuery {i} results: {result}") 295 | for key, value in result.items(): 296 | if 'search' in key.lower(): 297 | print(f"Search results length: {len(value)}") 298 | elif 'reranking' in key.lower(): 299 | print(f"Reranking results length: {len(value)}") 300 | 301 | except Exception as e: 302 | print(f"Main program error stack:\n{traceback.format_exc()}") 303 | print(f"Main program error: {e}") 304 | raise 305 | finally: 306 | # clean up 307 | try: 308 | await cleanup(app) 309 | except Exception as e: 310 | print(f"Cleanup process error: {e}") 311 | 312 | async def cleanup(app): 313 | """Clean up resources""" 314 | try: 315 | await app.stop() 316 | app.shutdown() 317 | except Exception as e: 318 | print(f"Cleanup failed:\n{traceback.format_exc()}") 319 | print(f"Cleanup failed: {str(e)}") 320 | 321 | 322 | if __name__ == "__main__": 323 | from Ayo.vis.vis_graph import visualize_dag_with_node_types 324 | 325 | dag = create_base_dag() 326 | 327 | print(dag.get_full_dag_nodes_info()) 328 | 329 | visualize_dag_with_node_types(dag, "before_optimize_embedd_ingest_search_reranking.png") 330 | 331 | # optimize DAG 332 | from Ayo.opt_pass.pruning_dependency import PruningDependencyPass 333 | from Ayo.opt_pass.stage_decomposition import StageDecompositionPass 334 | dag.optimize([PruningDependencyPass(), StageDecompositionPass()]) 335 | 336 | # print the optimized DAG 337 | print(dag.get_full_dag_nodes_info()) 338 | 339 | visualize_dag_with_node_types(dag, "optimize_embedd_ingest_search_reranking.png") 340 | 341 | asyncio.run(run_app(dag)) 342 | -------------------------------------------------------------------------------- /examples/test_dag.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from Ayo.dags.node import Node, NodeType, NodeAnnotation, NodeIOSchema 3 | from Ayo.dags.dag import DAG 4 | from Ayo.engines.engine_types import EngineType 5 | def test_simple_linear_dag(): 6 | """Test simple linear DAG""" 7 | # Create node IO schema 8 | embed_schema = NodeIOSchema( 9 | input_format={"texts": str}, 10 | output_format={"embeddings": list} 11 | ) 12 | 13 | rerank_schema = NodeIOSchema( 14 | input_format={"embeddings": list}, 15 | output_format={"scores": list} 16 | ) 17 | 18 | # Create compute node 19 | node1 = Node( 20 | name="Embedding1", 21 | node_type=NodeType.COMPUTE, 22 | engine_type=EngineType.EMBEDDER, 23 | io_schema=embed_schema, 24 | anno=NodeAnnotation.BATCHABLE, 25 | in_kwargs={"texts": None}, 26 | out_kwargs={"embeddings": None} 27 | ) 28 | 29 | node2 = Node( 30 | name="Rerank1", 31 | node_type=NodeType.COMPUTE, 32 | engine_type=EngineType.RERANKER, 33 | io_schema=rerank_schema, 34 | anno=NodeAnnotation.BATCHABLE, 35 | in_kwargs={"embeddings": None}, 36 | out_kwargs={"scores": None} 37 | ) 38 | 39 | # Create DAG 40 | dag = DAG(dag_id="test_linear") 41 | 42 | # Set query inputs 43 | dag.set_query_inputs({ 44 | "texts": ["This is a test text"] 45 | }) 46 | 47 | # Build dependencies 48 | node1 >> node2 49 | 50 | # Register nodes 51 | dag.register_nodes(node1, node2) 52 | 53 | # Validate topological sort 54 | sorted_nodes = dag.topological_sort() 55 | print(sorted_nodes) 56 | assert len(sorted_nodes) == 4 # input_node + 2个计算节点 + output_node 57 | assert sorted_nodes[0].node_type == NodeType.INPUT 58 | assert sorted_nodes[1] == node1 59 | assert sorted_nodes[2] == node2 60 | assert sorted_nodes[3].node_type == NodeType.OUTPUT 61 | 62 | def test_diamond_dag(): 63 | """Test diamond DAG""" 64 | # Create node IO schema 65 | embed_schema = NodeIOSchema( 66 | input_format={"texts": str}, 67 | output_format={"embeddings": list} 68 | ) 69 | 70 | process_schema = NodeIOSchema( 71 | input_format={"embeddings": list}, 72 | output_format={"processed": list} 73 | ) 74 | 75 | merge_schema = NodeIOSchema( 76 | input_format={"processed1": list, "processed2": list}, 77 | output_format={"final": list} 78 | ) 79 | 80 | # Create compute node 81 | node1 = Node( 82 | name="Embedding1", 83 | node_type=NodeType.COMPUTE, 84 | engine_type=EngineType.EMBEDDER, 85 | io_schema=embed_schema, 86 | in_kwargs={"texts": None}, 87 | out_kwargs={"embeddings": None} 88 | ) 89 | 90 | node2 = Node( 91 | name="Process1", 92 | node_type=NodeType.COMPUTE, 93 | engine_type=EngineType.DUMMY, 94 | io_schema=process_schema, 95 | in_kwargs={"embeddings": None}, 96 | out_kwargs={"processed1": None} 97 | ) 98 | 99 | node3 = Node( 100 | name="Process2", 101 | node_type=NodeType.COMPUTE, 102 | engine_type=EngineType.DUMMY, 103 | io_schema=process_schema, 104 | in_kwargs={"embeddings": None}, 105 | out_kwargs={"processed2": None} 106 | ) 107 | 108 | node4 = Node( 109 | name="Merge", 110 | node_type=NodeType.COMPUTE, 111 | engine_type=EngineType.AGGREGATOR, 112 | io_schema=merge_schema, 113 | in_kwargs={"processed1": None, "processed2": None}, 114 | out_kwargs={"final": None} 115 | ) 116 | 117 | # Create DAG 118 | dag = DAG(dag_id="test_diamond") 119 | 120 | # Set query inputs 121 | dag.set_query_inputs({ 122 | "texts": ["This is a test text"] 123 | }) 124 | 125 | # Build dependencies 126 | node1 >> node2 127 | node1 >> node3 128 | node2 >> node4 129 | node3 >> node4 130 | 131 | # Register nodes 132 | dag.register_nodes(node1, node2, node3, node4) 133 | 134 | # Validate topological sort 135 | sorted_nodes = dag.topological_sort() 136 | print(sorted_nodes) 137 | assert len(sorted_nodes) == 6 # input_node + 4 compute nodes + output_node 138 | assert sorted_nodes[0].node_type == NodeType.INPUT 139 | assert sorted_nodes[1] == node1 140 | assert set(sorted_nodes[2:4]) == {node2, node3} 141 | assert sorted_nodes[4] == node4 142 | assert sorted_nodes[5].node_type == NodeType.OUTPUT 143 | 144 | def test_cyclic_dag(): 145 | """Test cyclic DAG (should raise an exception)""" 146 | # Create node IO schema 147 | process_schema = NodeIOSchema( 148 | input_format={"input": str}, 149 | output_format={"output": str} 150 | ) 151 | 152 | # Create compute node 153 | node1 = Node( 154 | name="Node1", 155 | node_type=NodeType.COMPUTE, 156 | engine_type=EngineType.DUMMY, 157 | io_schema=process_schema, 158 | in_kwargs={"input": None}, 159 | out_kwargs={"output": None} 160 | ) 161 | 162 | node2 = Node( 163 | name="Node2", 164 | node_type=NodeType.COMPUTE, 165 | engine_type=EngineType.DUMMY, 166 | io_schema=process_schema, 167 | in_kwargs={"input": None}, 168 | out_kwargs={"output": None} 169 | ) 170 | 171 | # Create DAG 172 | dag = DAG(dag_id="test_cyclic") 173 | 174 | # Set query inputs 175 | dag.set_query_inputs({ 176 | "input": "test" 177 | }) 178 | 179 | # Build dependencies 180 | node1 >> node2 181 | node2 >> node1 182 | 183 | # Register nodes 184 | dag.register_nodes(node1, node2) 185 | 186 | # Validate if an exception is raised 187 | with pytest.raises(ValueError, match="Cycle detected in DAG"): 188 | dag.topological_sort() 189 | 190 | def test_dag_with_query_inputs(): 191 | """Test DAG with query inputs""" 192 | # Create node IO schema 193 | embed_schema = NodeIOSchema( 194 | input_format={"texts": str}, 195 | output_format={"embeddings": list} 196 | ) 197 | 198 | # Create compute node 199 | compute_node = Node( 200 | name="Embedding", 201 | node_type=NodeType.COMPUTE, 202 | engine_type=EngineType.EMBEDDER, 203 | io_schema=embed_schema, 204 | in_kwargs={"texts": None}, 205 | out_kwargs={"embeddings": None} 206 | ) 207 | 208 | # Create DAG 209 | dag = DAG(dag_id="test_with_inputs") 210 | 211 | # Set query inputs 212 | dag.set_query_inputs({ 213 | "texts": ["This is a test text"] 214 | }) 215 | 216 | # Register nodes 217 | dag.register_nodes(compute_node) 218 | 219 | # Validate if the input node is created correctly 220 | assert len(dag.input_nodes) == 1 221 | print(dag.input_nodes) 222 | print(list(dag.input_nodes.values())[0].input_kwargs) 223 | assert "texts" in list(dag.input_nodes.values())[0].input_kwargs 224 | 225 | # Validate topological sort 226 | sorted_nodes = dag.topological_sort() 227 | assert len(sorted_nodes) == 3 # input_node + compute_node + output_node 228 | assert sorted_nodes[0].node_type == NodeType.INPUT 229 | assert sorted_nodes[1] == compute_node 230 | assert sorted_nodes[2].node_type == NodeType.OUTPUT 231 | 232 | if __name__ == "__main__": 233 | pytest.main([__file__]) 234 | -------------------------------------------------------------------------------- /examples/test_embedding_service.py: -------------------------------------------------------------------------------- 1 | from Ayo.app import APP 2 | from Ayo.dags.node import Node, NodeAnnotation, NodeType, NodeIOSchema 3 | from Ayo.dags.dag import DAG 4 | from Ayo.queries.query import Query 5 | from Ayo.configs.config import EngineConfig 6 | from Ayo.engines.engine_types import ENGINE_REGISTRY, EngineType 7 | import time 8 | import asyncio 9 | import ray 10 | from copy import deepcopy 11 | from typing import List, Any 12 | import traceback 13 | 14 | # Create engine config, using the default config of ENGINE_REGISTRY 15 | default_embed_config = ENGINE_REGISTRY.get_default_config(EngineType.EMBEDDER) 16 | embed_config = EngineConfig( 17 | name="embedding_service", 18 | engine_type=EngineType.EMBEDDER, 19 | resources={}, 20 | num_gpus=1, 21 | num_cpus=1, 22 | instances=2, 23 | model_config={ 24 | **default_embed_config, 25 | "device": "cuda" 26 | }, 27 | latency_profile={ 28 | "timeout": 300, 29 | "batch_wait": 0.1 30 | } 31 | ) 32 | 33 | def create_base_dag(): 34 | """Create base DAG""" 35 | base_dag = DAG(dag_id="embedding_service") 36 | 37 | # Create embedding node with correct parameters 38 | embedding_node = Node( 39 | name="Embedding", 40 | node_type=NodeType.COMPUTE, 41 | engine_type=EngineType.EMBEDDER, 42 | io_schema=NodeIOSchema( 43 | input_format={"texts": List[str]}, 44 | output_format={"embeddings": List[Any]} 45 | ), 46 | anno=NodeAnnotation.BATCHABLE, 47 | config={ 48 | "batch_size": embed_config.model_config.get("max_batch_size", 256) 49 | } 50 | ) 51 | 52 | # Register node to DAG 53 | base_dag.register_nodes(embedding_node) 54 | 55 | return base_dag 56 | 57 | async def get_embeddings(app, texts, dag, query_id): 58 | try: 59 | # Create query 60 | query = Query( 61 | uuid=f"embed_{query_id}", 62 | query_id=f"embed_{query_id}", 63 | query_inputs={"texts": texts}, 64 | DAG=deepcopy(dag) 65 | ) 66 | 67 | # Submit query and get result 68 | future = await app.submit_query( 69 | query=query, 70 | timeout=embed_config.latency_profile.get("timeout", 30) 71 | ) 72 | 73 | try: 74 | # Wait for result 75 | result = await asyncio.wait_for( 76 | future, 77 | timeout=embed_config.latency_profile.get("timeout", 30) 78 | ) 79 | return result 80 | 81 | except asyncio.TimeoutError: 82 | raise TimeoutError(f"Query timed out after {embed_config.latency_profile.get('timeout', 30)} seconds") 83 | 84 | except Exception as e: 85 | print(f"Embedding processing error stack:\n{traceback.format_exc()}") 86 | raise Exception(f"Embedding processing failed: {str(e)}") 87 | 88 | async def cleanup(app): 89 | """Cleanup resources""" 90 | try: 91 | await app.stop() 92 | app.shutdown() 93 | except Exception as e: 94 | print(f"Cleanup error stack:\n{traceback.format_exc()}") 95 | print(f"Cleanup failed: {str(e)}") 96 | 97 | async def main(): 98 | try: 99 | # Initialize application 100 | app = APP.init() 101 | app.register_engine(embed_config) 102 | dag = create_base_dag() 103 | app.update_template(dag) 104 | 105 | # Start application 106 | await app.start() 107 | await asyncio.sleep(5) # Wait for system initialization 108 | 109 | # Prepare test text groups 110 | text_groups = [ 111 | [ 112 | "Artificial Intelligence is the future" *50, 113 | "Machine Learning is the important part of Artificial Intelligence" *50, 114 | "Deep Learning is the important part of Machine Learning" *50 115 | ]*80, 116 | [ 117 | "Natural Language Processing is essential" *50, 118 | "Computer Vision has many applications" *50, 119 | "Reinforcement Learning is fascinating" *50 120 | ]*80 121 | ] 122 | 123 | # Create delayed query tasks 124 | async def delayed_embedding(texts, index): 125 | if index > 0: 126 | await asyncio.sleep(2 * index) 127 | return await get_embeddings(app, texts, dag, str(index)) 128 | 129 | # Create all tasks 130 | tasks = [ 131 | delayed_embedding(texts, i) 132 | for i, texts in enumerate(text_groups) 133 | ] 134 | 135 | # Wait for all queries to complete 136 | results = await asyncio.gather(*tasks) 137 | 138 | # Print results 139 | for i, embeddings in enumerate(results): 140 | print(f"Query {i} embedding result: {embeddings}") 141 | 142 | except Exception as e: 143 | print(f"Main program error stack:\n{traceback.format_exc()}") 144 | print(f"Main program error: {e}") 145 | raise 146 | finally: 147 | # Cleanup 148 | try: 149 | await cleanup(app) 150 | except Exception as e: 151 | print(f"Cleanup process error: {e}") 152 | 153 | if __name__ == "__main__": 154 | try: 155 | asyncio.run(main()) 156 | except KeyboardInterrupt: 157 | print("\nReceived keyboard interrupt, shutting down...") 158 | except Exception as e: 159 | print(f"Fatal error stack:\n{traceback.format_exc()}") 160 | print(f"Fatal error: {e}") 161 | finally: 162 | if ray.is_initialized(): 163 | ray.shutdown() -------------------------------------------------------------------------------- /examples/test_multiple_llm_calls.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | import time 3 | import asyncio 4 | import traceback 5 | from typing import List, Dict, Any, Optional 6 | from copy import deepcopy 7 | 8 | from Ayo.app import APP 9 | from Ayo.dags.dag import DAG 10 | from Ayo.dags.node import Node, NodeAnnotation, NodeType, NodeIOSchema, NodeOps 11 | from Ayo.engines.engine_types import EngineType, ENGINE_REGISTRY 12 | from Ayo.configs.config import EngineConfig 13 | from Ayo.queries.query import Query 14 | from Ayo.modules.prompt_template import replace_placeholders 15 | from Ayo.opt_pass.pruning_dependency import PruningDependencyPass 16 | from Ayo.opt_pass.stage_decomposition import StageDecompositionPass 17 | from Ayo.opt_pass.prefilling_split import PrefillingSpiltPass 18 | from Ayo.opt_pass.decoding_pipeling import LLMDecodingPipeliningPass 19 | from Ayo.utils import print_key_info 20 | 21 | # Create LLM engine configuration 22 | default_llm_config = ENGINE_REGISTRY.get_default_config(EngineType.LLM) 23 | default_llm_config.update({ 24 | "model_name": "meta-llama/Llama-2-7b-chat-hf", 25 | "max_batch_size": 8, 26 | }) 27 | 28 | llm_config = EngineConfig( 29 | name="llm_service", 30 | engine_type=EngineType.LLM, 31 | resources={}, 32 | num_gpus=1, 33 | num_cpus=4, 34 | instances=1, 35 | model_config={ 36 | **default_llm_config, 37 | "device": "cuda" 38 | }, 39 | latency_profile={ 40 | "timeout": 300, 41 | "batch_wait": 0.1 42 | } 43 | ) 44 | 45 | def create_base_dag(): 46 | """Create base DAG""" 47 | dag = DAG(dag_id="dual_llm_workflow") 48 | 49 | # Generate unique IDs for each LLM 50 | answer_llm_id = str(uuid.uuid4()) 51 | enrich_llm_id = str(uuid.uuid4()) 52 | 53 | # First group: answer the query 54 | answer_prefilling_node = Node( 55 | name="AnswerPrefilling", 56 | node_type=NodeType.COMPUTE, 57 | engine_type=EngineType.LLM, 58 | io_schema=NodeIOSchema( 59 | input_format={"query": str}, 60 | output_format={"answer_prefill_state": bool} 61 | ), 62 | op_type=NodeOps.LLM_PREFILLING, 63 | anno=NodeAnnotation.BATCHABLE, 64 | config={ 65 | 'prompt_template': "Please answer the following question:\n\n{query}\n\nAnswer:", 66 | 'prompt': "Please answer the following question:\n\n{query}\n\nAnswer:", 67 | 'parse_json': False, 68 | 'partial_output': False, 69 | 'partial_prefilling': False, 70 | 'llm_internal_id': answer_llm_id, 71 | 'max_tokens': 512 72 | } 73 | ) 74 | 75 | answer_decoding_node = Node( 76 | name="AnswerDecoding", 77 | node_type=NodeType.COMPUTE, 78 | engine_type=EngineType.LLM, 79 | io_schema=NodeIOSchema( 80 | input_format={"answer_prefill_state": bool}, 81 | output_format={"answer": str} 82 | ), 83 | op_type=NodeOps.LLM_DECODING, 84 | anno=NodeAnnotation.BATCHABLE, 85 | config={ 86 | 'prompt_template': "Please answer the following question:\n\n{query}\n\nAnswer:", 87 | 'prompt': "Please answer the following question:\n\n{query}\n\nAnswer:", 88 | 'parse_json': False, 89 | 'partial_output': False, 90 | 'partial_prefilling': False, 91 | 'llm_internal_id': answer_llm_id, 92 | 'max_tokens': 512 93 | } 94 | ) 95 | 96 | # Second group: enrich the answer 97 | enrich_prefilling_node = Node( 98 | name="EnrichPrefilling", 99 | node_type=NodeType.COMPUTE, 100 | engine_type=EngineType.LLM, 101 | io_schema=NodeIOSchema( 102 | input_format={"query": str, "answer": str}, 103 | output_format={"enrich_prefill_state": bool} 104 | ), 105 | op_type=NodeOps.LLM_PREFILLING, 106 | anno=NodeAnnotation.BATCHABLE, 107 | config={ 108 | 'prompt_template': "Original question: {query}\n\nInitial answer: {answer}\n\nPlease enrich and expand the above answer, adding more details and examples:", 109 | 'prompt': "Original question: {query}\n\nInitial answer: {answer}\n\nPlease enrich and expand the above answer, adding more details and examples:", 110 | 'parse_json': False, 111 | 'partial_output': False, 112 | 'partial_prefilling': False, 113 | 'llm_internal_id': enrich_llm_id, 114 | 'max_tokens': 1024 115 | } 116 | ) 117 | 118 | enrich_decoding_node = Node( 119 | name="EnrichDecoding", 120 | node_type=NodeType.COMPUTE, 121 | engine_type=EngineType.LLM, 122 | io_schema=NodeIOSchema( 123 | input_format={"enrich_prefill_state": bool}, 124 | output_format={"enriched_answer": str} 125 | ), 126 | op_type=NodeOps.LLM_DECODING, 127 | anno=NodeAnnotation.BATCHABLE, 128 | config={ 129 | 'prompt_template': "Original question: {query}\n\nInitial answer: {answer}\n\nPlease enrich and expand the above answer, adding more details and examples:", 130 | 'prompt': "Original question: {query}\n\nInitial answer: {answer}\n\nPlease enrich and expand the above answer, adding more details and examples:", 131 | 'parse_json': False, 132 | 'partial_output': False, 133 | 'partial_prefilling': False, 134 | 'llm_internal_id': enrich_llm_id, 135 | 'max_tokens': 1024 136 | } 137 | ) 138 | 139 | # Connect nodes 140 | answer_prefilling_node >> answer_decoding_node >> enrich_prefilling_node >> enrich_decoding_node 141 | 142 | # Register all nodes 143 | dag.register_nodes( 144 | answer_prefilling_node, 145 | answer_decoding_node, 146 | enrich_prefilling_node, 147 | enrich_decoding_node 148 | ) 149 | 150 | # Set input 151 | dag.set_query_inputs({ 152 | 'query': "What is the impact of artificial intelligence on future society?" 153 | }) 154 | 155 | return dag 156 | 157 | async def process_query(app, query_input: str, dag, query_id: str): 158 | """Add a query to the app and process it""" 159 | try: 160 | query = Query( 161 | uuid=f"dual-llm-workflow-{query_id}", 162 | query_id=f"dual-llm-workflow-{query_id}", 163 | query_inputs={ 164 | "query": query_input 165 | }, 166 | DAG=deepcopy(dag) 167 | ) 168 | 169 | future = await app.submit_query( 170 | query=query, 171 | timeout=300 172 | ) 173 | 174 | result = await asyncio.wait_for(future, timeout=300) 175 | return result 176 | 177 | except Exception as e: 178 | print(f"Query {query_id} processing failed:\n{traceback.format_exc()}") 179 | raise Exception(f"Query {query_id} processing failed: {str(e)}") 180 | 181 | async def run_app(dag): 182 | try: 183 | # Initialize the application 184 | app = APP.init() 185 | app.register_engine(llm_config) 186 | 187 | # Start the application 188 | await app.start() 189 | await asyncio.sleep(5) 190 | 191 | async def delayed_query(query_text, index): 192 | if index > 0: 193 | await asyncio.sleep(3 * index) 194 | return await process_query( 195 | app, 196 | query_text, 197 | dag, 198 | str(index) 199 | ) 200 | 201 | # Prepare test data 202 | test_queries = [ 203 | "What is the impact of artificial intelligence on future society?", 204 | # "How can we balance technological advancement with ethical considerations?", 205 | # "What are the limitations of large language models?" 206 | ] 207 | 208 | # Create all tasks 209 | tasks = [ 210 | delayed_query(query_text, i) 211 | for i, query_text in enumerate(test_queries) 212 | ] 213 | 214 | # Wait for all queries to complete 215 | results = await asyncio.gather(*tasks) 216 | 217 | # Print results 218 | for i, result in enumerate(results): 219 | print(f"\nQuery {i} results:") 220 | print(result) 221 | 222 | except Exception as e: 223 | print(f"Main program error stack:\n{traceback.format_exc()}") 224 | print(f"Main program error: {e}") 225 | raise 226 | finally: 227 | # Clean up resources 228 | try: 229 | await cleanup(app) 230 | except Exception as e: 231 | print(f"Cleanup process error: {e}") 232 | 233 | async def cleanup(app): 234 | """Clean up resources""" 235 | try: 236 | await app.stop() 237 | app.shutdown() 238 | except Exception as e: 239 | print(f"Cleanup failed:\n{traceback.format_exc()}") 240 | print(f"Cleanup failed: {str(e)}") 241 | 242 | if __name__ == "__main__": 243 | from Ayo.vis.vis_graph import visualize_dag_with_node_types 244 | import os 245 | 246 | # Create DAG 247 | dag = create_base_dag() 248 | 249 | print("Original DAG node information:") 250 | print(dag.get_full_dag_nodes_info()) 251 | 252 | visualize_dag_with_node_types(dag, "before_optimize_dual_llm_workflow.png") 253 | 254 | # remove the png files 255 | os.remove("before_optimize_dual_llm_workflow.png") 256 | 257 | # Optimize DAG 258 | begin_time = time.time() 259 | dag.optimize([ 260 | PruningDependencyPass(), 261 | StageDecompositionPass(), 262 | PrefillingSpiltPass(), 263 | LLMDecodingPipeliningPass() 264 | ]) 265 | end_time = time.time() 266 | 267 | print_key_info(f"Optimization time: {end_time - begin_time} seconds") 268 | print("Optimized DAG node information:") 269 | print(dag.get_full_dag_nodes_info()) 270 | 271 | visualize_dag_with_node_types(dag, "optimized_dual_llm_workflow.png") 272 | os.remove("optimized_dual_llm_workflow.png") 273 | 274 | # Run the application 275 | asyncio.run(run_app(dag)) 276 | -------------------------------------------------------------------------------- /examples/test_reranking_service.py: -------------------------------------------------------------------------------- 1 | from Ayo.app import APP 2 | from Ayo.dags.node import Node, NodeAnnotation, NodeType, NodeIOSchema 3 | from Ayo.dags.dag import DAG 4 | from Ayo.queries.query import Query 5 | from Ayo.configs.config import EngineConfig 6 | from Ayo.engines.engine_types import ENGINE_REGISTRY, EngineType 7 | import time 8 | import asyncio 9 | import ray 10 | from copy import deepcopy 11 | from typing import List, Any, Dict 12 | import traceback 13 | 14 | # Create reranker engine config 15 | default_rerank_config = ENGINE_REGISTRY.get_default_config(EngineType.RERANKER) 16 | rerank_config = EngineConfig( 17 | name="reranker_service", 18 | engine_type=EngineType.RERANKER, 19 | resources={}, 20 | num_gpus=1, # GPU number for each instance 21 | num_cpus=1, # CPU number for each instance 22 | instances=1, # Run 2 instances 23 | model_config={ 24 | **default_rerank_config, # Use default config 25 | "device": "cuda", # Override specific config 26 | "max_batch_size": 128 # Batch size 27 | }, 28 | latency_profile={ 29 | "timeout": 300, # Timeout (seconds) 30 | "batch_wait": 0.1 # Batch wait time (seconds) 31 | } 32 | ) 33 | 34 | def create_base_dag(): 35 | """Create base DAG""" 36 | base_dag = DAG(dag_id="reranker_service") 37 | 38 | # Create reranker node 39 | reranker_node = Node( 40 | name="Reranker", 41 | node_type=NodeType.COMPUTE, 42 | engine_type=EngineType.RERANKER, 43 | io_schema=NodeIOSchema( 44 | input_format={ 45 | "query": str, 46 | "passages": List[str] 47 | }, 48 | output_format={ 49 | "ranked_results": List[Dict[str, Any]] 50 | } 51 | ), 52 | anno=NodeAnnotation.BATCHABLE, 53 | config={ 54 | 'top_k': 5 55 | } 56 | ) 57 | 58 | # Register nodes to DAG 59 | base_dag.register_nodes(reranker_node) 60 | 61 | return base_dag 62 | 63 | async def get_reranking(app, query: str, passages: List[str], dag, query_id: str): 64 | """Get reranking results""" 65 | try: 66 | # Create query 67 | query = Query( 68 | uuid=f"rerank_{query_id}", 69 | query_id=f"rerank_{query_id}", 70 | query_inputs={ 71 | "query": query, 72 | "passages": passages 73 | }, 74 | DAG=deepcopy(dag) 75 | ) 76 | 77 | # Submit query and get result 78 | future = await app.submit_query( 79 | query=query, 80 | timeout=rerank_config.latency_profile.get("timeout", 30) 81 | ) 82 | 83 | try: 84 | # Wait for result 85 | result = await asyncio.wait_for( 86 | future, 87 | timeout=rerank_config.latency_profile.get("timeout", 30) 88 | ) 89 | return result 90 | 91 | except asyncio.TimeoutError: 92 | raise TimeoutError(f"Query timed out after {rerank_config.latency_profile.get('timeout', 30)} seconds") 93 | 94 | except Exception as e: 95 | print(f"Reranking processing error stack:\n{traceback.format_exc()}") 96 | raise Exception(f"Reranking processing failed: {str(e)}") 97 | 98 | async def cleanup(app): 99 | """Cleanup resources""" 100 | try: 101 | await app.stop() 102 | app.shutdown() 103 | except Exception as e: 104 | print(f"Cleanup error stack:\n{traceback.format_exc()}") 105 | print(f"Cleanup failed: {str(e)}") 106 | 107 | async def main(): 108 | try: 109 | # Initialize application 110 | app = APP.init() 111 | app.register_engine(rerank_config) 112 | dag = create_base_dag() 113 | app.update_template(dag) 114 | 115 | # Start application 116 | await app.start() 117 | await asyncio.sleep(5) # Wait for system initialization 118 | 119 | # Prepare test data 120 | test_queries = [ 121 | { 122 | "query": "What is artificial intelligence?" * 5, 123 | "passages": [ 124 | "AI is a branch of computer science..." * 50, 125 | "Machine learning is a subset of AI..." * 50, 126 | "Deep learning revolutionized AI..." * 50 127 | ] * 12 128 | }, 129 | { 130 | "query": "How does natural language processing work?" * 5, 131 | "passages": [ 132 | "NLP combines linguistics and machine learning..." * 50, 133 | "Language models are key to NLP..." * 50, 134 | "Transformers architecture changed NLP..." * 50 135 | ] * 12 136 | } 137 | ] 138 | 139 | # Create delayed query tasks 140 | async def delayed_reranking(query_data, index): 141 | if index > 0: 142 | await asyncio.sleep(0 * index) 143 | return await get_reranking( 144 | app, 145 | query_data["query"], 146 | query_data["passages"], 147 | dag, 148 | str(index) 149 | ) 150 | 151 | # Create all tasks 152 | tasks = [ 153 | delayed_reranking(query_data, i) 154 | for i, query_data in enumerate(test_queries) 155 | ] 156 | 157 | # Wait for all queries to complete 158 | results = await asyncio.gather(*tasks) 159 | 160 | # Print results 161 | for i, ranked_results in enumerate(results): 162 | print(f"\nQuery {i} reranking results:") 163 | print(ranked_results) 164 | 165 | except Exception as e: 166 | print(f"Main program error stack:\n{traceback.format_exc()}") 167 | print(f"Main program error: {e}") 168 | raise 169 | finally: 170 | # Cleanup 171 | try: 172 | await cleanup(app) 173 | except Exception as e: 174 | print(f"Cleanup process error: {e}") 175 | 176 | if __name__ == "__main__": 177 | try: 178 | asyncio.run(main()) 179 | except KeyboardInterrupt: 180 | print("\nReceived keyboard interrupt, shutting down...") 181 | except Exception as e: 182 | print(f"Fatal error stack:\n{traceback.format_exc()}") 183 | print(f"Fatal error: {e}") 184 | finally: 185 | if ray.is_initialized(): 186 | ray.shutdown() 187 | -------------------------------------------------------------------------------- /examples/unoptimized_dag_for_embedding_ingestion_search_reranking_llm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NetX-lab/Ayo/61607bcb31ed98e1101c5d60d8573105bf60d87c/examples/unoptimized_dag_for_embedding_ingestion_search_reranking_llm.png -------------------------------------------------------------------------------- /examples/unoptimized_embedding_ingestion_searching.py: -------------------------------------------------------------------------------- 1 | from Ayo.app import APP 2 | from Ayo.dags.node import Node, NodeAnnotation, NodeType, NodeIOSchema, NodeOps 3 | from Ayo.dags.dag import DAG 4 | from Ayo.queries.query import Query 5 | from Ayo.configs.config import EngineConfig 6 | from Ayo.engines.engine_types import ENGINE_REGISTRY, EngineType 7 | import time 8 | import asyncio 9 | import ray 10 | from copy import deepcopy 11 | from typing import List, Any, Dict 12 | import traceback 13 | 14 | default_embed_config = ENGINE_REGISTRY.get_default_config(EngineType.EMBEDDER) 15 | default_embed_config.update( 16 | { 17 | "model_name": "BAAI/bge-large-en-v1.5", 18 | "max_batch_size": 1024, 19 | } 20 | ) 21 | embed_config = EngineConfig( 22 | name="embedding_service", 23 | engine_type=EngineType.EMBEDDER, 24 | resources={}, 25 | num_gpus=1, 26 | num_cpus=1, 27 | instances=1, 28 | model_config={ 29 | **default_embed_config, 30 | "device": "cuda" 31 | }, 32 | latency_profile={ 33 | "timeout": 300, 34 | "batch_wait": 0.1 35 | } 36 | ) 37 | 38 | vectordb_config = EngineConfig( 39 | name="vector_db_service", 40 | engine_type=EngineType.VECTOR_DB, 41 | resources={}, 42 | num_gpus=0, 43 | num_cpus=4, 44 | instances=1, 45 | model_config={ 46 | "host": "localhost", 47 | "port": 5432, 48 | "user": "asplos25", 49 | "password": "123456", 50 | "database": "database_asplos", 51 | "vector_dim": 1024, 52 | "max_batch_size": 1000 53 | }, 54 | latency_profile={ 55 | "timeout": 60, 56 | "batch_wait": 0.05 57 | } 58 | ) 59 | 60 | 61 | 62 | def create_base_dag(): 63 | """Create basic DAG""" 64 | base_dag = DAG(dag_id="random-test") 65 | 66 | passages_embedding_node = Node( 67 | name="Embedding", 68 | node_type=NodeType.COMPUTE, 69 | engine_type=EngineType.EMBEDDER, 70 | io_schema=NodeIOSchema( 71 | input_format={"passages": List[str]}, 72 | output_format={"passages_embeddings": List[List[float]]} 73 | ), 74 | op_type=NodeOps.EMBEDDING, 75 | anno=NodeAnnotation.BATCHABLE, 76 | config={ 77 | "batch_size": embed_config.model_config.get("max_batch_size", 1024) 78 | } 79 | ) 80 | 81 | ingestion_node = Node( 82 | name="Ingestion", 83 | node_type=NodeType.COMPUTE, 84 | engine_type=EngineType.VECTOR_DB, 85 | io_schema=NodeIOSchema( 86 | input_format={"passages": List[str], "passages_embeddings": List[List[float]]}, 87 | output_format={"index_status": bool} 88 | ), 89 | op_type=NodeOps.VECTORDB_INGESTION, 90 | anno=NodeAnnotation.BATCHABLE, 91 | config={ 92 | "batch_size": embed_config.model_config.get("max_batch_size", 256) 93 | } 94 | ) 95 | 96 | query_embedding_node = Node( 97 | name="QueryEmbedding", 98 | node_type=NodeType.COMPUTE, 99 | engine_type=EngineType.EMBEDDER, 100 | io_schema=NodeIOSchema( 101 | input_format={"queries": List[str]}, 102 | output_format={"queries_embeddings": List[List[float]]} 103 | ), 104 | op_type=NodeOps.EMBEDDING, 105 | anno=NodeAnnotation.BATCHABLE, 106 | config={ 107 | "batch_size": embed_config.model_config.get("max_batch_size", 256) 108 | } 109 | ) 110 | 111 | search_node = Node( 112 | name="Search", 113 | node_type=NodeType.COMPUTE, 114 | engine_type=EngineType.VECTOR_DB, 115 | io_schema=NodeIOSchema( 116 | input_format={"queries_embeddings": List[List[float]], "index_status": bool}, 117 | output_format={"search_results": List[List[str]]} 118 | ), 119 | op_type=NodeOps.VECTORDB_SEARCHING, 120 | anno=NodeAnnotation.BATCHABLE, 121 | config={ 122 | "batch_size": 16, 123 | "top_k": 5 124 | } 125 | ) 126 | 127 | passages_embedding_node >> ingestion_node 128 | ingestion_node >> search_node 129 | query_embedding_node >> search_node 130 | base_dag.register_nodes(passages_embedding_node, ingestion_node, query_embedding_node, search_node) 131 | 132 | return base_dag 133 | 134 | async def process_query(app, queries: List[str], passages: List[str], dag, query_id: str): 135 | try: 136 | query = Query( 137 | uuid=f"random-test-{query_id}", 138 | query_id=f"random-test-{query_id}", 139 | query_inputs={ 140 | "passages": passages, 141 | "queries": queries 142 | }, 143 | DAG=deepcopy(dag) 144 | ) 145 | 146 | future = await app.submit_query( 147 | query=query, 148 | timeout=300 149 | ) 150 | 151 | result = await asyncio.wait_for(future, timeout=300) 152 | return result 153 | 154 | except Exception as e: 155 | print(f"Query {query_id} processing failed:\n{traceback.format_exc()}") 156 | raise Exception(f"Query {query_id} processing failed: {str(e)}") 157 | 158 | async def cleanup(app): 159 | """clear resources""" 160 | try: 161 | await app.stop() 162 | app.shutdown() 163 | except Exception as e: 164 | print(f"Cleanup failed:\n{traceback.format_exc()}") 165 | print(f"Cleanup failed: {str(e)}") 166 | 167 | async def main(): 168 | try: 169 | # initialize the app 170 | app = APP.init() 171 | app.register_engine(embed_config) 172 | 173 | app.register_engine(vectordb_config) 174 | 175 | dag = create_base_dag() 176 | app.update_template(dag) 177 | 178 | # start the app 179 | await app.start() 180 | await asyncio.sleep(5) # wait for the system to initialize 181 | 182 | # prepare test data 183 | test_queries = [ 184 | { 185 | "passages": 186 | [ 187 | "OSDI is a conference about operating systems..." * 20, 188 | "MICRO is a conference about computer architecture..." * 20, 189 | "HPCA is a conference about computer architecture..." * 20, 190 | "MLSYS is a conference about machine learning..." * 20, 191 | "Machine learning system design is a conference about ..." * 20, 192 | "AI is a branch of computer science..." * 20, 193 | "Machine learning is a subset of AI using statistical models " * 20, 194 | "Deep learning revolutionized AI using neural networks " * 20, 195 | "The sun is the largest planet in the solar system..." * 20, 196 | "The moon is the only natural satellite of the earth..." * 20, 197 | "The earth is the third planet from the sun..." * 20, 198 | ]*80, 199 | "queries": [ 200 | "I want to know some knowledge about the top computer system conferences." 201 | ] 202 | }, 203 | 204 | 205 | ] 206 | 207 | # create delayed query tasks 208 | async def delayed_query(query_data, index): 209 | if index > 0: 210 | await asyncio.sleep(3 * index) 211 | return await process_query( 212 | app, 213 | query_data["queries"], 214 | query_data["passages"], 215 | dag, 216 | str(index) 217 | ) 218 | 219 | tasks = [ 220 | delayed_query(query_data, i) 221 | for i, query_data in enumerate(test_queries) 222 | ] 223 | 224 | results = await asyncio.gather(*tasks) 225 | 226 | for i, result in enumerate(results): 227 | print(f"\nQuery {i} results:") 228 | print(f"Results: {result}") 229 | 230 | except Exception as e: 231 | print(f"Main program error stack:\n{traceback.format_exc()}") 232 | print(f"Main program error: {e}") 233 | raise 234 | finally: 235 | try: 236 | await cleanup(app) 237 | except Exception as e: 238 | print(f"Cleanup process error: {e}") 239 | 240 | if __name__ == "__main__": 241 | try: 242 | asyncio.run(main()) 243 | except KeyboardInterrupt: 244 | print("\nReceived keyboard interrupt, shutting down...") 245 | except Exception as e: 246 | print(f"Fatal error stack:\n{traceback.format_exc()}") 247 | print(f"Fatal error: {e}") 248 | finally: 249 | if ray.is_initialized(): 250 | ray.shutdown() 251 | -------------------------------------------------------------------------------- /examples/unoptimized_embedding_ingestion_searching_reranking.py: -------------------------------------------------------------------------------- 1 | from Ayo.app import APP 2 | from Ayo.dags.node import Node, NodeAnnotation, NodeType, NodeIOSchema, NodeOps 3 | from Ayo.dags.dag import DAG 4 | from Ayo.queries.query import Query 5 | from Ayo.configs.config import EngineConfig 6 | from Ayo.engines.engine_types import ENGINE_REGISTRY, EngineType 7 | import time 8 | import asyncio 9 | import ray 10 | from copy import deepcopy 11 | from typing import List, Any, Dict, Optional, Union 12 | import traceback 13 | 14 | 15 | default_embed_config = ENGINE_REGISTRY.get_default_config(EngineType.EMBEDDER) 16 | default_embed_config.update( 17 | { 18 | "model_name": "BAAI/bge-large-en-v1.5", 19 | "max_batch_size": 1024, 20 | } 21 | ) 22 | embed_config = EngineConfig( 23 | name="embedding_service", 24 | engine_type=EngineType.EMBEDDER, 25 | resources={}, 26 | num_gpus=1, 27 | num_cpus=1, 28 | instances=2, 29 | model_config={ 30 | **default_embed_config, 31 | "device": "cuda" 32 | }, 33 | latency_profile={ 34 | "timeout": 300, 35 | "batch_wait": 0.1 36 | } 37 | ) 38 | 39 | 40 | vectordb_config = EngineConfig( 41 | name="vector_db_service", 42 | engine_type=EngineType.VECTOR_DB, 43 | resources={}, 44 | num_gpus=0, 45 | num_cpus=4, 46 | instances=1, 47 | model_config={ 48 | "host": "localhost", 49 | "port": 5432, 50 | "user": "asplos25", 51 | "password": "123456", 52 | "database": "database_asplos", 53 | "vector_dim": 1024, 54 | "max_batch_size": 1000 55 | }, 56 | latency_profile={ 57 | "timeout": 60, 58 | "batch_wait": 0.05 59 | } 60 | ) 61 | 62 | 63 | default_reranker_config = ENGINE_REGISTRY.get_default_config(EngineType.RERANKER) 64 | default_reranker_config.update( 65 | { 66 | "model_name": "BAAI/bge-reranker-large", 67 | "max_batch_size": 32, 68 | } 69 | ) 70 | reranker_config = EngineConfig( 71 | name="reranker_service", 72 | engine_type=EngineType.RERANKER, 73 | resources={}, 74 | num_gpus=1, 75 | num_cpus=1, 76 | instances=1, 77 | model_config={ 78 | **default_reranker_config, 79 | "device": "cuda" 80 | }, 81 | latency_profile={ 82 | "timeout": 60, 83 | "batch_wait": 0.1 84 | } 85 | ) 86 | 87 | def create_base_dag(): 88 | """Create basic DAG""" 89 | base_dag = DAG(dag_id="random-test") 90 | 91 | passages_embedding_node = Node( 92 | name="Embedding", 93 | node_type=NodeType.COMPUTE, 94 | engine_type=EngineType.EMBEDDER, 95 | io_schema=NodeIOSchema( 96 | input_format={"passages": List[str]}, 97 | output_format={"passages_embeddings": List[List[float]]} 98 | ), 99 | op_type=NodeOps.EMBEDDING, 100 | anno=NodeAnnotation.BATCHABLE, 101 | config={ 102 | 103 | } 104 | ) 105 | 106 | ingestion_node = Node( 107 | name="Ingestion", 108 | node_type=NodeType.COMPUTE, 109 | engine_type=EngineType.VECTOR_DB, 110 | io_schema=NodeIOSchema( 111 | input_format={"passages": List[str], "passages_embeddings": List[List[float]]}, 112 | output_format={"index_status": bool} 113 | ), 114 | op_type=NodeOps.VECTORDB_INGESTION, 115 | anno=NodeAnnotation.BATCHABLE, 116 | config={ 117 | 118 | } 119 | ) 120 | 121 | query_embedding_node = Node( 122 | name="QueryEmbedding", 123 | node_type=NodeType.COMPUTE, 124 | engine_type=EngineType.EMBEDDER, 125 | io_schema=NodeIOSchema( 126 | input_format={"queries": List[str]}, 127 | output_format={"queries_embeddings": List[List[float]]} 128 | ), 129 | op_type=NodeOps.EMBEDDING, 130 | anno=NodeAnnotation.BATCHABLE, 131 | config={ 132 | 133 | } 134 | ) 135 | 136 | search_node = Node( 137 | name="Search", 138 | node_type=NodeType.COMPUTE, 139 | engine_type=EngineType.VECTOR_DB, 140 | io_schema=NodeIOSchema( 141 | input_format={"queries_embeddings": List[List[float]], "index_status": bool}, 142 | output_format={"search_results": List[List[str]]} 143 | ), 144 | op_type=NodeOps.VECTORDB_SEARCHING, 145 | anno=NodeAnnotation.BATCHABLE, 146 | config={ 147 | "top_k": 30 # Add more retrieval to provide more candidates for reranking 148 | } 149 | ) 150 | 151 | reranking_node = Node( 152 | name="Reranking", 153 | node_type=NodeType.COMPUTE, 154 | engine_type=EngineType.RERANKER, 155 | io_schema=NodeIOSchema( 156 | input_format={"queries": str, "search_results": Union[List[List[str]],List[str]]}, 157 | output_format={"reranked_results": List[List[str]]} 158 | ), 159 | op_type=NodeOps.RERANKING, 160 | anno=NodeAnnotation.BATCHABLE, 161 | config={ 162 | "top_k": 10 # The number of results to return 163 | } 164 | ) 165 | 166 | base_dag.set_query_inputs( 167 | { 168 | "passages": [ 169 | "AI is a branch of computer science..." * 15, 170 | "Machine learning is a subset of AI..." * 15, 171 | "Deep learning revolutionized AI..." * 15, 172 | "I have a question about AI..." * 15, 173 | "I have a question about machine learning..." * 15, 174 | 'What is the latest news about AI?' * 15, 175 | 'What is the latest news about machine learning?' * 15, 176 | 'What is the latest news about deep learning?' * 15, 177 | 'What is the latest news about AI?' * 15, 178 | 'What is the latest news about machine learning?' * 15, 179 | 'What is the latest news about deep learning?' * 15, 180 | ]*60 + [ 181 | "OSDI is a conference about operating systems..." * 15, 182 | "ASPLOS is a conference about computer architecture..." * 15, 183 | "MICRO is a conference about computer architecture..." * 15, 184 | "HPCA is a conference about computer architecture..." * 15, 185 | "MLSYS is a conference about machine learning..." * 15, 186 | "Machine learning system design is a conference about machine ..." * 15, 187 | ], 188 | "queries": [ 189 | "I want to know some system conferences." 190 | ] 191 | }, 192 | ) 193 | 194 | 195 | passages_embedding_node >> ingestion_node >> query_embedding_node >> search_node >> reranking_node 196 | ingestion_node >> search_node 197 | base_dag.register_nodes(passages_embedding_node, ingestion_node, query_embedding_node, search_node, reranking_node) 198 | 199 | return base_dag 200 | 201 | 202 | async def process_query(app, queries: List[str], passages: List[str], dag, query_id: str): 203 | """Add a query to the app and process it""" 204 | try: 205 | query = Query( 206 | uuid=f"random-test-{query_id}", 207 | query_id=f"random-test-{query_id}", 208 | query_inputs={ 209 | "passages": passages, 210 | "queries": queries 211 | }, 212 | DAG=deepcopy(dag) 213 | ) 214 | 215 | future = await app.submit_query( 216 | query=query, 217 | timeout=300 218 | ) 219 | 220 | result = await asyncio.wait_for(future, timeout=300) 221 | return result 222 | 223 | except Exception as e: 224 | print(f"Query {query_id} processing failed:\n{traceback.format_exc()}") 225 | raise Exception(f"Query {query_id} processing failed: {str(e)}") 226 | 227 | 228 | async def run_app(dag): 229 | try: 230 | # Initialize the application 231 | app = APP.init() 232 | app.register_engine(embed_config) 233 | app.register_engine(vectordb_config) 234 | app.register_engine(reranker_config) # Register the reranker engine 235 | 236 | 237 | # Start the application 238 | await app.start() 239 | await asyncio.sleep(5) 240 | 241 | async def delayed_query(query_data, index): 242 | if index > 0: 243 | await asyncio.sleep(3 * index) 244 | return await process_query( 245 | app, 246 | query_data["queries"], 247 | query_data["passages"], 248 | dag, 249 | str(index) 250 | ) 251 | 252 | test_queries = [ 253 | { 254 | "passages": [ 255 | "AI is a branch of computer science..." * 15, 256 | "Machine learning is a subset of AI..." * 15, 257 | "Deep learning revolutionized AI..." * 15, 258 | "I have a question about AI..." * 15, 259 | "I have a question about machine learning..." * 15, 260 | 'What is the latest news about AI?' * 15, 261 | 'What is the latest news about machine learning?' * 15, 262 | 'What is the latest news about deep learning?' * 15, 263 | 'What is the latest news about AI?' * 15, 264 | 'What is the latest news about machine learning?' * 15, 265 | 'What is the latest news about deep learning?' * 15, 266 | ]*60 + [ 267 | "OSDI is a conference about operating systems..." * 15, 268 | "ASPLOS is a conference about computer architecture..." * 15, 269 | "MICRO is a conference about computer architecture..." * 15, 270 | "HPCA is a conference about computer architecture..." * 15, 271 | "MLSYS is a conference about machine learning..." * 15, 272 | "Machine learning system design is a conference about machine ..." * 15, 273 | ], 274 | "queries": [ 275 | "I want to know some system conferences." 276 | ] 277 | }, 278 | ] 279 | 280 | # create all tasks 281 | tasks = [ 282 | delayed_query(query_data, i) 283 | for i, query_data in enumerate(test_queries) 284 | ] 285 | 286 | # wait for all queries to complete 287 | results = await asyncio.gather(*tasks) 288 | 289 | # print results 290 | for i, result in enumerate(results): 291 | print(f"\nQuery {i} results: {result}") 292 | for key, value in result.items(): 293 | if 'search' in key.lower(): 294 | print(f"Search results length: {len(value)}") 295 | elif 'reranking' in key.lower(): 296 | print(f"Reranking results length: {len(value)}") 297 | 298 | except Exception as e: 299 | print(f"Main program error stack:\n{traceback.format_exc()}") 300 | print(f"Main program error: {e}") 301 | raise 302 | finally: 303 | # clean up 304 | try: 305 | await cleanup(app) 306 | except Exception as e: 307 | print(f"Cleanup process error: {e}") 308 | 309 | async def cleanup(app): 310 | """Clean up resources""" 311 | try: 312 | await app.stop() 313 | app.shutdown() 314 | except Exception as e: 315 | print(f"Cleanup failed:\n{traceback.format_exc()}") 316 | print(f"Cleanup failed: {str(e)}") 317 | 318 | 319 | if __name__ == "__main__": 320 | from Ayo.vis.vis_graph import visualize_dag_with_node_types 321 | 322 | dag = create_base_dag() 323 | 324 | 325 | visualize_dag_with_node_types(dag, "unoptimized_embedding_ingestion_search_reranking.png") 326 | 327 | 328 | asyncio.run(run_app(dag)) 329 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | asyncpg==0.30.0 2 | fastapi==0.115.8 3 | google_api_python_client==2.125.0 4 | llama_index==0.12.19 5 | numpy==1.24.3 6 | pandas==1.1.5 7 | pgvector==0.2.5 8 | psycopg==3.1.18 9 | pydantic==1.10.13 10 | ray==2.10.0 11 | Requests==2.32.3 12 | sentence_transformers==2.6.1 13 | torch==2.4.0+cu121 14 | tqdm==4.66.2 15 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | print(find_packages()) 4 | 5 | setup( 6 | name='Ayo', 7 | version='0.1', 8 | packages=find_packages(), 9 | ) 10 | 11 | --------------------------------------------------------------------------------