├── .gitignore
├── .gitmodules
├── Ayo
├── __init__.py
├── app.py
├── configs
│ ├── __init__.py
│ ├── config.py
│ └── model_config.py
├── dags
│ ├── __init__.py
│ ├── dag.py
│ ├── node.py
│ └── node_commons.py
├── engines
│ ├── __init__.py
│ ├── aggregator.py
│ ├── base_engine.py
│ ├── embedder.py
│ ├── engine_types.py
│ ├── llm.py
│ ├── payload_transformers.py
│ ├── reranker.py
│ └── vector_db.py
├── logger.py
├── modules
│ ├── base_module.py
│ ├── embedding.py
│ ├── indexing.py
│ ├── llm_syhthesizing.py
│ ├── mod_to_prim.py
│ ├── prompt_template.py
│ ├── query_expanding.py
│ ├── reranking.py
│ └── searching.py
├── opt_pass
│ ├── base_pass.py
│ ├── decoding_pipeling.py
│ ├── pass_manager.py
│ ├── prefilling_split.py
│ ├── pruning_dependency.py
│ ├── stage_decomposition.py
│ ├── test_dag_llm_decoding_pipeling.pdf
│ └── test_dag_prefilling_split.pdf
├── queries
│ ├── __init__.py
│ ├── query.py
│ └── query_state.py
├── schedulers
│ ├── engine_scheduler.py
│ └── graph_scheduler.py
├── utils.py
└── vis
│ ├── test_dag_node_types.png
│ └── vis_graph.py
├── LICENSE
├── README.md
├── examples
├── modules_to_primitives_embedding_ingestion_searching_reranking.py
├── modules_to_primitives_indexing_searching.py
├── optimized_dag_for_embedding_ingestion_searching_reranking_llm.png
├── optimized_embedding_ingestion_rewriting_searching_reranking_llm.py
├── optimized_embedding_ingestion_searching.py
├── optimized_embedding_ingestion_searching_reranking.py
├── optimized_embedding_ingestion_searching_reranking_llm.py
├── test_dag.py
├── test_embedding_service.py
├── test_multiple_llm_calls.py
├── test_reranking_service.py
├── unoptimized_dag_for_embedding_ingestion_search_reranking_llm.png
├── unoptimized_embedding_ingestion_rewriting_searching_reranking_llm.py
├── unoptimized_embedding_ingestion_searching.py
├── unoptimized_embedding_ingestion_searching_reranking.py
└── unoptimized_embedding_ingestion_searching_reranking_llm.py
├── requirements.txt
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | */__pycache__
3 | *.pyc
4 | *.pyo
5 | .DS_Store
6 | venv/
7 | env/
8 | .env
9 | .venv
10 | *.py[cod]
11 | *$py.class
12 | *.so
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | .pytest_cache/
30 |
31 | */profiles
32 | */ray_worker
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "vllm"]
2 | path = vllm
3 | url = https://github.com/Txxx926/vllm
4 | branch = ayo
5 |
--------------------------------------------------------------------------------
/Ayo/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/Ayo/configs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NetX-lab/Ayo/61607bcb31ed98e1101c5d60d8573105bf60d87c/Ayo/configs/__init__.py
--------------------------------------------------------------------------------
/Ayo/configs/config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, asdict
2 | from typing import Dict, Optional, List, Any
3 |
4 | @dataclass
5 | class EngineConfig:
6 | """configuration for execution engines"""
7 | name: str
8 | engine_type: str # e.g., "embedder", "llm", "vector_db"
9 | num_gpus: int = 0
10 | num_cpus: int = 1
11 | resources: Optional[Dict[str, int]] = None # e.g., {"GPU": 2}
12 | instances: int = 1
13 | model_config: Optional[Dict] = None
14 | latency_profile: Optional[Dict] = None
15 |
16 | def dict(self) -> Dict:
17 | """Convert config to dictionary"""
18 | return asdict(self)
19 |
20 |
21 | @dataclass
22 | class AppConfig:
23 | """application config"""
24 | engines: Dict[str, EngineConfig]
25 | optimization_passes: List[str] = None
26 | workflow_template: Dict[str, Any] = None
--------------------------------------------------------------------------------
/Ayo/configs/model_config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from enum import Enum
3 | from Ayo.dags.node_commons import NodeOps
4 | from Ayo.dags.node import Node
5 | from typing import Optional, Dict, List
6 |
7 | @dataclass
8 | class EmbeddingModelConfig:
9 | """embedding model config"""
10 | model_name: str
11 | dimension: int
12 | max_length: int = 512
13 | device: str = "cuda"
14 | batch_size: int = 1024
15 | vector_dim: int = 1024
16 |
17 | @dataclass
18 | class LLMConfig:
19 | """llm model config"""
20 | model_name: str
21 | temperature: float = 0.9
22 | top_p: float = 0.95
23 | device: str = "cuda"
24 |
25 | @dataclass
26 | class VectorDBConfig:
27 | """vector database config"""
28 | db_path: str
29 | dimension: int
30 | index_type: str = "HNSW" # or "IVF", "HNSW" etc.
31 | metric_type: str = "cosine" # or "IP", "cosine" etc.
32 | nprobe: int = 10
33 |
34 |
35 | class AggMode(str, Enum):
36 | """aggregator mode"""
37 | DUMMY = "dummy"
38 | MERGE = "merge"
39 | TOP_K = "top_k"
40 |
41 |
42 | def get_aggregator_config(node: Node, **kwargs) -> Dict:
43 | """get the aggregator config"""
44 |
45 | assert node.op_type == NodeOps.AGGREGATOR, f"node {node.name} is not an aggregator node"
46 | if node.parents[0].op_type == NodeOps.EMBEDDING:
47 | return {"agg_mode": AggMode.DUMMY}
48 | elif node.parents[0].op_type == NodeOps.VECTORDB_SEARCHING:
49 | return {"agg_mode": AggMode.MERGE}
50 | elif node.parents[0].op_type == NodeOps.RERANKING:
51 |
52 | agg_config= {"agg_mode": AggMode.TOP_K}
53 | agg_config.update(
54 | {
55 | node.config.get("topk",{}) or node.config.get("top_k",{}) or node.config.get("k",{}) or 5
56 | }
57 | )
58 | return agg_config
59 | elif node.parents[0].op_type == NodeOps.VECTORDB_INGESTION:
60 | return {"agg_mode": AggMode.DUMMY}
61 | else:
62 | raise ValueError(f"Unsupported node op type: {node.op_type}")
63 |
64 |
65 | def get_aggregator_config_for_parent_node(node: Node, **kwargs) -> Dict:
66 | if node.op_type == NodeOps.EMBEDDING:
67 | return {"agg_mode": AggMode.DUMMY}
68 | elif node.op_type == NodeOps.VECTORDB_SEARCHING:
69 | return {"agg_mode": AggMode.MERGE}
70 | elif node.op_type == NodeOps.RERANKING:
71 | return {"agg_mode": AggMode.TOP_K}
72 | elif node.op_type == NodeOps.VECTORDB_INGESTION:
73 | return {"agg_mode": AggMode.DUMMY}
74 | else:
75 | raise ValueError(f"Unsupported node op type: {node.op_type}")
--------------------------------------------------------------------------------
/Ayo/dags/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NetX-lab/Ayo/61607bcb31ed98e1101c5d60d8573105bf60d87c/Ayo/dags/__init__.py
--------------------------------------------------------------------------------
/Ayo/dags/node_commons.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | from enum import Enum
4 | from dataclasses import dataclass
5 | from typing import Dict, Optional
6 |
7 | class NodeStatus(Enum):
8 | PENDING = "pending"
9 | RUNNING = "running"
10 | COMPLETED = "completed"
11 | FAILED = "failed"
12 |
13 | class NodeAnnotation(str, Enum):
14 | """Node annotations for optimization hints"""
15 | SPLITTABLE = "splittable" # Node output can be split
16 | BATCHABLE = "batchable" # Node can process batched inputs
17 | NONE = "none" # No special optimization
18 |
19 | class NodeOps(str, Enum):
20 | """Node operations"""
21 | INPUT = "input"
22 | OUTPUT = "output"
23 | EMBEDDING = "embedding"
24 | VECTORDB_INGESTION = "vectordb_ingestion"
25 | VECTORDB_SEARCHING = "vectordb_searching"
26 | RERANKING = "reranking"
27 | LLM_PREFILLING = "llm_prefilling"
28 | LLM_DECODING = "llm_decoding"
29 | LLM_PARTIAL_PREFILLING = "llm_partial_prefilling"
30 | LLM_FULL_PREFILLING = "llm_full_prefilling"
31 | LLM_PARTIAL_DECODING = "llm_parallel_decoding"
32 | AGGREGATOR = "aggregator"
33 |
34 |
35 |
36 | class NodeType(Enum):
37 | """Node types in DAG"""
38 | INPUT = "input" # Input node that holds query inputs
39 | COMPUTE = "compute" # Computation node that performs operations
40 | OUTPUT = "output" # Output node that collects results
41 |
42 |
43 | @dataclass
44 | class NodeConfig:
45 | """Configuration for node execution"""
46 | batch_size: Optional[int] = None
47 | max_tokens: Optional[int] = None
48 | temperature: Optional[float] = None
49 | top_p: Optional[float] = None
50 | # Add other configuration parameters as needed
51 |
52 | @dataclass
53 | class NodeIOSchema:
54 | """Define the input and output schema for the node"""
55 | input_format: Dict[str, type] # input field name and type
56 | output_format: Dict[str, type] # output field name and type
57 |
--------------------------------------------------------------------------------
/Ayo/engines/__init__.py:
--------------------------------------------------------------------------------
1 | from Ayo.engines.engine_types import EngineType, EngineRegistry, ENGINE_REGISTRY
2 | from Ayo.engines.base_engine import BaseEngine
3 | from Ayo.engines.embedder import EmbeddingEngine
4 | from Ayo.engines.vector_db import VectorDBEngine
5 | from Ayo.engines.reranker import RerankerEngine
6 | from Ayo.engines.llm import LLMEngine
7 | from Ayo.engines.aggregator import AggregateEngine
8 |
9 |
--------------------------------------------------------------------------------
/Ayo/engines/aggregator.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional, Any, Tuple, Union
2 | import ray
3 | import asyncio
4 | import time
5 | import numpy as np
6 | from dataclasses import dataclass
7 | from collections import defaultdict
8 |
9 | # This engine would not really be used in the pipeline;
10 | # Rather we do the in-place aggregation in the engine scheduler slide
11 |
12 | @dataclass
13 | class AggregateRequest:
14 | """Data class for aggregate requests"""
15 | request_id: str
16 | query_id: str # Group requests by query ID
17 | agg_mode: str # Aggregation mode: concat, merge_dicts, select_best, custom
18 | data_sources: List[Any] # Data sources to aggregate
19 | callback_ref: Any = None # Ray ObjectRef for result callback
20 | timestamp: float = time.time()
21 |
22 | @ray.remote
23 | class AggregateEngine:
24 | """Ray Actor, used to handle aggregate requests
25 |
26 | Features:
27 | - Asynchronous request processing
28 | - Supports multiple aggregation modes
29 | - Groups requests by query ID
30 | """
31 |
32 | def __init__(self,
33 | max_batch_size: int = 32,
34 | max_queue_size: int = 1000,
35 | scheduler_ref: Optional[ray.actor.ActorHandle] = None,
36 | **kwargs):
37 |
38 | self.max_batch_size = max_batch_size
39 | self.max_queue_size = max_queue_size
40 |
41 | # Asynchronous queue
42 | self.request_queue = asyncio.Queue(maxsize=max_queue_size)
43 | self.batch_queue = asyncio.Queue(maxsize=max_queue_size)
44 |
45 | # Track requests by query ID
46 | self.query_requests: Dict[str, List[AggregateRequest]] = {}
47 |
48 | # Start processing tasks
49 | self.running = True
50 | self.tasks = [
51 | asyncio.create_task(self._batch_requests()),
52 | asyncio.create_task(self._process_batches())
53 | ]
54 |
55 | self.scheduler_ref = scheduler_ref
56 |
57 | async def submit_request(self,
58 | request_id: str,
59 | query_id: str,
60 | agg_mode: str,
61 | data_sources: List[Any]) -> None:
62 | """Submit new aggregate request"""
63 | request = AggregateRequest(
64 | request_id=request_id,
65 | query_id=query_id,
66 | agg_mode=agg_mode,
67 | data_sources=data_sources,
68 | callback_ref=None
69 | )
70 |
71 | if self.request_queue.qsize() >= self.max_queue_size:
72 | raise RuntimeError("Request queue is full")
73 |
74 | await self.request_queue.put(request)
75 |
76 | if query_id not in self.query_requests:
77 | self.query_requests[query_id] = []
78 | self.query_requests[query_id].append(request)
79 |
80 | async def _batch_requests(self):
81 | """Asynchronous task for batching requests"""
82 | while self.running:
83 | try:
84 | batch_requests = await self._get_next_batch()
85 | if batch_requests:
86 | await self.batch_queue.put(batch_requests)
87 | else:
88 | await asyncio.sleep(0.01) # Avoid busy waiting
89 | except Exception as e:
90 | print(f"Error in batch processing task: {e}")
91 | continue
92 |
93 | async def _process_batches(self):
94 | """Asynchronous task for processing batches"""
95 | while self.running:
96 | try:
97 | try:
98 | batch_requests = await asyncio.wait_for(
99 | self.batch_queue.get(),
100 | timeout=0.1
101 | )
102 | except asyncio.TimeoutError:
103 | continue
104 |
105 | # Process each request
106 | for request in batch_requests:
107 | try:
108 | # Process data based on aggregation mode
109 | result = await self._aggregate_data(
110 | request.agg_mode,
111 | request.data_sources
112 | )
113 |
114 | # Create ObjectRef for result
115 | result_ref = ray.put(result)
116 |
117 | # If scheduler is set, send result to scheduler
118 | if self.scheduler_ref is not None:
119 | await self.scheduler_ref.on_result.remote(
120 | request.request_id,
121 | request.query_id,
122 | result_ref
123 | )
124 |
125 | # Clean up request records
126 | if request.query_id in self.query_requests:
127 | self.query_requests[request.query_id].remove(request)
128 | if not self.query_requests[request.query_id]:
129 | del self.query_requests[request.query_id]
130 |
131 | except Exception as e:
132 | import traceback
133 | traceback.print_exc()
134 | print(f"Error in processing single request: {e}")
135 | continue
136 |
137 | except Exception as e:
138 | import traceback
139 | traceback.print_exc()
140 | print(f"Error in processing batch: {e}")
141 | continue
142 |
143 | async def _get_next_batch(self) -> List[AggregateRequest]:
144 | """Get next batch of requests to process"""
145 | batch_requests = []
146 | processed_queries = set()
147 |
148 | while len(batch_requests) < self.max_batch_size:
149 | try:
150 | request = await asyncio.wait_for(
151 | self.request_queue.get(),
152 | timeout=0.01
153 | )
154 |
155 | if request.query_id in processed_queries:
156 | # Process requests from the same query
157 | pending_requests = self.query_requests[request.query_id]
158 | batch_requests.extend(pending_requests)
159 | else:
160 | batch_requests.append(request)
161 | processed_queries.add(request.query_id)
162 |
163 | except asyncio.TimeoutError:
164 | break
165 |
166 | return batch_requests
167 |
168 | async def _aggregate_data(self, agg_mode: str, data_sources: List[Any]) -> Any:
169 | """Aggregate data based on aggregation mode"""
170 | if agg_mode == "concat":
171 | # Simple list concatenation
172 | result = []
173 | for source in data_sources:
174 | if isinstance(source, list):
175 | result.extend(source)
176 | else:
177 | result.append(source)
178 | return result
179 |
180 | elif agg_mode == "merge_dicts":
181 | # Merge multiple dictionaries
182 | result = {}
183 | for source in data_sources:
184 | if isinstance(source, dict):
185 | result.update(source)
186 | return result
187 |
188 | elif agg_mode == "select_best":
189 | # Select best result (assuming each source has a score field)
190 | if not data_sources:
191 | return None
192 |
193 | best_source = None
194 | best_score = float('-inf')
195 |
196 | for source in data_sources:
197 | if isinstance(source, dict) and 'score' in source:
198 | if source['score'] > best_score:
199 | best_score = source['score']
200 | best_source = source
201 |
202 | return best_source
203 |
204 | elif agg_mode == "topk":
205 | # Select top k results with highest scores
206 | # datasource format:
207 | if not data_sources:
208 | return []
209 |
210 | # Use n parameter from request as k value, default to 3 if not provided
211 | k = 3
212 |
213 | # Filter valid data sources ( must be dictionaries and contain score field)
214 | valid_sources = [
215 | source for source in data_sources
216 | if isinstance(source, dict) and 'score' in source
217 | ]
218 |
219 | # Sort by score in descending order and return top k
220 | sorted_sources = sorted(
221 | valid_sources,
222 | key=lambda x: x['score'],
223 | reverse=True
224 | )
225 |
226 | return sorted_sources[:k]
227 |
228 | elif agg_mode == "custom":
229 | # Custom aggregation function (needs function and data in data_sources)
230 | if len(data_sources) >= 2 and callable(data_sources[0]):
231 | custom_func = data_sources[0]
232 | data = data_sources[1:]
233 | return custom_func(*data)
234 | return None
235 |
236 | else:
237 | # Default return original data sources
238 | return data_sources
239 |
240 | async def shutdown(self):
241 | """Shutdown service"""
242 | self.running = False
243 | for task in self.tasks:
244 | task.cancel()
245 | try:
246 | await task
247 | except asyncio.CancelledError:
248 | pass
--------------------------------------------------------------------------------
/Ayo/engines/base_engine.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional, Any
2 | import ray
3 | import asyncio
4 | import time
5 | from abc import ABC, abstractmethod
6 | from dataclasses import dataclass
7 |
8 | @dataclass
9 | class BaseRequest:
10 | """Base data class for all engine requests"""
11 | request_id: str
12 | query_id: str # Group requests from same query
13 | callback_ref: Any # Ray ObjectRef for result
14 | timestamp: float = time.time()
15 |
16 |
17 | class BaseEngine(ABC):
18 | """Base class for all Ray Actor engines
19 |
20 | Features:
21 | - Async request handling
22 | - Request queuing and batching
23 | - Request tracking by query_id
24 | """
25 |
26 | def __init__(self,
27 | max_batch_size: int = 32,
28 | max_queue_size: int = 1000,
29 | scheduler_ref: Optional[ray.actor.ActorHandle] = None,
30 | **kwargs):
31 |
32 | self.max_batch_size = max_batch_size
33 | self.max_queue_size = max_queue_size
34 | self.scheduler_ref = scheduler_ref
35 |
36 | # Async queues
37 | self.request_queue = asyncio.Queue(maxsize=max_queue_size)
38 | self.batch_queue = asyncio.Queue(maxsize=max_queue_size)
39 |
40 | # Track requests by query_id
41 | self.query_requests: Dict[str, List[BaseRequest]] = {}
42 |
43 | # Create event loop
44 | self.loop = asyncio.get_event_loop()
45 |
46 | # Start processing tasks
47 | self.running = True
48 | self.tasks = [
49 | self.loop.create_task(self._batch_requests()),
50 | self.loop.create_task(self._process_batches())
51 | ]
52 |
53 | @abstractmethod
54 | def _load_model(self):
55 | """Load the model - must be implemented by subclasses"""
56 | pass
57 |
58 | @abstractmethod
59 | async def submit_request(self, request_id: str, query_id: str, **kwargs) -> ray.ObjectRef:
60 | """Submit a new request - must be implemented by subclasses"""
61 | pass
62 |
63 | @abstractmethod
64 | async def _get_next_batch(self):
65 | """Get next batch of requests - must be implemented by subclasses"""
66 | pass
67 |
68 | @abstractmethod
69 | async def _process_batch(self, batch_data):
70 | """Process a batch of requests - must be implemented by subclasses"""
71 | pass
72 |
73 | async def _batch_requests(self):
74 | """Async task for batching requests"""
75 | while self.running:
76 | try:
77 | batch_data = await self._get_next_batch()
78 | if batch_data:
79 | await self.batch_queue.put(batch_data)
80 | else:
81 | await asyncio.sleep(0.01)
82 | except Exception as e:
83 | print(f"Error in batching task: {e}")
84 | continue
85 |
86 | async def _process_batches(self):
87 | """Async task for processing batches"""
88 | while self.running:
89 | try:
90 | try:
91 | batch_data = await asyncio.wait_for(
92 | self.batch_queue.get(),
93 | timeout=0.1
94 | )
95 | except asyncio.TimeoutError:
96 | continue
97 |
98 | await self._process_batch(batch_data)
99 |
100 | except Exception as e:
101 | print(f"Error in process loop: {e}")
102 | continue
103 |
104 | async def shutdown(self):
105 | """Shutdown the engine"""
106 | self.running = False
107 | for task in self.tasks:
108 | task.cancel()
109 | try:
110 | await task
111 | except asyncio.CancelledError:
112 | pass
--------------------------------------------------------------------------------
/Ayo/engines/embedder.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional, Any, Tuple
2 | import ray
3 | import torch
4 | import asyncio
5 | import time
6 | from dataclasses import dataclass
7 | import numpy as np
8 | from collections import deque
9 | from Ayo.logger import get_logger, GLOBAL_INFO_LEVEL
10 |
11 | logger = get_logger(__name__, level=GLOBAL_INFO_LEVEL)
12 |
13 | @dataclass
14 | class EmbeddingRequest:
15 | """Data class for embedding requests"""
16 | request_id: str
17 | query_id: str # Group requests from same query
18 | texts: List[str]
19 | callback_ref: Any # Ray ObjectRef for result
20 | timestamp: float = time.time()
21 |
22 |
23 | @ray.remote(num_gpus=1)
24 | class EmbeddingEngine:
25 | """Ray Actor for serving embedding requests with async processing
26 |
27 | Features:
28 | - Async request handling
29 | - Batches requests for efficient processing
30 | - Groups requests from same query
31 | """
32 |
33 | def __init__(self,
34 | model_name: str = "BAAI/bge-large-en-v1.5",
35 | max_batch_size: int = 512,
36 | max_queue_size: int = 1000,
37 | device: str = "cuda" if torch.cuda.is_available() else "cpu",
38 | scheduler_ref: Optional[ray.actor.ActorHandle] = None,
39 | **kwargs):
40 |
41 | print(f"CUDA is available: {torch.cuda.is_available()}")
42 | if torch.cuda.is_available():
43 | print(f"Number of available GPUs: {torch.cuda.device_count()}")
44 | print(f"Current GPU device: {torch.cuda.current_device()}")
45 | print(f"GPU name: {torch.cuda.get_device_name()}")
46 |
47 | self.model_name = model_name
48 | self.max_batch_size = max_batch_size
49 | self.max_queue_size = max_queue_size
50 | self.device = device
51 |
52 | self.name = kwargs.get("name", None)
53 |
54 | # Initialize model
55 | self.model = self._load_model()
56 |
57 | # Async queues
58 | self.request_queue = asyncio.Queue(maxsize=max_queue_size)
59 | self.batch_queue = asyncio.Queue(maxsize=max_queue_size)
60 |
61 | # Track requests by query_id
62 | self.query_requests: Dict[str, List[EmbeddingRequest]] = {}
63 |
64 | # Create event loop
65 | self.loop = asyncio.get_event_loop()
66 |
67 | # Start processing tasks
68 | self.running = True
69 | self.tasks = [
70 | self.loop.create_task(self._batch_requests()),
71 | self.loop.create_task(self._process_batches())
72 | ]
73 |
74 | self.scheduler_ref = scheduler_ref
75 |
76 |
77 | def is_ready(self):
78 | """Check if the engine is ready"""
79 | return True
80 |
81 | def _load_model(self):
82 | """Load the embedding model"""
83 | from sentence_transformers import SentenceTransformer
84 |
85 | #self.model.encode(["hello","world"])
86 |
87 |
88 | model=SentenceTransformer(model_name_or_path=self.model_name,device=self.device)
89 | model.half()
90 | model.eval()
91 |
92 | #warm up
93 | warm_up_embeddings=model.encode(["hello","world"])
94 | logger.debug(f"warm_up_embeddings: {warm_up_embeddings}")
95 |
96 | logger.debug(f"Load and warm up embedding model:{self.model_name} successfully on {self.device}")
97 |
98 | return model
99 |
100 | async def submit_request(self,
101 | request_id: str,
102 | query_id: str,
103 | texts: List[str]) -> ray.ObjectRef:
104 | """Submit a new embedding request"""
105 |
106 |
107 | request = EmbeddingRequest(
108 | request_id=request_id,
109 | query_id=query_id,
110 | texts=texts,
111 | callback_ref=None
112 | )
113 |
114 | if self.request_queue.qsize() >= self.max_queue_size:
115 | raise RuntimeError("Request queue is full")
116 |
117 | await self.request_queue.put(request)
118 |
119 | if query_id not in self.query_requests:
120 | self.query_requests[query_id] = []
121 |
122 | self.query_requests[query_id].append(request)
123 |
124 |
125 | async def _batch_requests(self):
126 | """Async task for batching requests"""
127 | while self.running:
128 | try:
129 | batch_requests, batch_texts = await self._get_next_batch()
130 | if batch_requests:
131 | await self.batch_queue.put((batch_requests, batch_texts))
132 | else:
133 | await asyncio.sleep(0.01) # Avoid busy waiting
134 | except Exception as e:
135 | print(f"Error in batching task: {e}")
136 | continue
137 |
138 | async def _process_batches(self):
139 | """Async task for processing batches"""
140 | while self.running:
141 | try:
142 | try:
143 | batch_requests, batch_texts = await asyncio.wait_for(
144 | self.batch_queue.get(),
145 | timeout=0.1
146 | )
147 | except asyncio.TimeoutError:
148 | continue
149 |
150 | #print(f"Processing batch requests: {batch_requests}")
151 |
152 | try:
153 | embeddings = await self.loop.run_in_executor(
154 | None,
155 | self._compute_embeddings,
156 | batch_texts
157 | )
158 |
159 | start_idx = 0
160 | for request in batch_requests:
161 | try:
162 | end_idx = start_idx + len(request.texts)
163 | request_embeddings = embeddings[start_idx:end_idx]
164 |
165 | # create ObjectRef for result
166 | print(f"request_embeddings: {request_embeddings.shape}")
167 |
168 | result_ref = ray.put(request_embeddings)
169 |
170 | # If scheduler is set, send result to scheduler
171 |
172 | if self.scheduler_ref is not None:
173 | await self.scheduler_ref.on_result.remote(
174 | request.request_id,
175 | request.query_id,
176 | result_ref
177 | )
178 | else:
179 | # If no scheduler is set, use the original callback
180 | ray.get(ray.put(request_embeddings, _owner=request.callback_ref))
181 |
182 | # clean up request records
183 | # if request.query_id in self.query_requests:
184 | # self.query_requests[request.query_id].remove(request)
185 | # if not self.query_requests[request.query_id]:
186 | # del self.query_requests[request.query_id]
187 |
188 | start_idx = end_idx
189 | except Exception as e:
190 | print(f"Error processing individual request: {e}")
191 | continue
192 |
193 | except Exception as e:
194 | print(f"Error computing embeddings: {e}")
195 | continue
196 |
197 | except Exception as e:
198 | print(f"Error in inference task: {e}")
199 | continue
200 |
201 | async def _get_next_batch(self) -> Tuple[List[EmbeddingRequest], List[str]]:
202 | """Get next batch of requests to process"""
203 | batch_requests = []
204 | batch_texts = []
205 | processed_queries = set()
206 |
207 |
208 |
209 | #while len(batch_texts) < self.max_batch_size:
210 | while len(batch_texts) ==0:
211 | try:
212 | request = await asyncio.wait_for(
213 | self.request_queue.get(),
214 | timeout=0.01
215 | )
216 |
217 | if request.query_id in processed_queries:
218 | pending_requests = self.query_requests[request.query_id]
219 | for pending_req in pending_requests:
220 | if len(batch_texts) + len(pending_req.texts) <= self.max_batch_size:
221 | batch_requests.append(pending_req)
222 | batch_texts.extend(pending_req.texts)
223 | else:
224 | if len(batch_texts) + len(request.texts) <= self.max_batch_size:
225 | batch_requests.append(request)
226 | batch_texts.extend(request.texts)
227 | processed_queries.add(request.query_id)
228 | else:
229 | await self.request_queue.put(request)
230 | break
231 |
232 | except asyncio.TimeoutError:
233 | break
234 |
235 | return batch_requests, batch_texts
236 |
237 | def _compute_embeddings(self, texts: List[str]) -> np.ndarray:
238 | """Compute embeddings for a batch of texts"""
239 | #self.model.encode(["hello","world"])
240 | with torch.no_grad():
241 |
242 | assert isinstance(texts,list) or isinstance(texts,str)
243 |
244 | begin = time.time()
245 |
246 | batch_size = len(texts) if isinstance(texts,list) else 1
247 | embeddings = self.model.encode(texts,batch_size=batch_size,show_progress_bar=False)
248 | #embeddings = embeddings[:,:768]
249 |
250 | logger.debug(f"texts' type: {type(texts)}, len: {len(texts)}, embeddings shape: {embeddings.shape}")
251 | end = time.time()
252 | logger.debug(f"embedding time for {len(texts)} texts: {end - begin}")
253 | return embeddings
254 |
255 | async def shutdown(self):
256 | """Shutdown the service"""
257 | self.running = False
258 | for task in self.tasks:
259 | task.cancel()
260 | try:
261 | await task
262 | except asyncio.CancelledError:
263 | pass
264 |
265 |
266 |
267 |
268 |
--------------------------------------------------------------------------------
/Ayo/engines/engine_types.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from typing import Dict, Type, Optional, Any, Union
3 | from dataclasses import dataclass
4 | from Ayo.engines.base_engine import BaseEngine
5 |
6 | class EngineType(str, Enum):
7 | """Supported engine types"""
8 | INPUT = "input"
9 | OUTPUT = "output"
10 | EMBEDDER = "embedder"
11 | VECTOR_DB = "vector_db"
12 | RERANKER = "reranker"
13 | LLM = "llm"
14 | AGGREGATOR = "aggregator"
15 | DUMMY = "dummy"
16 |
17 | @classmethod
18 | def list(cls) -> list:
19 | """Get list of all engine types"""
20 | return list(cls)
21 |
22 | @classmethod
23 | def validate(cls, engine_type: str) -> bool:
24 | """Validate if engine type is supported"""
25 | return engine_type in cls.__members__.values()
26 |
27 | @dataclass
28 | class EngineSpec:
29 | """Engine specifications"""
30 | engine_class: Any # 使用Any代替具体类型
31 | default_config: Dict
32 | description: str
33 |
34 | class EngineRegistry:
35 | """Central registry for engine types and specifications"""
36 |
37 | def __init__(self):
38 | self._registry: Dict[str, EngineSpec] = {}
39 | self._register_default_engines()
40 |
41 | def _register_default_engines(self):
42 | """Register built-in engines"""
43 | # Lazy import to avoid circular imports
44 | from Ayo.engines.base_engine import BaseEngine
45 | from Ayo.engines.embedder import EmbeddingEngine
46 | from Ayo.engines.vector_db import VectorDBEngine
47 | from Ayo.engines.reranker import RerankerEngine
48 | from Ayo.engines.llm import LLMEngine
49 | from Ayo.engines.aggregator import AggregateEngine
50 |
51 | self.register(
52 | EngineType.INPUT,
53 | EngineSpec(
54 | engine_class=BaseEngine,
55 | default_config={},
56 | description="The dummy engine for input node"
57 | )
58 | )
59 |
60 | self.register(
61 | EngineType.OUTPUT,
62 | EngineSpec(
63 | engine_class=BaseEngine,
64 | default_config={},
65 | description="The dummy engine for output node"
66 | )
67 | )
68 |
69 | self.register(
70 | EngineType.EMBEDDER,
71 | EngineSpec(
72 | engine_class=EmbeddingEngine,
73 | default_config={
74 | "model_name": "BAAI/bge-large-en-v1.5",
75 | "max_batch_size": 1024,
76 | 'vector_dim' :1024
77 | },
78 | description="Text embedding engine using BGE model"
79 | )
80 | )
81 |
82 | self.register(
83 | EngineType.VECTOR_DB,
84 | EngineSpec(
85 | engine_class=VectorDBEngine,
86 | default_config={
87 | "host": "localhost",
88 | "user": "asplos25",
89 | "password": "123456",
90 | "database": "database_asplos",
91 | "port": 5432,
92 | "max_batch_size": 1000
93 | },
94 | description="Vector database engine using pgvector"
95 | )
96 | )
97 |
98 | self.register(
99 | EngineType.RERANKER,
100 | EngineSpec(
101 | engine_class=RerankerEngine,
102 | default_config={
103 | "model_name": "BAAI/bge-reranker-large",
104 | "max_batch_size": 512
105 | },
106 | description="Cross-encoder reranking engine"
107 | )
108 | )
109 |
110 | self.register(
111 | EngineType.LLM,
112 | EngineSpec(
113 | engine_class=LLMEngine,
114 | default_config={
115 | "model_name": "meta-llama/Llama-2-7b-chat-hf",
116 | "tensor_parallel_size": 1,
117 | "max_num_seqs": 256,
118 | "max_queue_size": 1000,
119 | "trust_remote_code": False,
120 | "dtype": "auto"
121 | },
122 | description="Large language model engine"
123 | )
124 | )
125 |
126 | self.register(
127 | EngineType.AGGREGATOR,
128 | EngineSpec(
129 | engine_class=AggregateEngine,
130 | default_config={
131 | "max_batch_size": 32,
132 | "max_queue_size": 1000
133 | },
134 | description="Data Aggregation Engine, supports multiple aggregation modes"
135 | )
136 | )
137 |
138 | def register(self, engine_type: str, spec: EngineSpec) -> None:
139 | """Register a new engine type"""
140 | if engine_type in self._registry:
141 | raise ValueError(f"Engine type {engine_type} already registered")
142 | self._registry[engine_type] = spec
143 |
144 | def unregister(self, engine_type: str) -> None:
145 | """Unregister an engine type"""
146 | if engine_type not in self._registry:
147 | raise ValueError(f"Engine type {engine_type} not registered")
148 | del self._registry[engine_type]
149 |
150 | def get_spec(self, engine_type: str) -> Optional[EngineSpec]:
151 | """Get engine specifications"""
152 | return self._registry.get(engine_type)
153 |
154 | def get_engine_class(self, engine_type: str) -> Optional[Type[BaseEngine]]:
155 | """Get engine class for given type"""
156 | spec = self.get_spec(engine_type)
157 | return spec.engine_class if spec else None
158 |
159 | def get_default_config(self, engine_type: str) -> Optional[Dict]:
160 | """Get default configuration for engine type"""
161 | spec = self.get_spec(engine_type)
162 | return spec.default_config if spec else None
163 |
164 | def list_engines(self) -> Dict[str, str]:
165 | """List all registered engines and their descriptions"""
166 | return {
167 | engine_type: spec.description
168 | for engine_type, spec in self._registry.items()
169 | }
170 |
171 | # Global engine registry instance
172 | ENGINE_REGISTRY = EngineRegistry()
--------------------------------------------------------------------------------
/Ayo/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 | import time
5 | from datetime import datetime
6 | from typing import Optional, Dict, Any, Union
7 | import threading
8 |
9 | class AyoLogger:
10 |
11 | # ANSI color codes
12 | COLORS = {
13 | 'DEBUG': '\033[36m', # cyan
14 | 'INFO': '\033[32m', # green
15 | 'WARNING': '\033[33m', # yellow
16 | 'ERROR': '\033[31m', # red
17 | 'CRITICAL': '\033[35m', # purple
18 | 'RESET': '\033[0m' # reset
19 | }
20 |
21 | # singleton instance
22 | _instance = None
23 | _lock = threading.Lock()
24 |
25 | def __new__(cls, *args, **kwargs):
26 | with cls._lock:
27 | if cls._instance is None:
28 | cls._instance = super(AyoLogger, cls).__new__(cls)
29 | cls._instance._initialized = False
30 | return cls._instance
31 |
32 | def __init__(self,
33 | name: str = "ayo",
34 | level: str = "INFO",
35 | log_file: Optional[str] = None,
36 | use_colors: bool = True,
37 | log_format: Optional[str] = None):
38 | """
39 | initialize the logger
40 |
41 | Args:
42 | name: logger name
43 | level: logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
44 | log_file: log file path, if None, only output to console
45 | use_colors: whether to use colors in console output
46 | log_format: custom log format, if None, use the default format
47 | """
48 | if self._initialized:
49 | return
50 |
51 | self.name = name
52 | self.level = getattr(logging, level.upper())
53 | self.log_file = log_file
54 | self.use_colors = use_colors
55 |
56 | # create the logger
57 | self.logger = logging.getLogger(name)
58 | self.logger.setLevel(self.level)
59 | self.logger.handlers = [] # clear the existing handlers
60 |
61 | # set the default log format
62 | if log_format is None:
63 | self.log_format = "%(asctime)s - %(levelname)s - %(name)s - [%(filename)s:%(lineno)d] - %(message)s"
64 | else:
65 | self.log_format = log_format
66 |
67 | # create the formatter
68 | self.formatter = logging.Formatter(self.log_format, datefmt="%Y-%m-%d %H:%M:%S")
69 |
70 | # add the console handler
71 | self._add_console_handler()
72 |
73 | # if the log file is specified, add the file handler
74 | if self.log_file:
75 | self._add_file_handler()
76 |
77 | self._initialized = True
78 |
79 | def _add_console_handler(self):
80 | """add the console handler"""
81 | console_handler = logging.StreamHandler(sys.stdout)
82 | console_handler.setLevel(self.level)
83 |
84 | if self.use_colors:
85 | # use the colored formatter
86 | colored_formatter = self._get_colored_formatter()
87 | console_handler.setFormatter(colored_formatter)
88 | else:
89 | console_handler.setFormatter(self.formatter)
90 |
91 | self.logger.addHandler(console_handler)
92 |
93 | def _add_file_handler(self):
94 | """add the file handler"""
95 | # ensure the log directory exists
96 | log_dir = os.path.dirname(self.log_file)
97 | if log_dir and not os.path.exists(log_dir):
98 | os.makedirs(log_dir)
99 |
100 | file_handler = logging.FileHandler(self.log_file)
101 | file_handler.setLevel(self.level)
102 | file_handler.setFormatter(self.formatter)
103 | self.logger.addHandler(file_handler)
104 |
105 | def _get_colored_formatter(self):
106 | """create the colored formatter"""
107 | class ColoredFormatter(logging.Formatter):
108 | def __init__(self, fmt, datefmt=None):
109 | super().__init__(fmt, datefmt)
110 | self.colors = AyoLogger.COLORS
111 |
112 | def format(self, record):
113 | levelname = record.levelname
114 | if levelname in self.colors:
115 | record.levelname = f"{self.colors[levelname]}{levelname}{self.colors['RESET']}"
116 | record.msg = f"{self.colors[levelname.upper()]}{record.msg}{self.colors['RESET']}"
117 | return super().format(record)
118 |
119 | return ColoredFormatter(self.log_format, datefmt="%Y-%m-%d %H:%M:%S")
120 |
121 | def set_level(self, level: str):
122 | """set the logging level"""
123 | level_upper = level.upper()
124 | if hasattr(logging, level_upper):
125 | self.level = getattr(logging, level_upper)
126 | self.logger.setLevel(self.level)
127 | for handler in self.logger.handlers:
128 | handler.setLevel(self.level)
129 |
130 | def debug(self, msg: str, *args, **kwargs):
131 | """record the DEBUG level log"""
132 | kwargs.setdefault('stacklevel', 2)
133 | self.logger.debug(msg, *args, **kwargs)
134 |
135 | def info(self, msg: str, *args, **kwargs):
136 | """record the INFO level log"""
137 | kwargs.setdefault('stacklevel', 2)
138 | self.logger.info(msg, *args, **kwargs)
139 |
140 | def warning(self, msg: str, *args, **kwargs):
141 | """record the WARNING level log"""
142 | kwargs.setdefault('stacklevel', 2)
143 | self.logger.warning(msg, *args, **kwargs)
144 |
145 | def error(self, msg: str, *args, **kwargs):
146 | """record the ERROR level log"""
147 | kwargs.setdefault('stacklevel', 2)
148 | self.logger.error(msg, *args, **kwargs)
149 |
150 | def critical(self, msg: str, *args, **kwargs):
151 | """record the CRITICAL level log"""
152 | kwargs.setdefault('stacklevel', 2)
153 | self.logger.critical(msg, *args, **kwargs)
154 |
155 | def exception(self, msg: str, *args, **kwargs):
156 | """record the exception information"""
157 | kwargs.setdefault('stacklevel', 2)
158 | self.logger.exception(msg, *args, **kwargs)
159 |
160 | def log_dict(self, level: str, data: Dict[str, Any], prefix: str = ""):
161 | """record the dictionary data"""
162 | level_method = getattr(self.logger, level.lower())
163 | for key, value in data.items():
164 | if prefix:
165 | key = f"{prefix}.{key}"
166 | if isinstance(value, dict):
167 | self.log_dict(level, value, key)
168 | else:
169 | level_method(f"{key}: {value}")
170 |
171 |
172 |
173 |
174 |
175 | def get_logger(name: str = "ayo",
176 | level: str = "INFO",
177 | log_file: Optional[str] = None,
178 | use_colors: bool = True) -> AyoLogger:
179 | """
180 | get the Ayo logger instance
181 |
182 | Args:
183 | name: logger name
184 | level: logging level
185 | log_file: log file path
186 | use_colors: whether to use colored output
187 |
188 | Returns:
189 | AyoLogger instance
190 | """
191 | return AyoLogger(name=name, level=level, log_file=log_file, use_colors=use_colors)
192 |
193 |
194 | GLOBAL_INFO_LEVEL = os.environ.get("AYO_INFO_LEVEL", "INFO")
195 |
196 | default_logger = AyoLogger(level=GLOBAL_INFO_LEVEL)
197 |
198 | debug = default_logger.debug
199 | info = default_logger.info
200 | warning = default_logger.warning
201 | error = default_logger.error
202 | critical = default_logger.critical
203 | exception = default_logger.exception
204 |
--------------------------------------------------------------------------------
/Ayo/modules/base_module.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Any
2 | from Ayo.dags.node import Node
3 | from Ayo.dags.node_commons import NodeIOSchema
4 |
5 | class BaseModule:
6 | """
7 | Base class for all modules
8 | """
9 |
10 | def __init__(self,
11 | input_format: Dict[str, Any] = None,
12 | output_format: Dict[str, Any] = None,
13 | config: Dict[str, Any] = None):
14 | """
15 | Initialize the module
16 |
17 | Args:
18 | input_format: Input format definition
19 | output_format: Output format definition
20 | config: Module configuration parameters
21 | """
22 | self.input_format = input_format or {}
23 | self.output_format = output_format or {}
24 | self.config = config or {}
25 |
26 | self.pre_dependencies = []
27 | self.post_dependencies = []
28 |
29 |
30 | def __rshift__(self, other):
31 | self.post_dependencies.append(other)
32 | other.pre_dependencies.append(self)
33 | return self
34 |
35 |
36 | def to_primitive_nodes(self) -> List[Node]:
37 | """
38 | Convert the module to a list of primitive nodes
39 |
40 | Returns:
41 | List[Node]: List of primitive nodes
42 | """
43 | raise NotImplementedError("Subclasses must implement the to_primitive_nodes method")
44 |
45 | def validate_io_schema(self) -> bool:
46 | """
47 | Validate the input and output format
48 |
49 | Returns:
50 | bool: Validation result
51 | """
52 | # Default implementation, subclasses can override this method for more detailed validation
53 | return len(self.input_format) > 0 and len(self.output_format) > 0
54 |
55 |
56 | def __str__(self) -> str:
57 | """Return the string representation of the module"""
58 | return f"{self.__class__.__name__}(input={self.input_format}, output={self.output_format})"
59 |
60 | def __repr__(self) -> str:
61 | """Return the detailed string representation of the module"""
62 | return f"{self.__class__.__name__}(input={self.input_format}, output={self.output_format}, config={self.config})"
--------------------------------------------------------------------------------
/Ayo/modules/embedding.py:
--------------------------------------------------------------------------------
1 | from Ayo.dags.node import Node
2 | from Ayo.dags.node_commons import NodeType, NodeOps, NodeIOSchema
3 | from Ayo.engines.engine_types import EngineType
4 | from Ayo.modules.base_module import BaseModule
5 |
6 |
7 | class EmbeddingModule(BaseModule):
8 | def __init__(self,
9 | input_format: dict={
10 | "text": list[str]
11 | },
12 | output_format: dict={
13 | "embeddings": list
14 | },
15 | config: dict=None):
16 | """Initialize the Embedding Module.
17 |
18 | This module is responsible for converting text into vector embeddings
19 | using an embedding model.
20 |
21 | Args:
22 | input_format (dict): Input format definition, defaults to:
23 | - text (list[str]): List of text strings to be embedded
24 | output_format (dict): Output format definition, defaults to:
25 | - embeddings (list): List of vector embeddings
26 | config (dict, optional): Configuration parameters for the embedding process
27 | """
28 | super().__init__(input_format, output_format, config)
29 |
30 | def to_primitive_nodes(self):
31 | return [
32 | Node(
33 | name="Embedding",
34 | io_schema=NodeIOSchema(
35 | input_format=self.input_format,
36 | output_format=self.output_format
37 | ),
38 | op_type=NodeOps.EMBEDDING,
39 | engine_type=EngineType.EMBEDDING,
40 | node_type=NodeType.COMPUTE,
41 | config=self.config
42 | )
43 | ]
44 |
--------------------------------------------------------------------------------
/Ayo/modules/indexing.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | from Ayo.dags.node import Node
3 | from Ayo.dags.node_commons import NodeType, NodeOps, NodeIOSchema
4 | from Ayo.engines.engine_types import EngineType
5 | from Ayo.modules.base_module import BaseModule
6 |
7 | class IndexingModule(BaseModule):
8 | def __init__(self,
9 | input_format: dict={
10 | "passages": List[str]
11 | },
12 | output_format: dict={
13 | "index_status": bool
14 | },
15 | config: dict=None):
16 | """Initialize the Indexing Module.
17 |
18 | This module is responsible for embedding passages and ingesting them into a vector database.
19 | It creates an index that can be used for vector similarity search.
20 |
21 | Args:
22 | input_format (dict): Input format definition, defaults to:
23 | - passages (List[str]): List of text passages to be indexed
24 | output_format (dict): Output format definition, defaults to:
25 | - index_status (bool): Status indicating whether indexing was successful
26 | config (dict, optional): Configuration parameters for the indexing process
27 | """
28 | super().__init__(input_format, output_format, config)
29 |
30 | def to_primitive_nodes(self):
31 | # create embedding node
32 | embedding_node = Node(
33 | name="EmbeddingForIndex",
34 | io_schema=NodeIOSchema(
35 | input_format={"passages": List[str]},
36 | output_format={"passages_embeddings": List[float]}
37 | ),
38 | op_type=NodeOps.EMBEDDING,
39 | engine_type=EngineType.EMBEDDER,
40 | node_type=NodeType.COMPUTE,
41 | config=self.config
42 | )
43 |
44 | # create ingestion node
45 | ingestion_node = Node(
46 | name="IngestionForIndex",
47 | io_schema=NodeIOSchema(
48 | input_format={"passages": List[str], "passages_embeddings": List[float]},
49 | output_format={"index_status": bool}
50 | ),
51 | op_type=NodeOps.VECTORDB_INGESTION,
52 | engine_type=EngineType.VECTOR_DB,
53 | node_type=NodeType.COMPUTE,
54 | config=self.config
55 | )
56 |
57 | # connect nodes
58 | embedding_node >> ingestion_node
59 |
60 | return [embedding_node, ingestion_node]
61 |
--------------------------------------------------------------------------------
/Ayo/modules/llm_syhthesizing.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | from Ayo.dags.node import Node
3 | from Ayo.dags.node_commons import NodeType, NodeOps, NodeIOSchema
4 | from Ayo.engines.engine_types import EngineType
5 | from Ayo.modules.base_module import BaseModule
6 | from enum import Enum
7 |
8 | class LLMGenerationMode(Enum):
9 | NORMAL = "normal"
10 | SUMMARIZATION = "summarization"
11 | REFINEMENT = "refinement"
12 |
13 | class LLMSynthesizingModule(BaseModule):
14 | RAG_PROMPT_TEMPLATE = """\
15 | You are an AI assistant specialized in Retrieval-Augmented Generation (RAG). Your responses
16 | must be based strictly on the retrieved documents provided to you. Follow these guidelines:
17 | 1. Use Retrieved Information Only - Your responses must rely solely on the retrieved documents.
18 | If the retrieved documents do not contain relevant information, explicitly state: 'Based on the
19 | available information, I cannot determine the answer.'\n"
20 | 2. Response Formatting - Directly answer the question using the retrieved data. If multiple
21 | sources provide information, synthesize them in a coherent manner. If no relevant information
22 | is found, clearly state that.\n"
23 | 3. Clarity and Precision - Avoid speculative language such as 'I think' or 'It might be.'
24 | Maintain a neutral and factual tone.\n"
25 | 4. Information Transparency - Do not fabricate facts or sources. If needed, summarize the
26 | retrieved information concisely.\n"
27 | 5. Handling Out-of-Scope Queries - If a question is outside the retrieved data (e.g., opinions,
28 | unverifiable claims), state: 'The retrieved documents do not provide information on this topic.'\n
29 | ---\n
30 | Example Interactions:\n
31 | User Question: Who founded Apple Inc.?\n
32 | Retrieved Context: 'Apple Inc. was co-founded in 1976 by Steve Jobs, Steve Wozniak, and Ronald Wayne.'\n
33 | Model Answer: 'Apple Inc. was co-founded in 1976 by Steve Jobs, Steve Wozniak, and Ronald Wayne.'\n
34 | ---\n
35 | User Question: When was the first iPhone released, and what were its key features?\n"
36 | Retrieved Context: 'The first iPhone was announced by Steve Jobs on January 9, 2007, and released on June 29, 2007.' "
37 | "'The original iPhone featured a 3.5-inch touchscreen display, a 2-megapixel camera, and ran on iOS.'\n"
38 | Model Answer: 'The first iPhone was announced on January 9, 2007, and released on June 29, 2007. "
39 | "It featured a 3.5-inch touchscreen display, a 2-megapixel camera, and ran on iOS.'\
40 | This ensures accuracy, reliability, and transparency in all responses. And you should directly answer the question based on the retrieved context and keep it concise as possible.
41 | Here is the question: {question}?
42 | Here is the retrieved context: {context}
43 | Here is your answer:
44 | """
45 |
46 | SUMMARIZATION_PROMPT_TEMPLATE = """\
47 | You are an AI assistant specialized in Question Answering. You would be provided with a question and several candidate answers.
48 | Your task is to summarize the candidate answers into a single answer. You should keep the original meaning of the question and the candidate answers.
49 | You should select the most relevant answer from the candidate answers and summarize it. Always need to keep your answer concise and to the point.
50 | Here is the question: {question}?
51 | Here are the candidate answers: {answers}
52 | Here is your answer:
53 | """
54 |
55 | def __init__(self,
56 | input_format: dict={
57 | "question": str,
58 | "context": List[str]
59 | },
60 | output_format: dict={
61 | "answer": str},
62 | config: dict={
63 | "generation_mode":LLMGenerationMode.NORMAL,
64 | 'prompt_template': RAG_PROMPT_TEMPLATE,
65 | 'parse_json': True,
66 | 'prompt':RAG_PROMPT_TEMPLATE,
67 | 'partial_output': False,
68 | 'partial_prefilling': False,
69 | 'llm_partial_decoding_idx': -1
70 | }):
71 | """Initialize the LLM Synthesizing Module.
72 |
73 | This module is responsible for generating answers using a Large Language Model (LLM)
74 | based on retrieved context. It supports multiple generation modes: normal generation,
75 | summarization, and refinement.
76 |
77 | Args:
78 | input_format (dict): Input format definition, defaults to:
79 | - question (str): User question
80 | - context (List[str]): List of retrieved context documents
81 | output_format (dict): Output format definition, defaults to:
82 | - answer (str): Generated response
83 | config (dict): Configuration parameters, including:
84 | - generation_mode (LLMGenerationMode): Generation mode, can be NORMAL, SUMMARIZATION, or REFINEMENT
85 | - prompt_template (str): Prompt template string
86 | - parse_json (bool): Whether to parse JSON output
87 | - prompt (str): Complete prompt string
88 | - partial_output (bool): Whether to enable partial output
89 | - partial_prefilling (bool): Whether to enable partial prefilling
90 | - llm_partial_decoding_idx (int): Partial decoding index
91 |
92 | Notes:
93 | - Summarization mode (SUMMARIZATION) requires 'context_num' to be specified in config
94 | - Refinement mode (REFINEMENT) also requires 'context_num' to be specified in config
95 | """
96 |
97 | super().__init__(input_format, output_format, config)
98 |
99 | # TODO: support the below generation mode
100 | if config["generation_mode"]==LLMGenerationMode.SUMMARIZATION:
101 | assert 'context_num' in config, "context_num is required for summarization"
102 | elif config["generation_mode"]==LLMGenerationMode.REFINEMENT:
103 | assert 'context_num' in config, "context_num is required for refinement"
104 |
105 |
106 | def to_primitive_nodes(self):
107 | if self.config["generation_mode"]==LLMGenerationMode.NORMAL:
108 | return [
109 | Node(
110 | name="LLMSynthesizingPrefilling",
111 | input_format=self.input_format,
112 | output_format=self.output_format,
113 | node_type=NodeType.COMPUTE,
114 | engine_type=EngineType.LLM,
115 | op_type=NodeOps.LLM_PREFILLING,
116 | config=self.config
117 | ),
118 | Node(
119 | name="LLMSynthesizingDecoding",
120 | input_format=self.input_format,
121 | output_format=self.output_format,
122 | node_type=NodeType.COMPUTE,
123 | engine_type=EngineType.LLM,
124 | op_type=NodeOps.LLM_DECODING,
125 | config=self.config
126 | )
127 | ]
128 |
129 | elif self.config["generation_mode"]==LLMGenerationMode.SUMMARIZATION:
130 | return [
131 | Node(
132 | name="LLMSynthesizingPrefilling",
133 | )
134 | ]
135 |
136 | elif self.config["generation_mode"]==LLMGenerationMode.REFINEMENT:
137 | return [
138 | Node(
139 | name="LLMSynthesizingPrefilling",
140 | )
141 | ]
142 |
143 |
--------------------------------------------------------------------------------
/Ayo/modules/mod_to_prim.py:
--------------------------------------------------------------------------------
1 | from Ayo.modules.base_module import BaseModule
2 | from Ayo.modules.indexing import IndexingModule
3 | from Ayo.modules.query_expanding import QueryExpandingModule
4 | from Ayo.modules.searching import SearchingModule
5 | from Ayo.modules.reranking import RerankingModule
6 | from typing import List
7 |
8 | def transform_mod_to_prim(mods: List[BaseModule]):
9 | """
10 | Transform a chain of modules to a list of primitive nodes
11 | """
12 | mods_2_nodes={}
13 | for mod in mods:
14 | mods_2_nodes[mod]=mod.to_primitive_nodes()
15 |
16 | for mod in mods:
17 | for post_mod in mod.post_dependencies:
18 | mods_2_nodes[mod][-1]>>mods_2_nodes[post_mod][0]
19 |
20 | node_list=[]
21 | for mod in mods:
22 | node_list.extend(mods_2_nodes[mod])
23 | return node_list
24 |
25 |
26 | if __name__=="__main__":
27 | indexing_module = IndexingModule(input_format={"passages": List[str]}, output_format={"index_status": bool})
28 | query_expanding_module=QueryExpandingModule(input_format={"query": str}, output_format={"expanded_queries": List[str]},config={"expanded_query_num": 3})
29 | searching_module = SearchingModule(input_format={"index_status": bool, "expanded_queries": List[str]}, output_format={"searching_results": List[str]})
30 | reranking_module=RerankingModule(input_format={"searching_results": List[str]}, output_format={"reranking_results": List[str]})
31 |
32 |
33 | indexing_module>>query_expanding_module>>searching_module>>reranking_module
34 |
35 |
36 | node_list=transform_mod_to_prim([indexing_module,query_expanding_module,searching_module,reranking_module])
37 |
38 |
39 | print(node_list)
40 |
41 |
--------------------------------------------------------------------------------
/Ayo/modules/prompt_template.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | QUERY_EXPANDING_PROMPT_TEMPLATE_STRING= """\
4 | Please rewrite the following question into {refine_question_number} more refined one. \
5 | You should keep the original meaning of the question, but make it more suitable and clear for context retrieval. \
6 | The original question is: {question}? \
7 | Please output your answer in json format. \
8 | It should contain {refine_question_number} new refined questions.\
9 | For example, if the expaned number is 3, the json output should be like this: \
10 | {{\
11 | "revised question1": "[refined question 1]",\
12 | "revised question2": "[refined question 2]",\
13 | "revised question3": "[refined question 3]"\
14 | }}\
15 | You just need to output the json string, do not output any other information or additional text!!! \
16 | The json output:"""
17 |
18 | RAG_QUESTION_ANSWERING_PROMPT_TEMPLATE_STRING="""\
19 | You are an AI assistant specialized in Retrieval-Augmented Generation (RAG). Your responses
20 | must be based strictly on the retrieved documents provided to you. Follow these guidelines:
21 | 1. Use Retrieved Information Only - Your responses must rely solely on the retrieved documents.
22 | If the retrieved documents do not contain relevant information, explicitly state: 'Based on the
23 | available information, I cannot determine the answer.'\n"
24 | 2. Response Formatting - Directly answer the question using the retrieved data. If multiple
25 | sources provide information, synthesize them in a coherent manner. If no relevant information
26 | is found, clearly state that.\n"
27 | 3. Clarity and Precision - Avoid speculative language such as 'I think' or 'It might be.'
28 | Maintain a neutral and factual tone.\n"
29 | 4. Information Transparency - Do not fabricate facts or sources. If needed, summarize the
30 | retrieved information concisely.\n"
31 | 5. Handling Out-of-Scope Queries - If a question is outside the retrieved data (e.g., opinions,
32 | unverifiable claims), state: 'The retrieved documents do not provide information on this topic.'\n
33 | ---\n
34 | Example Interactions:\n
35 | User Question: Who founded Apple Inc.?\n
36 | Retrieved Context: 'Apple Inc. was co-founded in 1976 by Steve Jobs, Steve Wozniak, and Ronald Wayne.'\n
37 | Model Answer: 'Apple Inc. was co-founded in 1976 by Steve Jobs, Steve Wozniak, and Ronald Wayne.'\n
38 | ---\n
39 | User Question: When was the first iPhone released, and what were its key features?\n"
40 | Retrieved Context: 'The first iPhone was announced by Steve Jobs on January 9, 2007, and released on June 29, 2007.' "
41 | "'The original iPhone featured a 3.5-inch touchscreen display, a 2-megapixel camera, and ran on iOS.'\n"
42 | Model Answer: 'The first iPhone was announced on January 9, 2007, and released on June 29, 2007. "
43 | "It featured a 3.5-inch touchscreen display, a 2-megapixel camera, and ran on iOS.'\
44 | This ensures accuracy, reliability, and transparency in all responses. And you should directly answer the question based on the retrieved context and keep it concise as possible.
45 | Here is the question: {question}?
46 | Here is the retrieved context: {context}
47 | Here is your answer, make sure it is concise:
48 | """
49 |
50 | def replace_placeholders(prompt_template: str, **kwargs):
51 | for key,value in kwargs.items():
52 | prompt_template = prompt_template.replace(f"{{{key}}}", f"{{{value}}}")
53 | print(prompt_template)
54 | return prompt_template
55 |
56 | #TODO: Currently, these classes have been actually used in the modules and the related payload-transformations
57 | # We should do these to make the prompt-template-transformation more flexible and reusable.
58 |
59 | class PromptTemplate:
60 | def __init__(self, prompt_template: str):
61 | # the template is a string with placeholders like {key}
62 | self.prompt_template = prompt_template
63 |
64 | self.placeholders = re.findall(r'\{([^{}]+)\}', self.prompt_template)
65 |
66 | def fill_template(self, **kwargs):
67 | raise NotImplementedError("This method should be implemented by the subclass")
68 |
69 |
70 | class QueryExpandingPromptTemplate(PromptTemplate):
71 |
72 | default_template = """\
73 | Please rewrite the following question into {refine_question_number} more refined one. \
74 | You should keep the original meaning of the question, but make it more suitable and clear for context retrieval. \
75 | The original question is: {question}? \
76 | Please output your answer in json format. \
77 | It should contain {refine_question_number} new refined questions.\
78 | For example, if the expaned number is 3, the json output should be like this: \
79 | {{\
80 | "revised question1": "[refined question 1]",\
81 | "revised question2": "[refined question 2]",\
82 | "revised question3": "[refined question 3]"\
83 | }}\
84 | You just need to output the json string, do not output any other information or additional text!!! \
85 | The json output:"""
86 |
87 |
88 | def __init__(self, prompt_template: str=None):
89 | if prompt_template is None:
90 | prompt_template = self.default_template
91 | super().__init__(prompt_template)
92 |
93 | # special case, here we do not need to check the placeholders
94 | self.placeholders = re.findall(r'\{([^{}]+)\}', self.prompt_template)
95 |
96 | def fill_template(self, **kwargs):
97 | refine_question_number = None
98 | question = None
99 | for key,value in kwargs.items():
100 | if 'num' in key.lower():
101 | refine_question_number = value
102 | elif key.lower() in ['query','question','question_']:
103 | question = value
104 |
105 |
106 | assert refine_question_number is not None, "refine_question_number is required"
107 | assert question is not None, "question is required"
108 |
109 | keys = ", ".join([f"question{i+1}" for i in range(refine_question_number)])
110 | json_example = "{\n " + "\n ".join([f"\"question{i+1}\": \"[refined version {i+1}]\"" + ("," if i < refine_question_number-1 else "") for i in range(refine_question_number)]) + "\n }"
111 |
112 | return self.prompt_template.format(
113 | refine_question_number=refine_question_number,
114 | question=question,
115 | keys=keys,
116 | json_example=json_example
117 | )
118 |
119 |
120 | class RAGQuestionAnsweringPromptTemplate(PromptTemplate):
121 |
122 | default_template="""\
123 | You are an AI assistant specialized in Retrieval-Augmented Generation (RAG). Your responses
124 | must be based strictly on the retrieved documents provided to you. Follow these guidelines:
125 | 1. Use Retrieved Information Only - Your responses must rely solely on the retrieved documents.
126 | If the retrieved documents do not contain relevant information, explicitly state: 'Based on the
127 | available information, I cannot determine the answer.'\n"
128 | 2. Response Formatting - Directly answer the question using the retrieved data. If multiple
129 | sources provide information, synthesize them in a coherent manner. If no relevant information
130 | is found, clearly state that.\n"
131 | 3. Clarity and Precision - Avoid speculative language such as 'I think' or 'It might be.'
132 | Maintain a neutral and factual tone.\n"
133 | 4. Information Transparency - Do not fabricate facts or sources. If needed, summarize the
134 | retrieved information concisely.\n"
135 | 5. Handling Out-of-Scope Queries - If a question is outside the retrieved data (e.g., opinions,
136 | unverifiable claims), state: 'The retrieved documents do not provide information on this topic.'\n
137 | ---\n
138 | Example Interactions:\n
139 | User Question: Who founded Apple Inc.?\n
140 | Retrieved Context: 'Apple Inc. was co-founded in 1976 by Steve Jobs, Steve Wozniak, and Ronald Wayne.'\n
141 | Model Answer: 'Apple Inc. was co-founded in 1976 by Steve Jobs, Steve Wozniak, and Ronald Wayne.'\n
142 | ---\n
143 | User Question: When was the first iPhone released, and what were its key features?\n"
144 | Retrieved Context: 'The first iPhone was announced by Steve Jobs on January 9, 2007, and released on June 29, 2007.' "
145 | "'The original iPhone featured a 3.5-inch touchscreen display, a 2-megapixel camera, and ran on iOS.'\n"
146 | Model Answer: 'The first iPhone was announced on January 9, 2007, and released on June 29, 2007. "
147 | "It featured a 3.5-inch touchscreen display, a 2-megapixel camera, and ran on iOS.'\
148 | This ensures accuracy, reliability, and transparency in all responses. And you should directly answer the question based on the retrieved context and keep it concise as possible.
149 | Here is the question: {question}?
150 | Here is the retrieved context: {context}
151 | Here is your answer:
152 | """
153 | def __init__(self, prompt_template: str=None):
154 |
155 | if prompt_template is None:
156 | prompt_template = self.default_template
157 | super().__init__(prompt_template)
158 |
159 | #TODO: add more prompt templates
160 |
--------------------------------------------------------------------------------
/Ayo/modules/query_expanding.py:
--------------------------------------------------------------------------------
1 | import uuid
2 | from Ayo.dags.node import Node
3 | from Ayo.dags.node_commons import NodeType, NodeOps, NodeIOSchema
4 | from Ayo.engines.engine_types import EngineType
5 | from typing import List
6 | from Ayo.modules.base_module import BaseModule
7 |
8 | class QueryExpandingModule(BaseModule):
9 |
10 | prompt_template="""Please rewrite the following question into {refine_question_number} more refined one. \
11 | You should keep the original meaning of the question, but make it more suitable and clear for context retrieval. \
12 | The original question is: {question}? \
13 | Please output your answer in json format. \
14 | It should contain {refine_question_number} new refined questions.\
15 | For example, if the expaned number is 3, the json output should be like this: \
16 | {{\
17 | "revised question1": "[refined question 1]",\
18 | "revised question2": "[refined question 2]",\
19 | "revised question3": "[refined question 3]"\
20 | }}\
21 | You just need to output the json string, do not output any other information or additional text!!! \
22 | The json output:"""
23 |
24 |
25 | #temperature=0.9, top_p=0.95 for llama2-7b-chat-hf
26 |
27 | # prompt
28 | # support_partial_output = node.config.get('partial_output', False)
29 | # support_partial_prefilling = node.config.get('partial_prefilling', False)
30 |
31 | # llm_partial_decoding_idx = node.config.get("llm_partial_decoding_idx", -1)
32 |
33 | def __init__(self,
34 | input_format: dict={
35 | "query": str
36 | },
37 | output_format: dict={
38 | "expanded_queries": List[str]
39 | },
40 | config: dict={
41 | 'expanded_query_num': 3,
42 | 'prompt_template': prompt_template,
43 | 'parse_json': True,
44 | 'prompt':prompt_template,
45 | 'partial_output': False,
46 | 'partial_prefilling': False,
47 | 'llm_partial_decoding_idx': -1
48 | }):
49 | """Initialize the Query Expanding Module.
50 |
51 | This module is responsible for expanding a single query into multiple refined queries
52 | using a language model, which can improve retrieval performance by capturing different
53 | aspects of the original query.
54 |
55 | Args:
56 | input_format (dict): Input format definition, defaults to:
57 | - query (str): Original user query
58 | output_format (dict): Output format definition, defaults to:
59 | - expanded_queries (List[str]): List of expanded/refined queries
60 | config (dict): Configuration parameters, including:
61 | - expanded_query_num (int): Number of queries to generate
62 | - prompt_template (str): Template for query expansion prompt
63 | - parse_json (bool): Whether to parse JSON output
64 | - prompt (str): Complete prompt string
65 | - partial_output (bool): Whether to enable partial output
66 | - partial_prefilling (bool): Whether to enable partial prefilling
67 | - llm_partial_decoding_idx (int): Partial decoding index
68 | """
69 | super().__init__(input_format, output_format, config)
70 |
71 | def to_primitive_nodes(self):
72 | # create LLM prefilling node
73 |
74 | llm_internal_id = f"query_expanding_{uuid.uuid4()}"
75 |
76 | llm_prefilling_node = Node(
77 | name="QueryExpandingPrefilling",
78 | io_schema=NodeIOSchema(
79 | input_format={"query": str},
80 | output_format={"prefill_state": dict}
81 | ),
82 | op_type=NodeOps.LLM_PREFILLING,
83 | engine_type=EngineType.LLM,
84 | node_type=NodeType.COMPUTE,
85 | config={
86 | 'prompt_template': self.config.get('prompt_template', self.prompt_template),
87 | 'prompt': self.config.get('prompt', self.prompt_template),
88 | 'expanded_query_num': self.config.get('expanded_query_num', 3),
89 | 'parse_json': self.config.get('parse_json', True),
90 | 'partial_output': self.config.get('partial_output', False),
91 | 'partial_prefilling': self.config.get('partial_prefilling', False),
92 | 'llm_partial_decoding_idx': self.config.get('llm_partial_decoding_idx', -1),
93 | 'llm_internal_id': llm_internal_id
94 | }
95 | )
96 |
97 | # create LLM decoding node
98 | llm_decoding_node = Node(
99 | name="QueryExpandingDecoding",
100 | io_schema=NodeIOSchema(
101 | input_format={"query": str, "prefill_state": dict},
102 | output_format={"expanded_queries": List[str]}
103 | ),
104 | op_type=NodeOps.LLM_DECODING,
105 | engine_type=EngineType.LLM,
106 | node_type=NodeType.COMPUTE,
107 | config={
108 | 'prompt_template': self.config.get('prompt_template', self.prompt_template),
109 | 'prompt': self.config.get('prompt', self.prompt_template),
110 | 'expanded_query_num': self.config.get('expanded_query_num', 3),
111 | 'parse_json': self.config.get('parse_json', True),
112 | 'partial_output': self.config.get('partial_output', False),
113 | 'partial_prefilling': self.config.get('partial_prefilling', False),
114 | 'llm_partial_decoding_idx': self.config.get('llm_partial_decoding_idx', -1),
115 | 'llm_internal_id': llm_internal_id
116 | }
117 | )
118 |
119 | # Connect the nodes
120 | llm_prefilling_node >> llm_decoding_node
121 |
122 | return [llm_prefilling_node, llm_decoding_node]
123 |
124 | def format_prompt(self, question):
125 | refine_question_number = self.config.get('expanded_query_num', None)
126 |
127 | assert refine_question_number is not None, "expanded_query_num is not set"
128 | #keys = ", ".join([f"question{i+1}" for i in range(refine_question_number)])
129 | #json_example = "{\n " + "\n ".join([f"\"question{i+1}\": \"[refined version {i+1}]\"" + ("," if i < refine_question_number-1 else "") for i in range(refine_question_number)]) + "\n }"
130 |
131 | return self.prompt_template.format(
132 | refine_question_number=refine_question_number,
133 | question=question,
134 | )
135 |
136 |
137 | if __name__ == "__main__":
138 | query_expanding_module = QueryExpandingModule()
139 | question = "What is the capital of France?"
140 | print(query_expanding_module.format_prompt(question))
--------------------------------------------------------------------------------
/Ayo/modules/reranking.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | from Ayo.dags.node import Node
3 | from Ayo.dags.node_commons import NodeType, NodeOps,NodeIOSchema
4 | from Ayo.engines.engine_types import EngineType
5 | from Ayo.modules.base_module import BaseModule
6 |
7 | class RerankingModule(BaseModule):
8 | def __init__(self,
9 | input_format: dict={
10 | "query": str,
11 | "passages": List[str]
12 | },
13 | output_format: dict={
14 | "passages": List[str]
15 | },
16 | config: dict={
17 | 'top_k': 10
18 | }):
19 | """Initialize the Reranking Module.
20 |
21 | This module is responsible for reranking the passages based on the query.
22 |
23 | Args:
24 | input_format (dict): Input format definition, defaults to:
25 | - query (str): The query to rerank the passages
26 | - passages (List[str]): The passages to rerank
27 | output_format (dict): Output format definition, defaults to:
28 | - passages (List[str]): The reranked passages
29 | config (dict, optional): Configuration parameters for the reranking process
30 | """
31 |
32 | super().__init__(input_format, output_format, config)
33 |
34 | def to_primitive_nodes(self):
35 | return [
36 | Node(
37 | name="Reranking",
38 | io_schema=NodeIOSchema(
39 | input_format=self.input_format,
40 | output_format=self.output_format
41 | ),
42 | op_type=NodeOps.RERANKING,
43 | engine_type=EngineType.RERANKER,
44 | node_type=NodeType.COMPUTE,
45 | config=self.config
46 | )
47 | ]
48 |
49 |
50 |
--------------------------------------------------------------------------------
/Ayo/modules/searching.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | from Ayo.dags.node import Node
3 | from Ayo.dags.node_commons import NodeType, NodeOps, NodeIOSchema
4 | from Ayo.engines.engine_types import EngineType
5 | from Ayo.modules.base_module import BaseModule
6 |
7 | class SearchingModule(BaseModule):
8 | '''
9 | The searching module is used to caluate the embedding for the query and search the vector database for the most relevant documents.
10 | '''
11 | def __init__(self,
12 | input_format: dict={
13 | "queries": List[str],
14 | 'index_status': bool
15 | },
16 | output_format: dict={
17 | 'search_results': List[str]
18 | },
19 | config: dict={
20 | 'top_k': 3
21 | }):
22 |
23 | """Initialize the Searching Module.
24 |
25 | This module is responsible for calculating embeddings for queries and searching
26 | the vector database for the most relevant documents based on embedding similarity.
27 |
28 | Args:
29 | input_format (dict): Input format definition, defaults to:
30 | - queries (List[str]): List of queries to search for
31 | - index_status (bool): Status indicating whether the index is ready
32 | output_format (dict): Output format definition, defaults to:
33 | - search_results (List[str]): List of retrieved documents
34 | config (dict): Configuration parameters, including:
35 | - top_k (int): Number of top results to return per query
36 | """
37 | super().__init__(input_format, output_format, config)
38 |
39 | def to_primitive_nodes(self):
40 |
41 | query_embedd_input_key=None
42 | index_input_key=None
43 | for key in self.input_format.keys():
44 | if 'query' in key or 'queries' in key or 'expanded_queries' in key or 'passages' in key or 'expanded_passages' in key or 'expanded_query' in key:
45 | query_embedd_input_key=key
46 |
47 | elif 'index' in key or 'index_status' in key:
48 | index_input_key=key
49 |
50 |
51 | query_embedding_node = Node(
52 | name="QueryEmbedding",
53 | io_schema=NodeIOSchema(
54 | input_format={query_embedd_input_key: self.input_format[query_embedd_input_key]},
55 | output_format={"queries_embeddings": List[float]}
56 | ),
57 | op_type=NodeOps.EMBEDDING,
58 | engine_type=EngineType.EMBEDDER,
59 | node_type=NodeType.COMPUTE,
60 | config={}
61 | )
62 |
63 | vector_db_searching_node = Node(
64 | name="VectorDBSearching",
65 | io_schema=NodeIOSchema(
66 | input_format={
67 | "queries_embeddings": List[float],
68 | index_input_key: self.input_format[index_input_key]
69 | },
70 | output_format=self.output_format
71 | ),
72 | op_type=NodeOps.VECTORDB_SEARCHING,
73 | engine_type=EngineType.VECTOR_DB,
74 | node_type=NodeType.COMPUTE,
75 | config={
76 | "top_k": self.config["top_k"],
77 | }
78 | )
79 |
80 | query_embedding_node >> vector_db_searching_node
81 |
82 | return [query_embedding_node, vector_db_searching_node]
83 |
84 |
85 |
--------------------------------------------------------------------------------
/Ayo/opt_pass/base_pass.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Dict, List, Optional, Any, TYPE_CHECKING
3 | from Ayo.logger import get_logger, GLOBAL_INFO_LEVEL
4 |
5 | logger = get_logger(__name__, level=GLOBAL_INFO_LEVEL)
6 |
7 | if TYPE_CHECKING:
8 | from Ayo.dags.dag import DAG
9 | from Ayo.dags.node import Node
10 |
11 | class OPT_Pass(ABC):
12 | """Base class for all optimization passes
13 |
14 | An optimization pass takes a DAG as input, performs specific optimizations,
15 | and returns the optimized DAG. Each pass should focus on a specific type
16 | of optimization (e.g., pruning dependencies, batching, splitting).
17 | """
18 |
19 | def __init__(self, name: str):
20 | """Initialize the optimization pass
21 |
22 | Args:
23 | name: Unique identifier for this optimization pass
24 | """
25 | self.name = name
26 | self.enabled = True
27 | self.config: Dict[str, Any] = {}
28 |
29 | @abstractmethod
30 | def run(self, dag: 'DAG') -> 'DAG':
31 | """Execute the optimization pass on the given DAG
32 |
33 | Args:
34 | dag: Input DAG to optimize
35 |
36 | Returns:
37 | Optimized DAG
38 |
39 | This method must be implemented by all concrete optimization passes.
40 | """
41 | pass
42 |
43 | def configure(self, **kwargs) -> None:
44 | """Configure the optimization pass
45 |
46 | Args:
47 | **kwargs: Configuration parameters specific to this pass
48 | """
49 | self.config.update(kwargs)
50 |
51 | def enable(self) -> None:
52 | """Enable this optimization pass"""
53 | self.enabled = True
54 |
55 | def disable(self) -> None:
56 | """Disable this optimization pass"""
57 | self.enabled = False
58 |
59 | def is_enabled(self) -> bool:
60 | """Check if this pass is enabled"""
61 | return self.enabled
62 |
63 | def get_config(self, key: str, default: Any = None) -> Any:
64 | """Get configuration value
65 |
66 | Args:
67 | key: Configuration key
68 | default: Default value if key doesn't exist
69 | """
70 | return self.config.get(key, default)
71 |
72 | def validate_dag(self, dag: 'DAG') -> bool:
73 | """Validate DAG before optimization
74 |
75 | Args:
76 | dag: DAG to validate
77 |
78 | Returns:
79 | True if DAG is valid for this optimization
80 | """
81 | return True
82 |
83 | def get_applicable_nodes(self, dag: 'DAG') -> List['Node']:
84 | """Get nodes that this pass can optimize
85 |
86 | Args:
87 | dag: Input DAG
88 |
89 | Returns:
90 | List of nodes that can be optimized by this pass
91 | """
92 | return []
93 |
94 | def log_optimization(self, message: str) -> None:
95 | """Log optimization information
96 |
97 | Args:
98 | message: Message to log
99 | """
100 | logger.info(f"[{self.name}] {message}")
101 |
102 | def __str__(self) -> str:
103 | return f"OPT_Pass(name={self.name}, enabled={self.enabled})"
104 |
105 | def __repr__(self) -> str:
106 | return self.__str__()
107 |
--------------------------------------------------------------------------------
/Ayo/opt_pass/pass_manager.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional
2 | from Ayo.opt_pass.base_pass import OPT_Pass
3 | from Ayo.dags.node import Node
4 |
5 | class PassManager:
6 | """Manager for handling optimization passes
7 |
8 | Features:
9 | - Pass registration and management
10 | - Pass selection based on node properties
11 | """
12 |
13 | def __init__(self):
14 | self.passes: Dict[str, OPT_Pass] = {}
15 |
16 | def register_pass(self, opt_pass: OPT_Pass) -> None:
17 | """Register an optimization pass
18 |
19 | Args:
20 | opt_pass: Optimization pass to register
21 | """
22 | self.passes[opt_pass.name] = opt_pass
23 |
24 | def get_passes(self) -> List[OPT_Pass]:
25 | """Get all registered passes
26 |
27 | Returns:
28 | List of registered optimization passes
29 | """
30 | return list(self.passes.values())
31 |
32 | def get_enabled_passes(self) -> List[OPT_Pass]:
33 | """Get all enabled passes
34 |
35 | Returns:
36 | List of enabled optimization passes
37 | """
38 | return [p for p in self.passes.values() if p.is_enabled()]
39 |
40 | def get_pass(self, name: str) -> Optional[OPT_Pass]:
41 | """Get a specific pass by name
42 |
43 | Args:
44 | name: Name of the pass to retrieve
45 |
46 | Returns:
47 | The requested pass or None if not found
48 | """
49 | return self.passes.get(name)
50 |
51 | def enable_pass(self, name: str) -> None:
52 | """Enable a specific pass
53 |
54 | Args:
55 | name: Name of the pass to enable
56 | """
57 | if name in self.passes:
58 | self.passes[name].enable()
59 |
60 | def disable_pass(self, name: str) -> None:
61 | """Disable a specific pass
62 |
63 | Args:
64 | name: Name of the pass to disable
65 | """
66 | if name in self.passes:
67 | self.passes[name].disable()
--------------------------------------------------------------------------------
/Ayo/opt_pass/pruning_dependency.py:
--------------------------------------------------------------------------------
1 | from typing import List, TYPE_CHECKING
2 | from Ayo.dags.node_commons import NodeType
3 | from Ayo.opt_pass.base_pass import OPT_Pass
4 | from Ayo.logger import get_logger, GLOBAL_INFO_LEVEL
5 |
6 | logger = get_logger(__name__, level=GLOBAL_INFO_LEVEL)
7 |
8 | if TYPE_CHECKING:
9 | from Ayo.dags.dag import DAG
10 | from Ayo.dags.node import Node
11 |
12 | class PruningDependencyPass(OPT_Pass):
13 | """Optimization Pass: Clean up invalid dependencies in DAG"""
14 |
15 | def __init__(self):
16 | super().__init__(name="PruningDependencyPass")
17 |
18 | def run(self, dag: 'DAG') -> 'DAG':
19 | """Execute the optimization pass"""
20 | # save the reference of DAG
21 | self.dag = dag
22 |
23 | # Ensure topological sort is up to date
24 | dag._ensure_topo_sort()
25 |
26 | # Check and add missing parent node connections
27 | self._add_missing_connections(dag)
28 |
29 | # Identify essential parents for each node based on topological structure
30 | self._identify_essential_parents(dag)
31 |
32 | # Finally prune invalid connections
33 | for node in dag.nodes:
34 | self._prune_node_dependencies(node)
35 |
36 | # clear the reference of DAG, avoid memory leak
37 | self.dag = None
38 |
39 | return dag
40 |
41 | def get_applicable_nodes(self, dag: 'DAG') -> List['Node']:
42 | """Get nodes that can be pruned"""
43 | return [node for node in dag.nodes if len(node.parents) > 0]
44 |
45 | def validate_dag(self, dag: 'DAG') -> bool:
46 | """Validate if DAG can be pruned"""
47 | return len(dag.nodes) > 0
48 |
49 | def _prune_node_dependencies(self, node: 'Node') -> None:
50 | """Clean up invalid dependencies for a single node"""
51 | invalid_parents = []
52 |
53 | for parent in node.parents:
54 | if not self._has_valid_connection(parent, node):
55 | invalid_parents.append(parent)
56 |
57 | for parent in invalid_parents:
58 | self._remove_connection(parent, node)
59 |
60 | def _has_valid_connection(self, parent: 'Node', child: 'Node') -> bool:
61 | """Check if there is valid data flow between parent and child nodes"""
62 | if parent.name in child.input_key_from_parents:
63 | output_key = child.input_key_from_parents[parent.name]
64 | if (output_key in parent.output_names and
65 | output_key in child.input_names):
66 | return True
67 | return False
68 |
69 | def _remove_connection(self, parent: 'Node', child: 'Node') -> None:
70 | """Remove connection between two nodes"""
71 | if child in parent.children:
72 | parent.children.remove(child)
73 |
74 | if parent in child.parents:
75 | child.parents.remove(parent)
76 |
77 | if parent.name in child.input_key_from_parents:
78 | del child.input_key_from_parents[parent.name]
79 |
80 | # update the in_degree information of DAG
81 | if hasattr(self, 'dag'):
82 | if child in self.dag.in_degree:
83 | self.dag.in_degree[child] -= 1
84 | # mark the topological sort need to be recalculated
85 | self.dag._mark_topo_dirty()
86 |
87 | def _add_missing_connections(self, dag: 'DAG') -> None:
88 | """Check and add missing parent node connections"""
89 | # preprocess: create the mapping from output name to node
90 | output_providers = {}
91 | for node in dag.topo_list:
92 | for output_name in node.output_names:
93 | if output_name not in output_providers:
94 | output_providers[output_name] = []
95 | output_providers[output_name].append(node)
96 |
97 | # process the nodes in topological order
98 | for node in dag.topo_list:
99 | if node.node_type == NodeType.INPUT:
100 | continue
101 |
102 | # check each input needed by the node
103 | for input_name in node.input_names:
104 | # if this input has no provider
105 | # here we assume the output name from different nodes are unique
106 | logger.info(f"{node.name}, {input_name}, {node.input_key_from_parents.values()}")
107 | if (input_name not in node.input_key_from_parents.values()):
108 | logger.warning(f"input_name: {input_name} not in node.input_key_from_parents.values(): {node.input_key_from_parents.values()}")
109 | # find the node that can provide this input
110 | potential_parents = output_providers.get(input_name, [])
111 | for potential_parent in potential_parents:
112 | if (potential_parent != node and
113 | potential_parent not in node.parents):
114 | # add the connection
115 | node.add_parent(potential_parent)
116 | node.input_key_from_parents[potential_parent.name] = input_name
117 | break
118 |
119 | def _identify_essential_parents(self, dag: 'DAG') -> None:
120 | """Identify essential parents for each node based on topological structure"""
121 | # process the nodes in topological order
122 | for node in dag.topo_list:
123 | if node.node_type == NodeType.INPUT or len(node.parents) <= 1:
124 | continue
125 |
126 | self._find_essential_parents(node)
127 |
128 | def _find_essential_parents(self, node: 'Node') -> None:
129 | """Find essential parents for a node, removing redundant parents"""
130 | # Record each input provider
131 | input_providers = {} # input name -> best parent node
132 | redundant_parents = []
133 |
134 | # ensure we know each parent's input
135 | for parent in node.parents:
136 | if parent.name in node.input_key_from_parents:
137 | output_key = node.input_key_from_parents[parent.name]
138 |
139 | # if this input has no provider, record current parent
140 | if output_key not in input_providers:
141 | input_providers[output_key] = parent
142 | else:
143 | # there is already a node providing this input
144 | # here we can implement the logic to select the best parent node, for example:
145 | # 1. select the node with earlier topological order (may result in an earlier availability)
146 | # 2. select the node based on other criteria
147 | # default: keep the first encountered parent node
148 | redundant_parents.append(parent)
149 |
150 | # remove the redundant parent node connections
151 | for parent in redundant_parents:
152 | self._remove_connection(parent, node)
153 |
154 |
155 | if __name__ == "__main__":
156 | from Ayo.dags.dag import DAG
157 | from Ayo.dags.node import Node
158 | from Ayo.dags.node_commons import NodeType, NodeIOSchema, NodeAnnotation
159 | from Ayo.engines.engine_types import EngineType
160 | from typing import Any, Dict
161 |
162 | dag = DAG()
163 |
164 | dag.set_query_inputs({"query": "What is the capital of France?", "passages": ["Paris is the capital of France.", "Paris is the capital of France.", "Paris is the capital of France."]})
165 |
166 | embedding_node = Node(
167 | name="Embedding",
168 | node_type=NodeType.COMPUTE,
169 | engine_type=EngineType.EMBEDDER,
170 | io_schema=NodeIOSchema(
171 | input_format={"passages": List[str]},
172 | output_format={"embeddings_passages": List[Any]}
173 | ),
174 | anno=NodeAnnotation.BATCHABLE,
175 | config={
176 | }
177 | )
178 |
179 | reranker_node = Node(
180 | name="Reranker",
181 | node_type=NodeType.COMPUTE,
182 | engine_type=EngineType.RERANKER,
183 | io_schema=NodeIOSchema(
184 | input_format={
185 | "query": str,
186 | "passages": List[str]
187 | },
188 | output_format={
189 | "ranked_results": List[str]
190 | }
191 | ),
192 | anno=NodeAnnotation.BATCHABLE,
193 | config={
194 | }
195 | )
196 |
197 | llm_node = Node(
198 | name="LLM",
199 | node_type=NodeType.COMPUTE,
200 | engine_type=EngineType.LLM,
201 | io_schema=NodeIOSchema(
202 | input_format={"query": str, "ranked_results": List[str]},
203 | output_format={"answer": str}
204 | ),
205 | anno=NodeAnnotation.BATCHABLE,
206 | config={
207 | }
208 | )
209 |
210 | embedding_node>>reranker_node >> llm_node
211 |
212 | dag.register_nodes(embedding_node, reranker_node, llm_node)
213 |
214 | print(dag.get_full_dag_nodes_info())
215 |
216 | dag.optimize([PruningDependencyPass()])
217 |
218 |
219 | print(dag.get_full_dag_nodes_info())
220 |
221 |
--------------------------------------------------------------------------------
/Ayo/opt_pass/test_dag_llm_decoding_pipeling.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NetX-lab/Ayo/61607bcb31ed98e1101c5d60d8573105bf60d87c/Ayo/opt_pass/test_dag_llm_decoding_pipeling.pdf
--------------------------------------------------------------------------------
/Ayo/opt_pass/test_dag_prefilling_split.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NetX-lab/Ayo/61607bcb31ed98e1101c5d60d8573105bf60d87c/Ayo/opt_pass/test_dag_prefilling_split.pdf
--------------------------------------------------------------------------------
/Ayo/queries/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NetX-lab/Ayo/61607bcb31ed98e1101c5d60d8573105bf60d87c/Ayo/queries/__init__.py
--------------------------------------------------------------------------------
/Ayo/queries/query.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Optional
2 | from enum import Enum
3 | import time
4 | from dataclasses import dataclass, field
5 | import ray
6 | from Ayo.dags.dag import DAG
7 | from Ayo.queries.query_state import QueryStates, QueryStatus
8 |
9 |
10 | class Query:
11 | """Base class for all query types in the system"""
12 | def __init__(
13 | self,
14 | uuid: str,
15 | query_id: str,
16 | query_inputs: Dict[str, Any],
17 | DAG: DAG,
18 | context: Optional[Dict] = None,
19 | uploaded_file: Optional[Any] = None,
20 | timeout: float = 30.0
21 | # here some attributes maybe not used... to clear some in the future
22 | ):
23 | # Basic query information
24 | self.uuid = uuid
25 | self.query_id = query_id
26 | self.query_inputs = query_inputs
27 | self.context = context or {}
28 | self.DAG = DAG
29 | self.uploaded_file = uploaded_file
30 |
31 | # Initialize DAG information
32 | self._init_dag()
33 |
34 | # Query state management
35 | self.status = QueryStatus.INIT
36 | self.query_state = QueryStates.remote()
37 | self.created_at = time.time()
38 | self.updated_at = self.created_at
39 | self.start_time: Optional[float] = None
40 | self.end_time: Optional[float] = None
41 | self.timeout = timeout
42 | self.error_message: Optional[str] = None
43 |
44 | # Results storage
45 | self.results: Dict[str, Any] = {}
46 | self.metadata: Dict[str, Any] = {}
47 |
48 | def _init_dag(self) -> None:
49 | """Initialize DAG query related information"""
50 | self.DAG.query_id = self.query_id
51 | self.DAG.set_query_inputs(self.query_inputs)
52 | self.DAG.create_input_nodes()
53 |
54 | def start(self):
55 | """Start query execution"""
56 | self.status = QueryStatus.RUNNING
57 | self.start_time = time.time()
58 | self.updated_at = self.start_time
59 |
60 | def complete(self):
61 | """Mark query as completed"""
62 | self.status = QueryStatus.COMPLETED
63 | self.end_time = time.time()
64 | self.updated_at = self.end_time
65 |
66 | def fail(self, error_message: str):
67 | """Mark query as failed"""
68 | self.status = QueryStatus.FAILED
69 | self.error_message = error_message
70 | self.end_time = time.time()
71 | self.updated_at = self.end_time
72 |
73 | def set_timeout(self):
74 | """Mark query as timed out"""
75 | self.status = QueryStatus.TIMEOUT
76 | self.error_message = "Query execution exceeded timeout"
77 | self.end_time = time.time()
78 | self.updated_at = self.end_time
79 |
80 | def is_timeout(self) -> bool:
81 | """Check if query has exceeded timeout"""
82 | if self.start_time is None:
83 | return False
84 | return (time.time() - self.start_time) > self.timeout
85 |
86 | def get_execution_time(self) -> Optional[float]:
87 | """Get query execution time in seconds"""
88 | if self.start_time is None:
89 | return None
90 | end = self.end_time or time.time()
91 | return end - self.start_time
92 |
93 | def _get_obj_name_recurse(self, name, obj):
94 | """Helper method for attribute access"""
95 | name = name.split(".", maxsplit=1)
96 | recurse = len(name) > 1
97 | next_name = name[1] if recurse else ""
98 | name = name[0]
99 | obj = self if obj is None else obj
100 | return obj, name, next_name, recurse
101 |
102 | def get_status(self) -> QueryStatus:
103 | """Get query status"""
104 | return self.status
105 |
106 | def get_remote_attr(self, __name: str, __obj: object = None):
107 | """Get remote attribute value"""
108 | obj, name, next_name, recurse = self._get_obj_name_recurse(__name, __obj)
109 | next_obj = getattr(obj, name)
110 | if recurse:
111 | next_obj = self.get_remote_attr(next_name, next_obj)
112 | return next_obj
113 |
114 | def set_remote_attr(self, __name: str, __value: Any, __obj: object = None):
115 | """Set remote attribute value"""
116 | obj, name, next_name, recurse = self._get_obj_name_recurse(__name, __obj)
117 | if recurse:
118 | next_obj = getattr(obj, name)
119 | self.set_remote_attr(next_name, __value, next_obj)
120 | else:
121 | if hasattr(obj, name):
122 | setattr(obj, name, __value)
123 |
124 | def to_dict(self) -> Dict[str, Any]:
125 | """Convert query to dictionary representation"""
126 | return {
127 | "uuid": self.uuid,
128 | "query_id": self.query_id,
129 | "query_inputs": self.query_inputs,
130 | "status": self.status.value,
131 | "created_at": self.created_at,
132 | "updated_at": self.updated_at,
133 | "execution_time": self.get_execution_time(),
134 | "error_message": self.error_message,
135 | "results": self.results,
136 | "context": self.context,
137 | "metadata": self.metadata
138 | }
139 |
140 | def __str__(self):
141 | return f"Query(uuid={self.uuid}, status={self.status.value}, query_id='{self.query_id}', query_inputs={self.query_inputs})"
142 |
143 | def __repr__(self):
144 | return self.__str__()
145 |
--------------------------------------------------------------------------------
/Ayo/queries/query_state.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Any, Optional
2 | import time
3 | import ray
4 | from enum import Enum
5 |
6 | class QueryStatus(Enum):
7 | """Query execution states"""
8 | INIT = "init" # Query is initialized
9 | PENDING = "pending" # Query is waiting to be processed
10 | RUNNING = "running" # Query is being processed
11 | COMPLETED = "completed"# Query completed successfully
12 | FAILED = "failed" # Query failed during execution
13 | TIMEOUT = "timeout" # Query exceeded time limit
14 |
15 | @ray.remote
16 | class QueryStates:
17 | """Ray Actor for managing query states and intermediate results
18 |
19 | Handles both:
20 | 1. Global variables and intermediate results for DAG nodes
21 | 2. Service-specific results and query states
22 | """
23 | def __init__(self):
24 | # For storing node-specific variables and results
25 | self.global_var_idx = {}
26 |
27 | # For storing query states and metadata
28 | self.states: Dict[str, Dict[str, Any]] = {}
29 |
30 | # For storing node-specific results
31 | self.node_results: Dict[str, Any] = {}
32 |
33 | def set_global_var(self, var, node_name):
34 | """Set global variable for a node"""
35 | self.global_var_idx[node_name] = var
36 |
37 | def get_global_var(self, node_name):
38 | """Get global variable for a node"""
39 | if node_name not in self.global_var_idx:
40 | return None
41 | return self.global_var_idx[node_name]
42 |
43 | def get_node_results(self):
44 | """Get all node results"""
45 | return self.node_results
46 |
47 | def set_node_result(self, node_name: str, result):
48 | """Set the result of a specific node"""
49 | if node_name not in self.node_results:
50 | self.node_results[node_name] = {}
51 | self.node_results[node_name] = result
52 |
53 | def get_node_result(self, node_name: str):
54 | """Get the result of a specific node"""
55 | if node_name not in self.node_results:
56 | return None
57 | return self.node_results[node_name]
58 |
59 | def clear_node_result(self, node_name: str):
60 | """Clear the result of a specific node"""
61 | if node_name in self.node_results:
62 | self.node_results.pop(node_name, None)
63 |
64 | def clear_query(self, query_id: str):
65 | """Clear all data related to a query"""
66 | # Clear query state
67 | if query_id in self.states:
68 | del self.states[query_id]
69 |
70 | # Clear any service results
71 | for service_results in self.service_results.values():
72 | service_results.pop(query_id, None)
--------------------------------------------------------------------------------
/Ayo/schedulers/engine_scheduler.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional, Any
2 | import ray
3 | import asyncio
4 | from abc import ABC, abstractmethod
5 | from dataclasses import dataclass
6 | import time
7 | from collections import deque
8 | from Ayo.queries.query import Query
9 | from Ayo.configs.config import EngineConfig
10 | from Ayo.dags.node import NodeOps
11 | from Ayo.logger import get_logger, GLOBAL_INFO_LEVEL
12 |
13 | logger = get_logger(__name__, level=GLOBAL_INFO_LEVEL)
14 |
15 |
16 | @dataclass
17 | class EngineRequest:
18 | """Data class for engine requests"""
19 | request_id: str # unique id for each request, {query_id}_{node_name}
20 | query_id: str
21 | query: Query
22 | payload: Any # Request payload (e.g., texts for embedding), generated from the payload_transformer, could be different for different engines
23 | result_ref: Optional[ray.ObjectRef] = None
24 | timestamp: float = time.time()
25 |
26 |
27 | class BaseEngineScheduler(ABC):
28 | """Abstract base class for engine schedulers"""
29 |
30 | @abstractmethod
31 | async def submit_request(self, request: EngineRequest):
32 | """Submit a request to the engine"""
33 | pass
34 |
35 | @abstractmethod
36 | async def shutdown(self):
37 | """Shutdown the scheduler and all engine instances"""
38 | pass
39 |
40 | class SchedulingStrategy(ABC):
41 | """Abstract base class for scheduling strategies"""
42 | @abstractmethod
43 | def get_next_engine(self, engines: List[Any], current_idx: int) -> 'tuple[Any, int]':
44 | """Get the next available engine
45 |
46 | Args:
47 | engines: list of engine instances
48 | current_idx: current engine index
49 |
50 | Returns:
51 | tuple[engine, new_idx]: selected engine and updated index
52 | """
53 | pass
54 |
55 | class RoundRobinStrategy(SchedulingStrategy):
56 | """Round-robin scheduling strategy"""
57 | def get_next_engine(self, engines: List[Any], current_idx: int) -> 'tuple[Any, int]':
58 | if not engines:
59 | raise ValueError("No engines available")
60 | engine = engines[current_idx]
61 | new_idx = (current_idx + 1) % len(engines)
62 | return engine, new_idx
63 |
64 | @ray.remote
65 | class EngineScheduler(BaseEngineScheduler):
66 | def __init__(self,
67 | engine_class,
68 | engine_config: EngineConfig, # use EngineConfig as config
69 | **engine_kwargs):
70 |
71 | self.engine_class = engine_class
72 | self.name = engine_config.name
73 | self.num_instances = engine_config.instances
74 | self.num_gpus = engine_config.num_gpus
75 | self.num_cpus = engine_config.num_cpus
76 | self.resources = engine_config.resources
77 | self.engine_kwargs = {
78 | **engine_kwargs,
79 | **(engine_config.model_config or {}) # merge model_config
80 | }
81 | self.engines = []
82 | self._create_engines()
83 |
84 | # set scheduling strategy, default using round robin
85 | self.scheduling_strategy = RoundRobinStrategy()
86 |
87 | # create request queue
88 | self.request_queue = asyncio.Queue(maxsize=1000)
89 |
90 |
91 | # initialize current engine index
92 | self.current_engine_idx = 0
93 |
94 |
95 | self.pending_requests: Dict[str, EngineRequest] = {}
96 |
97 | # Start processing tasks
98 | self.running = True
99 | self.loop = asyncio.get_event_loop()
100 | self.submit_task = self.loop.create_task(self._submit_requests())
101 | self.result_task = self.loop.create_task(self._process_results())
102 | self._is_ready = True # initialization complete flag
103 |
104 | # The queue for completed requests
105 | self.complete_queue = asyncio.Queue(maxsize=1000)
106 |
107 | def _create_engines(self):
108 | """create engine instances based on EngineConfig"""
109 | for i in range(self.num_instances):
110 | # create options dict, only include custom resources in resources
111 | options = {}
112 | if self.resources:
113 | custom_resources = {
114 | k: v for k, v in self.resources.items()
115 | if k.upper() not in ['CPU', 'GPU']
116 | }
117 | if custom_resources:
118 | options['resources'] = custom_resources
119 |
120 | if self.num_gpus:
121 | options['num_gpus'] = self.num_gpus
122 | if self.num_cpus:
123 | options['num_cpus'] = self.num_cpus
124 | else:
125 | options['num_cpus'] = 1
126 |
127 |
128 |
129 | # directly use the current actor's handle
130 | scheduler_handle = ray.runtime_context.get_runtime_context().current_actor
131 |
132 | # when creating engine instance, pass in scheduler handle
133 | logger.info(f"try to create engine instance with name: {self.name}_{i} for class: {self.engine_class} with options: {options}")
134 |
135 | engine = self.engine_class.options(**options).remote(
136 | name=f"{self.name}_{i}",
137 | scheduler_ref=scheduler_handle, # directly use actor handle
138 | **self.engine_kwargs
139 | )
140 |
141 | _= ray.get(engine.is_ready.remote())
142 | self.engines.append(engine)
143 | logger.info(f"created engine instance with name: {self.name}_{i} for class: {self.engine_class}")
144 |
145 | def _get_next_engine(self):
146 | """use scheduling strategy to get next available engine"""
147 | engine, new_idx = self.scheduling_strategy.get_next_engine(
148 | self.engines,
149 | self.current_engine_idx
150 | )
151 | self.current_engine_idx = new_idx
152 | return engine
153 |
154 | async def add_engine(self):
155 | """add a new engine instance dynamically"""
156 | engine = self.engine_class.remote(
157 | name=f"{self.name}_{self.num_instances}",
158 | resource=self.resource,
159 | scheduler_ref=self,
160 | **self.engine_kwargs
161 | )
162 | self.engines.append(engine)
163 | self.num_instances += 1
164 | return len(self.engines)
165 |
166 | async def remove_engine(self, index: Optional[int] = None):
167 | """remove an engine instance
168 |
169 | Args:
170 | index: the index of the engine to remove, if None, remove the last one
171 | """
172 | if not self.engines or self.num_instances <= 1:
173 | raise ValueError("Cannot remove the last engine instance")
174 |
175 | if index is None:
176 | index = len(self.engines) - 1
177 |
178 | if 0 <= index < len(self.engines):
179 | engine = self.engines.pop(index)
180 | await engine.shutdown.remote()
181 | self.num_instances -= 1
182 |
183 | # adjust current engine index
184 | if self.current_engine_idx >= len(self.engines):
185 | self.current_engine_idx = 0
186 |
187 | return len(self.engines)
188 | else:
189 | raise ValueError(f"Invalid engine index: {index}")
190 |
191 | async def _submit_requests(self):
192 | """Task for submitting requests to engines"""
193 | while self.running:
194 | try:
195 | # timeout 0.1 second to avoid busy waiting
196 | request = await asyncio.wait_for(self.request_queue.get(), timeout=0.05)
197 | logger.info(f"Processing request: {request.request_id}")
198 |
199 | # get available engine
200 | engine = self._get_next_engine()
201 |
202 | logger.info(f"submit request {request.request_id} with payload: {request.payload.keys()}")
203 | # submit request to engine
204 | await engine.submit_request.remote(
205 | request_id=request.request_id,
206 | query_id=request.query_id,
207 | # texts=request.payload
208 | **request.payload
209 | )
210 |
211 | # track result
212 | self.pending_requests[request.request_id] = request
213 | self.request_queue.task_done()
214 |
215 | except asyncio.TimeoutError:
216 | continue
217 | except asyncio.CancelledError:
218 | break
219 | except Exception as e:
220 | logger.error(f"Error submitting request for {request.request_id}: {e}")
221 | await asyncio.sleep(0.001)
222 |
223 | async def _process_results(self):
224 | """Task for processing completed results"""
225 | while self.running:
226 | try:
227 | request = await self.complete_queue.get()
228 |
229 | if request.result_ref is not None:
230 | try:
231 | result = await asyncio.get_event_loop().run_in_executor(
232 | None,
233 | ray.get,
234 | request.result_ref
235 | )
236 | logger.debug(f"result in engine scheduler: {result}")
237 |
238 | query_states = request.query.query_state
239 | # Use ray.get to wait for result completion
240 | await asyncio.get_event_loop().run_in_executor(
241 | None,
242 | ray.get,
243 | query_states.set_node_result.remote(
244 | request.request_id.split("::")[-1], # here we assume the request_id is in the format of {query_id}::{node_name} as in the graph scheduler submit_node method
245 | request.result_ref
246 | )
247 | )
248 |
249 | logger.info(f"Successfully set result for request {request.request_id}")
250 |
251 | except Exception as e:
252 | logger.error(f"Error processing result for request {request.request_id}: {e}")
253 |
254 | finally:
255 | # clean up completed request
256 | self.complete_queue.task_done()
257 | if request.request_id in self.pending_requests:
258 | del self.pending_requests[request.request_id]
259 |
260 | await asyncio.sleep(0.01)
261 |
262 | except Exception as e:
263 | logger.error(f"Error in result processing loop: {e}")
264 | await asyncio.sleep(0.01)
265 |
266 | async def submit_request(self, request: EngineRequest):
267 | """Submit a request to the engine"""
268 | #Enqueue request instead of submitting directly
269 | try:
270 | logger.info(f"Enqueued request: {request.request_id}")
271 | await self.request_queue.put(request)
272 | self.pending_requests[request.request_id] = request
273 | logger.info(f"Enqueued request: {request.request_id}")
274 |
275 | except asyncio.QueueFull:
276 | logger.error(f"Request queue full, cannot enqueue {request.request_id}")
277 | raise
278 |
279 | async def on_result(self, request_id: str, query_id: str, result_ref_from_engine: Any):
280 | """Handle result callback from engine"""
281 | if not isinstance(result_ref_from_engine, ray.ObjectRef):
282 | result_ref_from_engine = ray.put(result_ref_from_engine)
283 | try:
284 | if request_id in self.pending_requests:
285 | request = self.pending_requests[request_id]
286 | request.result_ref = result_ref_from_engine
287 |
288 | # put completed request into complete queue
289 | await self.complete_queue.put(request)
290 | logger.info(f"Moved request {request_id} to complete queue")
291 |
292 | # Clean up pending_requests
293 | del self.pending_requests[request_id]
294 | else:
295 | logger.warning(f"Received result for unknown request {request_id}")
296 |
297 | except Exception as e:
298 | logger.error(f"Error processing result callback: {e}")
299 |
300 | async def shutdown(self):
301 | """Shutdown the scheduler and all engine instances"""
302 | self.running = False
303 |
304 | # Cancel processing tasks
305 | for task in [self.submit_task, self.result_task]:
306 | if task:
307 | task.cancel()
308 | try:
309 | await task
310 | except asyncio.CancelledError:
311 | pass
312 |
313 | # Shutdown all engine instances
314 | for engine in self.engines:
315 | await engine.shutdown.remote()
316 |
317 | # Clear engine list
318 | self.engines.clear()
319 | self.num_instances = 0
320 |
321 | async def is_ready(self) -> bool:
322 | """Check if the scheduler is fully initialized
323 |
324 | Returns:
325 | bool: if the scheduler is fully initialized and ready to process requests
326 | """
327 | # check if all necessary components are initialized
328 | if not self.engines:
329 | return False
330 |
331 | # check if async tasks are running
332 | if not self.submit_task or self.submit_task.done():
333 | return False
334 | if not self.result_task or self.result_task.done():
335 | return False
336 |
337 | return self._is_ready
338 |
339 |
--------------------------------------------------------------------------------
/Ayo/utils.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | def print_warning(message):
4 | #yellow color
5 | print(f"\033[93m{message}\033[0m")
6 |
7 | def print_key_info(message):
8 | #green color
9 | print(f"\033[92m{message}\033[0m")
10 |
11 | def print_error(message):
12 | #red color
13 | print(f"\033[91m{message}\033[0m")
14 |
15 |
16 | def format_query_expanding_prompt(question, prompt_template, expanded_query_num=3):
17 | keys = ", ".join([f"question{i+1}" for i in range(expanded_query_num)])
18 | json_example = "{\n " + "\n ".join([f"\"question{i+1}\": \"[refined version {i+1}]\"" + ("," if i < expanded_query_num-1 else "") for i in range(expanded_query_num)]) + "\n }"
19 |
20 | return prompt_template.format(
21 | expanded_query_num=expanded_query_num,
22 | question=question,
23 | keys=keys,
24 | json_example=json_example
25 | )
26 |
27 | def rename_template_placeholders(template, placeholder_mapping):
28 | """
29 | Modify the placeholder names in the template string
30 |
31 | Args:
32 | template (str): The template string containing placeholders
33 | placeholder_mapping (dict): The placeholder mapping, format as {old_name: new_name}
34 |
35 | Returns:
36 | str: The modified template string
37 | """
38 | result = template
39 | for old_name, new_name in placeholder_mapping.items():
40 | result = result.replace(f"{{{old_name}}}", f"{{{new_name}}}")
41 | return result
42 |
43 | def fill_prompt_template_with_placeholdersname_approximations(prompt_template, input_kwargs):
44 | """
45 | Fill the prompt template, handle the mismatch between placeholder names and input parameter names
46 |
47 | Args:
48 | prompt_template (str): The prompt template containing placeholders
49 | input_kwargs (dict): The input parameters dictionary
50 |
51 | Returns:
52 | str: The filled prompt template
53 | """
54 | import re
55 | from difflib import get_close_matches
56 |
57 | # find all placeholders in the prompt template
58 | placeholders = re.findall(r'\{([^{}]+)\}', prompt_template)
59 | result = prompt_template
60 |
61 | # create a match cache to avoid duplicate calculations
62 | input_keys = list(input_kwargs.keys())
63 |
64 | # find the best match for each placeholder
65 | for placeholder in placeholders:
66 | if placeholder in input_kwargs:
67 | # if there is an exact match, use it
68 | value = str(input_kwargs[placeholder])
69 | result = result.replace(f"{{{placeholder}}}", value)
70 | else:
71 | # check if the match result is already in the cache
72 | close_matches = get_close_matches(placeholder, input_keys, n=1, cutoff=0.1)
73 |
74 | if close_matches:
75 | print_warning(f"Use approximate matching: Placeholder '{placeholder}' is matched to input parameter '{close_matches[0]}'")
76 | value = str(input_kwargs[close_matches[0]])
77 | result = result.replace(f"{{{placeholder}}}", value)
78 | else:
79 | print_error(f"No match found for placeholder '{placeholder}'")
80 |
81 | return result
82 |
83 |
84 | def check_unfilled_placeholders_in_prompt_template(prompt_template):
85 | """
86 | Check if the prompt template is complete by ensuring all placeholders are filled (no {placeholder} pattern in the prompt template)
87 | """
88 | placeholders = re.findall(r'\{([^{}]+)\}', prompt_template)
89 | if placeholders:
90 | raise ValueError(f"Prompt template is not complete without any unfilled placeholders")
91 | else:
92 | return True
93 |
94 | def check_prompt_template_and_placeholders_match(prompt_template, input_kwargs):
95 | """
96 | Check if the prompt template is complete by ensuring all placeholders are filled
97 |
98 | Args:
99 | prompt_template (str): The prompt template to check
100 | input_kwargs (dict): The input kwargs to check
101 |
102 | Returns:
103 | bool: True if the prompt template is complete, False otherwise
104 | """
105 |
106 |
107 | # check if all placeholders in the prompt template are in the input_kwargs
108 | import re
109 | placeholders = re.findall(r'\{([^{}]+)\}', prompt_template)
110 | for placeholder in placeholders:
111 | if placeholder not in input_kwargs:
112 | raise ValueError(f"Placeholder {placeholder} not found in input_kwargs")
113 |
114 | return True
115 |
116 | def fill_prompt_template(prompt_template, input_kwargs):
117 | """
118 | Fill the placeholders in the prompt template with the input kwargs
119 | """
120 | for key, value in input_kwargs.items():
121 | prompt_template = prompt_template.replace(f"{{{key}}}", value)
122 |
123 | #check if the prompt template is complete
124 | return prompt_template
125 |
126 |
127 |
128 |
129 |
--------------------------------------------------------------------------------
/Ayo/vis/test_dag_node_types.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NetX-lab/Ayo/61607bcb31ed98e1101c5d60d8573105bf60d87c/Ayo/vis/test_dag_node_types.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 NetX
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Ayo
2 |
3 | This repository contains the prototype implementation for our ASPLOS'25 paper: [Towards End-to-End Optimization of LLM-based Applications with Ayo](https://dl.acm.org/doi/10.1145/3676641.3716278) ([arXiv preprint](https://arxiv.org/pdf/2407.00326)).
4 |
5 | Ayo is a fine-grained orchestration framework designed for building and optimizing AI-powered applications—such as Retrieval-Augmented Generation (RAG) workflows—in environments where inference engines are ***deployed locally*** rather than accessed via remote APIs.
6 |
7 | Unlike existing frameworks that usually treat workflows as coarse-grained, sequential module chains, Ayo introduces a task-primitive-based abstraction, enabling highly flexible and dynamic orchestration. With minimal user input, Ayo automatically optimizes workflows for performance, exploiting parallelism, pipelining, and inherent scheduling strategies.
8 |
9 | > **Note**:Some parts of the repo are still under construction, e.g. the unified multi-request scheduling for engine schedulers, user-friendly interface, and the documentation. We would keep updating these.
10 |
11 | ## Key Features
12 | - Fine-grained task orchestration for LLM workflows
13 | - Dynamic optimization for performance (e.g., parallelism, pipelining)
14 | - Dependency Pruning
15 | - Stage decomposition parallelization
16 | - LLM Prefilling Splitting
17 | - LLM Decoding pipelining
18 | - Distributed Two-level Scheduling
19 | - A graph scheduler for the task primitive scheduling of each query graph
20 | - Several distinct engine schedulers for handling different types of engines and managing the different operations
21 |
22 | ## Quick Start
23 |
24 | 1. Install dependencies:
25 |
26 | Install postgres and pgvector:
27 | ```bash
28 | sudo apt-get install postgresql postgresql-contrib libpq-dev # install postgresql
29 | git clone https://github.com/pgvector/pgvector.git # compile and install pgvector; you could install through other ways as well
30 | cd pgvector
31 | make
32 | sudo make install
33 | sudo -u postgres psql template1 -c "CREATE EXTENSION vector;" # test
34 | ```
35 |
36 | Install our modified vllm:
37 | ```bash
38 | git clone --recurse-submodules https://github.com/NetX-lab/Ayo.git # clone the repo and submodules
39 | cd vllm
40 | pip install -e .
41 | ```
42 |
43 | Install Ayo:
44 | ```bash
45 | cd ..
46 | pip install -r requirements.txt
47 | pip install -e .
48 | ```
49 |
50 | 2. Define the workflow with Nodes (Task Primitives) and Optimize the workflow with Ayo
51 |
52 |
53 | Click to expand the code
55 |
56 | ```python
57 | from Ayo.app import APP
58 | from Ayo.configs.config import EngineConfig
59 | from Ayo.engines.engine_types import EngineType
60 |
61 | app = APP.init() # initialize the app entry
62 |
63 | llm_config = EngineConfig(
64 | name="llm_service",
65 | engine_type=EngineType.LLM,
66 | resources={},
67 | num_gpus=1,
68 | num_cpus=1,
69 | instances=1,
70 | model_config={
71 | "model_name": "meta-llama/Llama-2-7b-chat-hf",
72 | "tensor_parallel_size": 1,
73 | #other config ...
74 | },
75 | latency_profile={
76 | "timeout": 300,
77 | }
78 | )
79 |
80 | app.register_engine(llm_config)
81 | #register other engines ...
82 |
83 |
84 | # define the primitive nodes
85 | llm_prefilling_node = Node(
86 | name="LLMPrefilling",
87 | node_type=NodeType.COMPUTE,
88 | engine_type=EngineType.LLM,
89 | io_schema=NodeIOSchema(
90 | input_format={"queries": List[str], "reranked_results": List[List[str]]},
91 | output_format={"prefill_state": bool}
92 | ),
93 | op_type=NodeOps.LLM_PREFILLING,
94 | config={
95 | 'prompt_template': replace_placeholders(RAG_QUESTION_ANSWERING_PROMPT_TEMPLATE_STRING, question="queries", context="reranked_results"),
96 | 'parse_json': True,
97 | #other config ...
98 | }
99 | )
100 |
101 | llm_decoding_node = Node(
102 | name="LLMDecoding",
103 | node_type=NodeType.COMPUTE,
104 | engine_type=EngineType.LLM,
105 | io_schema=NodeIOSchema(
106 | input_format={"prefill_state": bool},
107 | output_format={"result": str}
108 | ),
109 | op_type=NodeOps.LLM_DECODING,
110 | config={
111 | 'prompt_template': replace_placeholders(RAG_QUESTION_ANSWERING_PROMPT_TEMPLATE_STRING, question="queries", context="reranked_results"),
112 | 'parse_json': True,
113 | #other config ...
114 | }
115 | )
116 | #define other nodes ...
117 |
118 | # create the DAG
119 | dag = DAG(dag_id="rag_workflow")
120 | dag.register_nodes(llm_prefilling_node, llm_decoding_node, ...)
121 | # set the query inputs
122 | dag.set_query_inputs(
123 | {
124 | 'queries': ['What is the capital of France?'], ## set the query inputs
125 | }
126 | )
127 |
128 | from Ayo.opt_pass.pruning_dependency import PruningDependencyPass
129 | from Ayo.opt_pass.stage_decomposition import StageDecompositionPass
130 | from Ayo.opt_pass.prefilling_split import PrefillingSpiltPass
131 | from Ayo.opt_pass.decoding_pipeling import LLMDecodingPipeliningPass
132 |
133 | dag.optimize([PruningDependencyPass(), StageDecompositionPass(), PrefillingSpiltPass(), LLMDecodingPipeliningPass()])
134 |
135 | query=Query(
136 | uuid=f"random-test-{query_id}",
137 | query_id=f"random-test-{query_id}",
138 | DAG=deepcopy(dag)
139 | )
140 |
141 | future = await app.submit_query(
142 | query=query,
143 | timeout=300
144 | )
145 |
146 | result = await asyncio.wait_for(future, timeout=300)
147 |
148 |
149 | ```
150 |
151 | Click to expand the code
159 |
160 | ```python
161 | from Ayo.modules import IndexingModule, QueryExpandingModule, SearchingModule, RerankingModule
162 | from Ayo.modules_to_primitives import transform_mod_to_prim
163 |
164 | indexing_module = IndexingModule(
165 | input_format={"passages": List[str]},
166 | output_format={"index_status": bool}
167 | )
168 |
169 | query_expanding_module = QueryExpandingModule(
170 | input_format={"query": str},
171 | output_format={"expanded_queries": List[str]},
172 | config={"expanded_query_num": 3}
173 | )
174 |
175 | searching_module = SearchingModule(
176 | input_format={"index_status": bool, "expanded_queries": List[str]},
177 | output_format={"searching_results": List[str]}
178 | )
179 |
180 | reranking_module = RerankingModule(
181 | input_format={"searching_results": List[str]},
182 | output_format={"reranking_results": List[str]}
183 | )
184 |
185 |
186 | indexing_module>>query_expanding_module>>searching_module>>reranking_module
187 |
188 |
189 | node_list=transform_mod_to_prim([indexing_module,query_expanding_module,searching_module,reranking_module])
190 |
191 | ### Then optimize the workflow with Ayo as above
192 |
193 | ```
194 |
195 |
196 |
197 |