├── OpenAGI ├── __init__.py ├── config.py ├── openagi_utils.py ├── runner.py ├── agents.py ├── async_online_utils.py └── planner.py ├── requirements.txt ├── predictor.py ├── color.py ├── travelplanner_supplement ├── config.py ├── runner.py ├── async_online_utils.py └── planner.py ├── util.py ├── .gitignore ├── README.md └── data └── openagi_task_description.txt /OpenAGI/__init__.py: -------------------------------------------------------------------------------- 1 | from .async_online_utils import OnlineTrajectoryCollector, OnlineLearningExecutor, FiniteReplay, SharedState -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | openai==1.65.2 2 | tiktoken 3 | asyncio 4 | pyautogen==0.2.31 5 | colorama 6 | async-timeout==4.0.2 7 | config==0.5.1 8 | inputimeout 9 | torch==2.7.0 10 | transformers==4.51.3 -------------------------------------------------------------------------------- /predictor.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class DistilBERTValueFunction(nn.Module): 4 | def __init__(self, bert_model): 5 | super().__init__() 6 | self.model = bert_model 7 | self.fc = nn.Linear(self.model.config.hidden_size, 1) 8 | 9 | def forward(self, input_ids, attention_mask): 10 | outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) 11 | cls_embedding = outputs.last_hidden_state[:, 0, :] 12 | 13 | k = self.fc(cls_embedding) 14 | return k -------------------------------------------------------------------------------- /color.py: -------------------------------------------------------------------------------- 1 | from colorama import init, Fore, Style 2 | import sys,time,random 3 | import sys 4 | import regex as re 5 | import asyncio 6 | 7 | 8 | # code from https://gist.github.com/jeffskinnerbox/6663095 9 | colorCodes = { 10 | 'black': '0;30', 'bright gray': '0;37', 11 | 'blue': '0;34', 'white': '1;37', 12 | 'green': '0;32', 'bright blue': '1;34', 13 | 'cyan': '0;36', 'bright green': '1;32', 14 | 'red': '0;31', 'bright cyan': '1;36', 15 | 'purple': '0;35', 'bright red': '1;31', 16 | 'yellow': '0;33', 'bright purple': '1;35', 17 | 'dark gray': '1;30', 'bright yellow': '1;33', 18 | 'normal': '0' 19 | } 20 | 21 | ########## print thought, will not log ########## 22 | #typing_speed: wpm 23 | def slow_type_approximation(t): 24 | t = '[Approximation agent:]\n' + t + '\n[End Thought Process]' 25 | for l in t: 26 | sys.stdout.write("\033[" + colorCodes['bright cyan'] + "m" + l + "\033[0m") 27 | #sys.stdout.write(l) 28 | sys.stdout.flush() 29 | time.sleep(random.random()*10.0/300) 30 | print('') 31 | return '' 32 | 33 | def slow_type_target(t): 34 | t = '[Target agent:]\n' + t + '\n[End Thought Process]' 35 | for l in t: 36 | sys.stdout.write("\033[" + colorCodes['bright purple'] + "m" + l + "\033[0m") 37 | #sys.stdout.write(l) 38 | sys.stdout.flush() 39 | time.sleep(random.random()*10.0/300) 40 | print('') 41 | return '' 42 | 43 | if __name__ == '__main__': 44 | slow_type_approximation('This is a test') 45 | slow_type_target('This is a test') -------------------------------------------------------------------------------- /travelplanner_supplement/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from collections import defaultdict 3 | 4 | @dataclass 5 | class Config: 6 | MAX_STEP: int = 5 7 | TRAIN_INTERVAL: int = 1 8 | WARMUP: int = 0 9 | 10 | ENABLE_TRAIN: bool = True 11 | ENABLE_PRED: bool = True 12 | 13 | MAX_CONCURRENT_CALLS: int = 0 14 | TOTAL_APPROXIMATION_CALLS: int = 0 15 | TOTAL_CORRECT_APPROXIMATION_CALLS: int = 0 16 | 17 | TOTAL_TOKEN_GENERATION: int = 0 18 | TOTAL_TOKEN_PROMPT: int = 0 19 | USERINPUT: bool = False 20 | 21 | TARGET_NORMAL_PROMPT = defaultdict(int) 22 | TARGET_NORMAL_GENERATION = defaultdict(int) 23 | TARGET_SP_PROMPT: int = 0 24 | TARGET_SP_GENERATION: int = 0 25 | 26 | APPROX_SP_PROMPT: int = 0 27 | APPROX_SP_GENERATION: int = 0 28 | APPROX_NORMAL_PROMPT = defaultdict(int) 29 | APPROX_NORMAL_GENERATION = defaultdict(int) 30 | 31 | TOTAL_SP_TIME: float = 0.0 32 | TARGET_NORMAL_TIME = defaultdict(float) 33 | APPROX_NROMAL_TIME = defaultdict(float) 34 | 35 | PREDICT_K =[] 36 | PREDICT_CORRECT: int = 0 37 | PREDICT_TOTAL: int = 0 38 | BUILD_TRAJ_TIMES: int = 0 39 | 40 | PENDING_BACKGROUND_TASKS = [] 41 | 42 | TARGET_TYPE: str = "" 43 | 44 | # Human-in-the-loop interaction 45 | HIL_INTERACTION: int = 0 46 | 47 | def reset_task_metrics(self): 48 | """Reset metrics for a new task.""" 49 | self.MAX_CONCURRENT_CALLS = 0 50 | self.TOTAL_APPROXIMATION_CALLS = 0 51 | self.TOTAL_CORRECT_APPROXIMATION_CALLS = 0 52 | 53 | self.TOTAL_TOKEN_GENERATION = 0 54 | self.TOTAL_TOKEN_PROMPT = 0 55 | self.USERINPUT = False 56 | 57 | self.TARGET_NORMAL_PROMPT.clear() 58 | self.TARGET_NORMAL_GENERATION.clear() 59 | self.TARGET_SP_PROMPT = 0 60 | self.TARGET_SP_GENERATION = 0 61 | 62 | self.APPROX_SP_PROMPT = 0 63 | self.APPROX_SP_GENERATION = 0 64 | self.APPROX_NORMAL_PROMPT.clear() 65 | self.APPROX_NORMAL_GENERATION.clear() 66 | 67 | self.TOTAL_SP_TIME = 0.0 68 | self.TARGET_NORMAL_TIME.clear() 69 | self.APPROX_NROMAL_TIME.clear() 70 | 71 | self.PREDICT_K.clear() 72 | self.PREDICT_CORRECT = 0 73 | self.PREDICT_TOTAL = 0 74 | self.BUILD_TRAJ_TIMES = 0 75 | self.PENDING_BACKGROUND_TASKS.clear() 76 | self.HIL_INTERACTION = 0 77 | 78 | def initialize_from_args(self, args): 79 | """Initialize configuration from command line arguments.""" 80 | self.TRAIN_INTERVAL = args.freq 81 | self.BUILD_TRAJ_TIMES = 0 82 | # self.MAX_STEP = 5 83 | self.ENABLE_TRAIN = args.pred 84 | self.ENABLE_PRED = args.pred 85 | self.WARMUP = 0 -------------------------------------------------------------------------------- /OpenAGI/config.py: -------------------------------------------------------------------------------- 1 | """Configuration for the speculative planning system.""" 2 | 3 | from dataclasses import dataclass, field 4 | from typing import Dict, List 5 | 6 | @dataclass 7 | class Config: 8 | """Global configuration and metrics.""" 9 | # Constants 10 | MAX_STEP: int = 5 11 | TRAIN_INTERVAL: int = 1 12 | WARMUP: int = 0 13 | 14 | # Flags 15 | ENABLE_TRAIN: bool = True 16 | ENABLE_PRED: bool = True 17 | USERINPUT: bool = False 18 | 19 | # Metrics 20 | MAX_CONCURRENT_CALLS: int = 0 21 | TOTAL_APPROXIMATION_CALLS: int = 0 22 | TOTAL_CORRECT_APPROXIMATION_CALLS: int = 0 23 | 24 | # Token tracking 25 | TOTAL_TOKEN_GENERATION: int = 0 26 | TOTAL_TOKEN_PROMPT: int = 0 27 | 28 | TARGET_NORMAL_PROMPT: Dict[int, int] = field(default_factory=dict) 29 | TARGET_NORMAL_GENERATION: Dict[int, int] = field(default_factory=dict) 30 | TARGET_SP_PROMPT: int = 0 31 | TARGET_SP_GENERATION: int = 0 32 | 33 | APPROX_SP_PROMPT: int = 0 34 | APPROX_SP_GENERATION: int = 0 35 | APPROX_NORMAL_PROMPT: Dict[int, int] = field(default_factory=dict) 36 | APPROX_NORMAL_GENERATION: Dict[int, int] = field(default_factory=dict) 37 | 38 | # Timing metrics 39 | TOTAL_SP_TIME: float = 0.0 40 | TARGET_NORMAL_TIME: Dict[int, float] = field(default_factory=dict) 41 | APPROX_NORMAL_TIME: Dict[int, float] = field(default_factory=dict) 42 | 43 | # Prediction metrics 44 | PREDICT_K: List[int] = field(default_factory=list) 45 | PREDICT_CORRECT: int = 0 46 | PREDICT_TOTAL: int = 0 47 | BUILD_TRAJ_TIMES: int = 0 48 | 49 | def reset_task_metrics(self): 50 | """Reset metrics for a new task.""" 51 | self.MAX_CONCURRENT_CALLS = 0 52 | self.TOTAL_APPROXIMATION_CALLS = 0 53 | self.TOTAL_CORRECT_APPROXIMATION_CALLS = 0 54 | 55 | self.TOTAL_TOKEN_GENERATION = 0 56 | self.TOTAL_TOKEN_PROMPT = 0 57 | self.USERINPUT = False 58 | 59 | self.TARGET_NORMAL_PROMPT.clear() 60 | self.TARGET_NORMAL_GENERATION.clear() 61 | self.TARGET_SP_PROMPT = 0 62 | self.TARGET_SP_GENERATION = 0 63 | 64 | self.APPROX_SP_PROMPT = 0 65 | self.APPROX_SP_GENERATION = 0 66 | self.APPROX_NORMAL_PROMPT.clear() 67 | self.APPROX_NORMAL_GENERATION.clear() 68 | 69 | self.TOTAL_SP_TIME = 0.0 70 | self.TARGET_NORMAL_TIME.clear() 71 | self.APPROX_NORMAL_TIME.clear() 72 | 73 | self.PREDICT_K.clear() 74 | self.PREDICT_CORRECT = 0 75 | self.PREDICT_TOTAL = 0 76 | self.BUILD_TRAJ_TIMES = 0 77 | 78 | def initialize_from_args(self, args): 79 | """Initialize configuration from command line arguments.""" 80 | self.TRAIN_INTERVAL = args.freq 81 | self.BUILD_TRAJ_TIMES = 0 82 | self.MAX_STEP = 5 83 | self.ENABLE_TRAIN = args.pred 84 | self.ENABLE_PRED = args.pred 85 | self.WARMUP = 0 -------------------------------------------------------------------------------- /OpenAGI/openagi_utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for the speculative planning system.""" 2 | 3 | import asyncio 4 | from datetime import datetime 5 | import time 6 | 7 | def parse_response(response: str) -> str: 8 | """Parse the response from agents to extract the action.""" 9 | if '<' in response and '>' in response and '', '').replace('<', '').replace('>', '') 15 | elif '**' in response and response.count('**') >= 2: 16 | start = response.index('**') + len('**') 17 | response = response[start:] 18 | end = response.index('**') 19 | return response[:end].replace('<', '').replace('>', '') 20 | else: 21 | return '' 22 | 23 | def ordinal(n): 24 | """Convert a number to its ordinal representation.""" 25 | if 11 <= (n % 100) <= 13: 26 | suffix = 'th' 27 | else: 28 | suffix = {1: 'st', 2: 'nd', 3: 'rd'}.get(n % 10, 'th') 29 | return f"{n}{suffix}" 30 | 31 | def load_data(data_path): 32 | """Load task descriptions from a file.""" 33 | with open(data_path, "r") as f: 34 | data = f.read() 35 | return [t.strip() for t in data.split("\n")] 36 | 37 | def judge_to_be_true(s: str, t: str) -> bool: 38 | """Check if two actions are equivalent.""" 39 | return s == t 40 | 41 | def concurrent_calls(): 42 | """Get the number of concurrent tasks.""" 43 | tasks = asyncio.all_tasks() 44 | pending_tasks = [t for t in tasks if not t.done() and not t.cancelled()] 45 | return len(pending_tasks) 46 | 47 | def get_timestamp(): 48 | """Get current timestamp.""" 49 | return datetime.fromtimestamp(time.time()) 50 | 51 | def create_prompt(task_description): 52 | """Create the initial prompt for the agents.""" 53 | tools = """ 54 | Available tools are as follows: 55 | 56 | (1) Sentiment Analysis 57 | (2) Text Summarization 58 | (3) Machine Translation 59 | (4) Fill Mask 60 | (5) Question Answering 61 | (6) Image Classification 62 | (7) Object Detection 63 | (8) Colorization 64 | (9) Image Super-Resolution 65 | (10) Image Denoising 66 | (11) Image Deblurring 67 | (12) Visual Question Answering 68 | (13) Image Captioning 69 | (14) Text-to-Image Generation 70 | (15) TERMINATE 71 | 72 | For each step of the plan, please specify the tool you would like to use. But if you think the task is completed, please use TERMINATE to end the conversation. 73 | 74 | Please use xml tags to specify the tool when responsing. For example, Sentiment Analysis for Sentiment Analysis. 75 | """ 76 | return f"## Problem: {task_description}\nPlease solve this problem using the following tools step by step:\n{tools}" 77 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import os 4 | from datetime import date, datetime 5 | import sys 6 | import asyncio 7 | from async_timeout import timeout 8 | import time 9 | from autogen import AssistantAgent 10 | import tiktoken 11 | import signal 12 | import time 13 | import config 14 | from functools import partial 15 | 16 | class Logger(object): 17 | def __init__(self, log_path, on=False): 18 | self.log_path = log_path 19 | self.on = on 20 | os.makedirs(os.path.dirname(self.log_path), exist_ok=True) 21 | if self.on: 22 | while os.path.isfile(self.log_path): 23 | self.log_path = self.log_path[:-4] + '+'+ self.log_path[-4:] 24 | with open(self.log_path, 'w') as f: 25 | f.write("") 26 | 27 | def log(self, string, newline=True): 28 | #if self.on: 29 | with open(self.log_path, "a") as logf: 30 | today = date.today() 31 | today_date = today.strftime("%m/%d/%Y") 32 | now = datetime.now() 33 | current_time = now.strftime("%H:%M:%S") 34 | string = today_date + ", " + current_time + ": " + string 35 | logf.write(string) 36 | if newline: 37 | logf.write("\n") 38 | 39 | sys.stdout.write(string) 40 | if newline: 41 | sys.stdout.write("\n") 42 | sys.stdout.flush() 43 | 44 | async def cancel(task): 45 | start = time.time() 46 | task.cancel() 47 | try: 48 | async with timeout(-1): 49 | await task 50 | end = time.time() 51 | except asyncio.CancelledError: 52 | end = time.time() 53 | return 'task cancelled' 54 | except asyncio.exceptions.TimeoutError: 55 | end = time.time() 56 | return 'time out for canceling task' 57 | except Exception as e: 58 | end = time.time() 59 | return 'exception:' + str(e) 60 | 61 | 62 | class SignalHandler: 63 | def __init__(self): 64 | pass 65 | 66 | # # old register handler, doesn't know why it doesn't work :( 67 | # def exit_handler(self, signum, frame, target_tasks): 68 | # to_halt_target_task = [target_task for target_task in target_tasks if target_task.get_name() == f'target_{config.HIL_INTERACTION}' and not target_task.done() and not target_task.cancelled()] 69 | # if to_halt_target_task != []: 70 | # config.USERINPUT=True 71 | # to_halt_target_task[0].cancel() 72 | 73 | async def exit_handler(self, target_tasks): 74 | to_halt_target_task = [target_task for target_task in target_tasks if target_task.get_name() == f'target_{config.HIL_INTERACTION}' and not target_task.done() and not target_task.cancelled()] 75 | if to_halt_target_task != []: 76 | config.USERINPUT=True 77 | to_halt_target_task[0].cancel() 78 | async with timeout(-1): 79 | await to_halt_target_task[0] 80 | 81 | def register_async_handler(target_tasks): 82 | loop = asyncio.get_event_loop() 83 | s = SignalHandler() 84 | #global sigterm_handler 85 | loop.add_signal_handler(getattr(signal, 'SIGINT'), lambda: asyncio.create_task(s.exit_handler(target_tasks=target_tasks))) 86 | 87 | 88 | def exit_handler(signum, frame, target_tasks): 89 | to_halt_target_task = [target_task for target_task in target_tasks if target_task.get_name() == f'target_{config.HIL_INTERACTION}' and not target_task.done() and not target_task.cancelled()] 90 | if to_halt_target_task != []: 91 | config.USERINPUT=True 92 | to_halt_target_task[0].cancel() 93 | 94 | # # old register handler, doesn't know why it doesn't work :( 95 | # def register_handler(target_tasks): 96 | # s = SignalHandler() 97 | # signal.signal(signal.SIGINT, partial(s.exit_handler, target_tasks=target_tasks)) 98 | 99 | def register_handler(target_tasks): 100 | signal.signal(signal.SIGINT, partial(exit_handler, target_tasks=target_tasks)) 101 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/python 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=python 3 | .DS_Store 4 | 5 | ### Python ### 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # poetry 103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 107 | #poetry.lock 108 | 109 | # pdm 110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 111 | #pdm.lock 112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 113 | # in version control. 114 | # https://pdm.fming.dev/#use-with-ide 115 | .pdm.toml 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | #.idea/ 166 | 167 | ### Python Patch ### 168 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 169 | poetry.toml 170 | 171 | # ruff 172 | .ruff_cache/ 173 | 174 | # LSP config files 175 | pyrightconfig.json 176 | 177 | # End of https://www.toptal.com/developers/gitignore/api/python 178 | 179 | OpenAGI/openagi_dyn.py 180 | travelplanner_supplement/sp_travel_planner_dyn.py -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dynamic Speculative Agent Planning 2 | Dynamic Speculative Planning (DSP) is an lightweight online reinforcement learning framework for accelerating LLM-based agents. This repository hosts the code and data for this paper: [Dynamic Speculative Agent Planning](https://arxiv.org/abs/2509.01920) 3 | 4 | ## Table of Contents 5 | - [Experiment & Command](#experiment--command) 6 | - [OpenAGI Experiment](#openagi-experiment) 7 | - [TravelPlanner Experiment](#travelplanner-experiment) 8 | - [Citation](#citation) 9 | 10 | ## Experiment & Command 11 | We provide two environments for two separate experiments. Please follow instructions accordingly. 12 | 13 | 14 | ### OpenAGI Experiment 15 | The OpenAGI setting uses the agent to generate plan first and then do the execution. Thus here, we focus on the planning step without execution. 16 | 17 | To set up the environment: 18 | ``` 19 | conda create -n specplan python=3.10 20 | conda activate specplan 21 | pip install -r requirements.txt 22 | ``` 23 | 24 | In openagi_dyn.py, set the OPENAI_API_KEY and DEEPSEEK_API_KEY: 25 | ``` 26 | os.environ['OPENAI_API_KEY'] = your_gpt_key 27 | os.environ['DEEPSEEK_API_KEY'] = your_dpsk_key 28 | ``` 29 | 30 | There are four setting that we employs in our experiment: 31 | 32 | We use the following shorthand: 33 | - `Direct` = direct-generation 34 | - `CoT` = chain-of-thought 35 | - `MAD` = multi-agent-debate 36 | 37 | | **Setting** | **Approximation Agent** | **Target Agent** | 38 | |-------------|--------------------------|------------------| 39 | | **1** | Direct (GPT-4.1-mini) | ReAct (GPT-4.1-mini) | 40 | | **2** | CoT (GPT-4.1-mini) | MAD (GPT-4.1-mini) | 41 | | **3** | Direct (deepseek-chat) | ReAct (deepseek-reasoner) | 42 | | **4** | CoT (deepseek-chat) | MAD (deepseek-reasoner) | 43 | 44 | You may also configure your own approximation–target combinations or plug in other APIs if interested. 45 | 46 | - Fix Mode: 47 | ``` 48 | python -m OpenAGI.runner --no-pred --k 2 # choose fix k value 49 | ``` 50 | 51 | - Dynamic Mode: 52 | ``` 53 | approx_type = "direct" # could be "direct" (setting 1 & 3), "cot" (setting 2 & 4) 54 | target_type = "react" # could be "react" (setting 1 & 3), "multi_agent" (setting 2 & 4) 55 | offset = 0 # choose inference offset for k 56 | tau = 0.5 # choose asymmetric hyperparameter for expectile regression 57 | model_type = "gpt-4.1-mini" # could be "gpt-4.1-mini" or "deepseek" 58 | python -m OpenAGI.runner --pred --target_type target_type --approx_type approx_type --offset offset --tau tau --model_type model_type 59 | ``` 60 | 61 | ### TravelPlanner Experiment 62 | The TravelPlanner mainly adopts the code from [TravelPlanner](https://github.com/OSU-NLP-Group/TravelPlanner) and integrate the dynamic speculative planning code into it. 63 | 64 | To run speculative planning on TravelPlanner, you need to first download code and database following instructions in [TravelPlanner](https://github.com/OSU-NLP-Group/TravelPlanner) to download data. A different virtual environment to fit TravelPlanner is also necessary. 65 | ``` 66 | git clone https://github.com/OSU-NLP-Group/TravelPlanner 67 | 68 | conda create -n travelplanner python=3.9 69 | conda activate travelplanner 70 | pip install -r requirements.txt 71 | pip install -r TravelPlanner/requirements.txt 72 | ``` 73 | 74 | Put tool_agents_sp.py from travelplanner_supplement/ into TravelPlanner/agents/. 75 | Then, put other files from travelplanner_supplement/ and predictor.py, util.py from Dynamic-Speculative-Planning/ into the TravelPlanner/ root directory. 76 | 77 | In tool_agents_sp.py and runner.py, set the OPENAI_API_KEY and DEEPSEEK_API_KEY: 78 | ``` 79 | os.environ['OPENAI_API_KEY'] = your_gpt_key 80 | os.environ['DEEPSEEK_API_KEY'] = your_dpsk_key 81 | ``` 82 | 83 | To run the experiment: 84 | - Fix Mode 85 | ``` 86 | cd TravelPlanner 87 | python runner.py --no-pred --k 2 # choose fix k value 88 | ``` 89 | 90 | - Dynamic Mode 91 | ``` 92 | cd TravelPlanner 93 | approx_type = "direct" # could be "direct" (setting 1 & 3), "cot" (setting 2 & 4) 94 | target_type = "react" # could be "react" (setting 1 & 3), "multi_agent" (setting 2 & 4) 95 | offset = 0 # choose inference offset for k 96 | tau = 0.5 # choose asymmetric hyperparameter for expectile regression 97 | model_type = "gpt-4.1-mini" # could be "gpt-4.1-mini", "deepseek-chat" 98 | python runner.py --pred --target_type target_type --approx_type approx_type --offset offset --tau tau --model_type model_type 99 | ``` 100 | 101 | ## Citation 102 | 103 | If you have any further questions, please feel free to contact us. And if you find our work helpful, please cite our paper: 104 | 105 | ```bibtex 106 | @article{guan2025dynamic, 107 | title={Dynamic Speculative Agent Planning}, 108 | author={Guan, Yilin and Hua, Wenyue and Lan, Qingfeng and Fei, Sun and Ding, Dujian and Acharya, Devang and Wang, Chi and Wang, William Yang}, 109 | journal={arXiv preprint arXiv:2509.01920}, 110 | year={2025} 111 | } 112 | ``` 113 | -------------------------------------------------------------------------------- /OpenAGI/runner.py: -------------------------------------------------------------------------------- 1 | """Runner for the speculative planning system.""" 2 | 3 | import os 4 | import random 5 | import argparse 6 | import torch 7 | from datetime import datetime 8 | import asyncio 9 | 10 | from transformers import AutoModel, AutoTokenizer 11 | import tiktoken 12 | 13 | from .config import Config 14 | from .planner import SpeculativePlanner 15 | from .agents import setup_assistants 16 | from .agents import setup_multi_agent, create_agent 17 | from .openagi_utils import load_data, create_prompt 18 | from .predictor import DistilBERTValueFunction 19 | from .async_online_utils import OnlineLearningExecutor 20 | from util import Logger 21 | 22 | config = Config() # Create global config instance 23 | 24 | async def _run_planner(planner, args, prompt): 25 | """Async helper function to run the planner.""" 26 | return await planner.run(args, prompt) 27 | 28 | def parse_args(): 29 | parser = argparse.ArgumentParser(description='OpenAGI') 30 | parser.add_argument('--data', type=str, default='data/openagi_task_description.txt', 31 | help='data directory') 32 | parser.add_argument('--k', type=int, default=2, 33 | help='number of approximation steps to generate everytime') 34 | parser.add_argument('--approx_type', type=str, default='direct', 35 | help='cot, direct') 36 | parser.add_argument('--target_type', type=str, default='react', 37 | help='react, multi_agent') 38 | parser.add_argument('--model_type', type=str, default="gpt-4.1-mini", 39 | help='gpt-4.1-mini, deepseek') 40 | parser.add_argument('--pred', action='store_true', 41 | help='enable speculative planning with predictor') 42 | parser.add_argument('--no-pred', dest='pred', action='store_false', 43 | help='disable speculative planning with predictor') 44 | parser.set_defaults(pred=True) 45 | 46 | # Online learning parameters 47 | parser.add_argument('--lr', type=float, default=1e-5, 48 | help='online learning lr') 49 | parser.add_argument('--ep', type=int, default=3, 50 | help='online learning epoch per train') 51 | parser.add_argument('--bf', type=int, default=2500, 52 | help='online learning buffer size') 53 | parser.add_argument('--bs', type=int, default=16, 54 | help='online learning batch size') 55 | parser.add_argument('--gma', type=float, default=1, 56 | help='online learning gamma for lambda return calculation') 57 | parser.add_argument('--lmd', type=float, default=0.95, 58 | help='online learning lambda for lambda return calculation') 59 | parser.add_argument('--load', dest='load', action='store_true', 60 | help='load previous trajectory and model') 61 | parser.add_argument('--no-load', dest='load', action='store_false', 62 | help='do not load previous trajectory and model') 63 | parser.add_argument('--tau', type=float, default=0.5, 64 | help='expectile loss tau') 65 | parser.add_argument('--s_task', type=int, default=1, 66 | help='start task id') 67 | parser.add_argument('--freq', type=int, default=1, 68 | help='online learning training frequency') 69 | parser.add_argument('--offset', type=int, default=0, 70 | help='biased inference offset for k') 71 | 72 | parser.set_defaults(load=False) 73 | return parser.parse_args() 74 | 75 | def setup_executor(args, device): 76 | """Set up the online learning executor.""" 77 | # Set up paths 78 | if args.pred: # dyn k 79 | traj_dir = f"trajectory/online_traj/{args.approx_type}_{args.target_type}/{args.model_type}" 80 | os.makedirs(traj_dir, exist_ok=True) 81 | traj_file = f"{traj_dir}/tau_{args.tau}_offset_{args.offset}.ndjson" 82 | else: # fix k 83 | traj_dir = f"trajectory/{args.approx_type}_{args.target_type}/{args.model_type}/fix_k_{args.k}" 84 | os.makedirs(traj_dir, exist_ok=True) 85 | traj_file = f"{traj_dir}/task_{{}}.json" 86 | 87 | ckpt_dir = f"ckpt/online/{args.approx_type}_{args.target_type}/{args.model_type}" 88 | os.makedirs(ckpt_dir, exist_ok=True) 89 | ckpt_path = f"{ckpt_dir}/tau_{args.tau}_offset_{args.offset}.pth" 90 | 91 | # Set up model 92 | model_path = "distilbert-base-uncased" 93 | bert_model = AutoModel.from_pretrained(model_path) 94 | tokenizer = AutoTokenizer.from_pretrained(model_path) 95 | model = DistilBERTValueFunction(bert_model).to(device) 96 | 97 | # Create executor 98 | return OnlineLearningExecutor( 99 | device=device, 100 | model_save_path=ckpt_path, 101 | model=model, 102 | tokenizer=tokenizer, 103 | buffer_size=args.bf, 104 | batch_size_steps=args.bs, 105 | lambda_=args.lmd, 106 | gamma=args.gma, 107 | lr=args.lr, 108 | epoch_per_train=args.ep, 109 | load=args.load, 110 | traj_file=traj_file, 111 | tau=args.tau 112 | ), traj_file 113 | 114 | def setup_agents(args): 115 | """Set up approximation and target agents with proper types.""" 116 | app_assistant, tar_assistant = setup_assistants(args.model_type) 117 | 118 | # Set up multi-agent if needed before creating agents 119 | if args.target_type == 'multi_agent': 120 | tar_assistants = setup_multi_agent(args.model_type, tar_assistant) 121 | else: 122 | tar_assistants = tar_assistant 123 | 124 | # Create agents with proper types - loggers will be set by planner 125 | app_agent = create_agent( 126 | args.approx_type, 127 | app_assistant, 128 | logger=None, 129 | encoding=tiktoken.get_encoding("cl100k_base"), 130 | config=config 131 | ) 132 | 133 | tar_agent = create_agent( 134 | args.target_type, 135 | tar_assistants, # This will be either single assistant or list for multi-agent 136 | logger=None, 137 | encoding=tiktoken.get_encoding("cl100k_base"), 138 | config=config 139 | ) 140 | 141 | return app_agent, tar_agent 142 | 143 | def setup_task_loggers(args, task_id): 144 | """Set up loggers for a specific task.""" 145 | pred_type = "dyn_k" if args.pred else "fix_k" 146 | log_dir = f"data/{args.approx_type}_{args.target_type}/{args.model_type}/{pred_type}" 147 | 148 | if args.pred: 149 | base_path = f'{log_dir}/tau_{args.tau}_offset_{args.offset}' 150 | else: 151 | base_path = f'{log_dir}/k_{args.k}' 152 | 153 | logger = Logger(f'{base_path}/simulation_datapoint{task_id}.log', on=True) 154 | target_logger = Logger(f'{base_path}/target_datapoint{task_id}.log', on=True) 155 | approximation_logger = Logger(f'{base_path}/approximation_datapoint{task_id}.log', on=True) 156 | 157 | return logger, target_logger, approximation_logger 158 | 159 | def log_task_metrics(logger, step_num, config): 160 | """Log metrics for the completed task.""" 161 | # Calculate basic metrics 162 | normal_plan_time = sum(config.TARGET_NORMAL_TIME[i] for i in range(1, step_num+1)) 163 | normal_app_time = sum(config.APPROX_NORMAL_TIME[i] for i in range(1, step_num+1)) 164 | normal_tar_generation = sum(config.TARGET_NORMAL_GENERATION[i] for i in range(1, step_num+1)) 165 | normal_app_generation = sum(config.APPROX_NORMAL_GENERATION[i] for i in range(1, step_num+1)) 166 | normal_plan_generation = normal_tar_generation + normal_app_generation 167 | normal_tar_prompt = sum(config.TARGET_NORMAL_PROMPT[i] for i in range(1, step_num+1)) 168 | normal_app_prompt = sum(config.APPROX_NORMAL_PROMPT[i] for i in range(1, step_num+1)) 169 | normal_plan_prompt = normal_app_prompt + normal_tar_prompt 170 | 171 | # Log token metrics 172 | logger.log('normal approx token prompt: ' + str(normal_app_prompt)) 173 | logger.log('normal approx token generation: ' + str(normal_app_generation)) 174 | logger.log('sp approx token prompt: ' + str(config.APPROX_SP_PROMPT)) 175 | logger.log('sp approx token generation: ' + str(config.APPROX_SP_GENERATION)) 176 | logger.log('normal target token prompt: ' + str(normal_tar_prompt)) 177 | logger.log('normal target token generation: ' + str(normal_tar_generation)) 178 | logger.log('sp target token prompt: ' + str(config.TOTAL_TOKEN_PROMPT - config.APPROX_SP_PROMPT)) 179 | logger.log('sp target token generation: ' + str(config.TOTAL_TOKEN_GENERATION - config.APPROX_SP_GENERATION)) 180 | logger.log('total sp token prompt: ' + str(config.TOTAL_TOKEN_PROMPT)) 181 | logger.log('total sp token generation: ' + str(config.TOTAL_TOKEN_GENERATION)) 182 | 183 | # Log timing metrics 184 | logger.log('normal target step time: ' + str(normal_plan_time)) 185 | logger.log('accuracy of approximation agent: ' + str(config.TOTAL_CORRECT_APPROXIMATION_CALLS/config.TOTAL_APPROXIMATION_CALLS)) 186 | 187 | # Calculate and log averages 188 | avg_sp_token = round(config.TOTAL_TOKEN_GENERATION/step_num, 2) 189 | avg_normal_token = round(normal_plan_generation/step_num, 2) 190 | logger.log('sp token generation/step: ' + str(avg_sp_token)) 191 | logger.log('normal token generation/step: ' + str(avg_normal_token)) 192 | logger.log('generation token cost increased: ' + str(round((avg_sp_token/avg_normal_token-1)*100, 2))+"%.") 193 | 194 | avg_sp_prompt = config.TOTAL_TOKEN_PROMPT / step_num 195 | avg_normal_prompt = round(normal_plan_prompt/step_num, 2) 196 | logger.log('sp prompt/step: ' + str(round(avg_sp_prompt, 2))) 197 | logger.log('normal prompt/step: ' + str(avg_normal_prompt)) 198 | logger.log('prompt token cost increased: ' + str(round((avg_sp_prompt/avg_normal_prompt-1)*100, 2))+"%.") 199 | 200 | avg_sp_time = round(config.TOTAL_SP_TIME/step_num, 2) 201 | avg_normal_time = round(normal_plan_time/step_num, 2) 202 | avg_approx_time = round(normal_app_time/step_num, 2) 203 | logger.log('sp time/step: ' + str(avg_sp_time)) 204 | logger.log('normal target time/step: ' + str(avg_normal_time)) 205 | logger.log('normal approx time/step: ' + str(avg_approx_time)) 206 | logger.log('time decreased by: ' + str(round((1-avg_sp_time/avg_normal_time)*100, 2))+"%.") 207 | 208 | if config.ENABLE_PRED: 209 | logger.log(f'predictor acc: {round(config.PREDICT_CORRECT / config.PREDICT_TOTAL, 2)}') 210 | logger.log(f'step number: {step_num}') 211 | 212 | def run_task(args, task_id, executor, encoding, planner, traj_file): 213 | """Run a single task with the speculative planner.""" 214 | # Reset config metrics for new task 215 | config.reset_task_metrics() 216 | 217 | # Set up task-specific loggers 218 | logger, target_logger, approximation_logger = setup_task_loggers(args, task_id) 219 | 220 | # Update planner loggers - planner will handle passing loggers to agents 221 | planner.set_loggers(logger, target_logger, approximation_logger) 222 | 223 | # Load and set up task 224 | tasks = load_data(args.data) 225 | task_description = tasks[task_id] 226 | executor.set_initial_task_prompt(task_description) 227 | 228 | # Log task description through planner's loggers 229 | planner.log_task_description(task_description) 230 | 231 | # Create initial prompt and run planner 232 | prompt = create_prompt(task_description) 233 | steps = asyncio.run(_run_planner(planner, args, prompt)) 234 | 235 | # Log results and metrics through planner 236 | planner.log_results(steps) 237 | log_task_metrics(planner.logger, len(steps), config) 238 | 239 | # Save trajectory 240 | if not args.pred: 241 | task_traj_file = traj_file.format(task_id) 242 | planner.logger.log(f'k = {args.k}') 243 | else: 244 | task_traj_file = traj_file 245 | planner.logger.log('dynamic k') 246 | 247 | executor.save_trajectory(task_traj_file, config.ENABLE_TRAIN) 248 | 249 | def main(): 250 | """Main entry point.""" 251 | # Set up environment 252 | os.environ['DEEPSEEK_API_KEY'] = "" 253 | os.environ['OPENAI_API_KEY'] = "" 254 | encoding = tiktoken.get_encoding("cl100k_base") 255 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 256 | 257 | # Parse arguments and initialize config 258 | args = parse_args() 259 | config.initialize_from_args(args) 260 | 261 | # Set random seed 262 | random.seed(2) 263 | 264 | # Set up executor 265 | executor, traj_file = setup_executor(args, device) 266 | 267 | # Set up agents without loggers - planner will handle logger assignment 268 | app_agent, tar_agent = setup_agents(args) 269 | 270 | # Initialize the planner that will be reused across tasks 271 | planner = SpeculativePlanner( 272 | collector=executor.collector, 273 | encoding=encoding, 274 | executor=executor, 275 | app_agent=app_agent, 276 | tar_agent=tar_agent, 277 | config=config 278 | ) 279 | 280 | # Run tasks 281 | task_ids = list(range(args.s_task, 313)) 282 | warmup_task = 0 283 | for task_id in task_ids: 284 | if warmup_task >= config.WARMUP: 285 | config.ENABLE_PRED = args.pred 286 | run_task(args, task_id, executor, encoding, planner, traj_file) 287 | warmup_task += 1 288 | 289 | if __name__ == '__main__': 290 | main() -------------------------------------------------------------------------------- /travelplanner_supplement/runner.py: -------------------------------------------------------------------------------- 1 | """Runner for the travel planner speculative planning system.""" 2 | 3 | import os 4 | import argparse 5 | import torch 6 | import asyncio 7 | from datasets import load_dataset 8 | 9 | from transformers import AutoModel, AutoTokenizer 10 | import tiktoken 11 | 12 | from config import Config 13 | from planner import TravelPlannerSpeculativePlanner 14 | from predictor import DistilBERTValueFunction 15 | from async_online_utils import OnlineLearningExecutor 16 | from util import Logger 17 | from agents.tool_agents_sp import DirectAgent, ReactAgent, CoTAgent, MultiAgent 18 | 19 | config = Config() 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser(description='Travel Planner Speculative Planning') 23 | parser.add_argument('--set_type', type=str, default='validation', help='validation or test') 24 | parser.add_argument('--k', type=int, default=4, help='number of approximation steps to generate everytime') 25 | parser.add_argument('--approx_type', type=str, default='direct', help='cot, direct') 26 | parser.add_argument('--target_type', type=str, default='react', help='react, multi_agent') 27 | parser.add_argument('--model_type', type=str, default="gpt-4.1-mini", help='gpt-4.1-mini, deepseek-chat') 28 | parser.add_argument('--pred', action='store_true', help='enable speculative planning with predictor') 29 | parser.add_argument('--no-pred', dest='pred', action='store_false', help='disable speculative planning with predictor') 30 | parser.set_defaults(pred=True) 31 | 32 | # Online learning parameters 33 | parser.add_argument('--lr', type=float, default=1e-5, help='online learning lr') 34 | parser.add_argument('--ep', type=int, default=3, help='online learning epoch per train') 35 | parser.add_argument('--bf', type=int, default=2500, help='online learning buffer size') 36 | parser.add_argument('--bs', type=int, default=16, help='online learning batch size') 37 | parser.add_argument('--gma', type=float, default=1, help='online learning gamma for lambda return calculation') 38 | parser.add_argument('--lmd', type=float, default=0.95, help='online learning lambda for lambda return calculation') 39 | parser.add_argument('--load', dest='load', action='store_true', help='load previous trajectory and model') 40 | parser.add_argument('--no-load', dest='load', action='store_false', help='do not load previous trajectory and model') 41 | parser.add_argument('--tau', type=float, default=0.5, help='expectile loss tau') 42 | parser.add_argument('--s_task', type=int, default=1, help='start task id') 43 | parser.add_argument('--freq', type=int, default=1, help='online learning training frequency') 44 | parser.add_argument('--offset', type=int, default=0, help='biased inference offset for k') 45 | 46 | parser.set_defaults(load=False) 47 | return parser.parse_args() 48 | 49 | def setup_executor(args, device): 50 | """Set up the online learning executor.""" 51 | if args.pred: # dyn k 52 | traj_dir = f"../trajectory/travel_planner/online_traj/{args.approx_type}_{args.target_type}/{args.model_type}" 53 | os.makedirs(traj_dir, exist_ok=True) 54 | traj_file = f"{traj_dir}/tau_{args.tau}_offset_{args.offset}.ndjson" 55 | else: # fix k 56 | traj_dir = f"../trajectory/travel_planner/{args.approx_type}_{args.target_type}/{args.model_type}/fix_k_{args.k}" 57 | os.makedirs(traj_dir, exist_ok=True) 58 | traj_file = f"{traj_dir}/task_{{}}.json" 59 | 60 | ckpt_dir = f"../ckpt/travel_planner/online/{args.approx_type}_{args.target_type}/{args.model_type}" 61 | os.makedirs(ckpt_dir, exist_ok=True) 62 | ckpt_path = f"{ckpt_dir}/tau_{args.tau}_offset_{args.offset}.pth" 63 | 64 | # Set up model 65 | model_path = "distilbert-base-uncased" 66 | bert_model = AutoModel.from_pretrained(model_path) 67 | tokenizer = AutoTokenizer.from_pretrained(model_path) 68 | model = DistilBERTValueFunction(bert_model).to(device) 69 | 70 | # Create executor 71 | return OnlineLearningExecutor( 72 | device=device, 73 | model_save_path=ckpt_path, 74 | model=model, 75 | tokenizer=tokenizer, 76 | buffer_size=args.bf, 77 | batch_size_steps=args.bs, 78 | lambda_=args.lmd, 79 | gamma=args.gma, 80 | lr=args.lr, 81 | epoch_per_train=args.ep, 82 | load=args.load, 83 | traj_file=traj_file, 84 | tau=args.tau 85 | ), traj_file 86 | 87 | def setup_agents(args): 88 | """Set up approximation and target agents with proper types.""" 89 | tools_list = [ 90 | "notebook", 91 | "flights", 92 | "attractions", 93 | "accommodations", 94 | "restaurants", 95 | "googleDistanceMatrix", 96 | "planner", 97 | "cities", 98 | ] 99 | 100 | app_model_type = args.model_type 101 | tar_model_type = args.model_type if args.model_type == "gpt-4.1-mini" else "deepseek-reasoner" 102 | 103 | # setup target agent 104 | if args.target_type == "react": 105 | target_agent = ReactAgent( 106 | None, 107 | tools=tools_list, 108 | max_steps=config.MAX_STEP, 109 | react_llm_name=tar_model_type, 110 | planner_llm_name=tar_model_type, 111 | ) 112 | elif args.target_type == "multi_agent": 113 | target_agent = MultiAgent( 114 | None, 115 | tools=tools_list, 116 | max_steps=config.MAX_STEP, 117 | react_llm_name=tar_model_type, 118 | planner_llm_name=tar_model_type, 119 | ) 120 | else: 121 | target_agent = None 122 | config.TARGET_TYPE = args.target_type 123 | 124 | # setup approximation agent 125 | if args.approx_type == "direct": 126 | approximation_agent = DirectAgent( 127 | None, 128 | tools=tools_list, 129 | max_steps=config.MAX_STEP, 130 | react_llm_name=app_model_type, 131 | planner_llm_name=app_model_type, 132 | ) 133 | elif args.approx_type == "cot": 134 | approximation_agent = CoTAgent( 135 | None, 136 | tools=tools_list, 137 | max_steps=config.MAX_STEP, 138 | react_llm_name=app_model_type, 139 | planner_llm_name=app_model_type, 140 | ) 141 | else: 142 | approximation_agent = None 143 | 144 | return approximation_agent, target_agent 145 | 146 | def setup_task_loggers(args, task_id): 147 | """Set up loggers for a specific task.""" 148 | pred_type = "dyn_k" if args.pred else "fix_k" 149 | log_dir = f"../data/travel_planner/{args.approx_type}_{args.target_type}/{args.model_type}/{pred_type}" 150 | 151 | if args.pred: 152 | base_path = f'{log_dir}/tau_{args.tau}_offset_{args.offset}' 153 | else: 154 | base_path = f'{log_dir}/k_{args.k}' 155 | 156 | logger = Logger(f'{base_path}/simulation_datapoint{task_id}.log', on=True) 157 | target_logger = Logger(f'{base_path}/target_datapoint{task_id}.log', on=True) 158 | approximation_logger = Logger(f'{base_path}/approximation_datapoint{task_id}.log', on=True) 159 | 160 | return logger, target_logger, approximation_logger 161 | 162 | def log_task_metrics(logger, step_num, config): 163 | """Log metrics for the completed task.""" 164 | # Calculate basic metrics 165 | normal_plan_time = sum(config.TARGET_NORMAL_TIME[i] for i in range(1, step_num+1)) 166 | normal_app_time = sum(config.APPROX_NROMAL_TIME[i] for i in range(1, step_num+1)) 167 | normal_tar_generation = sum(config.TARGET_NORMAL_GENERATION[i] for i in range(1, step_num+1)) 168 | normal_app_generation = sum(config.APPROX_NORMAL_GENERATION[i] for i in range(1, step_num+1)) 169 | normal_plan_generation = normal_tar_generation + normal_app_generation 170 | normal_tar_prompt = sum(config.TARGET_NORMAL_PROMPT[i] for i in range(1, step_num+1)) 171 | normal_app_prompt = sum(config.APPROX_NORMAL_PROMPT[i] for i in range(1, step_num+1)) 172 | normal_plan_prompt = normal_app_prompt + normal_tar_prompt 173 | 174 | # Log token metrics 175 | logger.log('normal approx token prompt: ' + str(normal_app_prompt)) 176 | logger.log('normal approx token generation: ' + str(normal_app_generation)) 177 | logger.log('sp approx token prompt: ' + str(config.APPROX_SP_PROMPT)) 178 | logger.log('sp approx token generation: ' + str(config.APPROX_SP_GENERATION)) 179 | logger.log('normal target token prompt: ' + str(normal_tar_prompt)) 180 | logger.log('normal target token generation: ' + str(normal_tar_generation)) 181 | logger.log('sp target token prompt: ' + str(config.TOTAL_TOKEN_PROMPT - config.APPROX_SP_PROMPT)) 182 | logger.log('sp target token generation: ' + str(config.TOTAL_TOKEN_GENERATION - config.APPROX_SP_GENERATION)) 183 | logger.log('total sp token prompt: ' + str(config.TOTAL_TOKEN_PROMPT)) 184 | logger.log('total sp token generation: ' + str(config.TOTAL_TOKEN_GENERATION)) 185 | 186 | # Log timing metrics 187 | logger.log('normal target step time: ' + str(normal_plan_time)) 188 | logger.log('accuracy of approximation agent: ' + str(config.TOTAL_CORRECT_APPROXIMATION_CALLS/config.TOTAL_APPROXIMATION_CALLS)) 189 | 190 | # Calculate and log averages 191 | avg_sp_token = round(config.TOTAL_TOKEN_GENERATION/step_num, 2) 192 | avg_normal_token = round(normal_plan_generation/step_num, 2) 193 | logger.log('sp token generation/step: ' + str(avg_sp_token)) 194 | logger.log('normal token generation/step: ' + str(avg_normal_token)) 195 | logger.log('generation token cost increased: ' + str(round((avg_sp_token/avg_normal_token-1)*100, 2))+"%.") 196 | 197 | avg_sp_prompt = config.TOTAL_TOKEN_PROMPT / step_num 198 | avg_normal_prompt = round(normal_plan_prompt/step_num, 2) 199 | logger.log('sp prompt/step: ' + str(round(avg_sp_prompt, 2))) 200 | logger.log('normal prompt/step: ' + str(avg_normal_prompt)) 201 | logger.log('prompt token cost increased: ' + str(round((avg_sp_prompt/avg_normal_prompt-1)*100, 2))+"%.") 202 | 203 | avg_sp_time = round(config.TOTAL_SP_TIME/step_num, 2) 204 | logger.log('sp time/step: ' + str(avg_sp_time)) 205 | avg_normal_time = round(normal_plan_time/step_num, 2) 206 | avg_approx_time = round(normal_app_time/step_num, 2) 207 | logger.log('normal target time/step: ' + str(avg_normal_time)) 208 | logger.log('normal approx time/step: ' + str(avg_approx_time)) 209 | logger.log('time decreased by: ' + str(round((1-avg_sp_time/avg_normal_time)*100, 2))+"%.") 210 | 211 | def run_one_task(args, task_id, executor, encoding, planner, traj_file, query_data_list): 212 | query = query_data_list[task_id-1]["query"] 213 | 214 | # Set query for agents (consistent with original) 215 | planner.app_agent.query = query 216 | planner.tar_agent.query = query 217 | 218 | logger, target_logger, approximation_logger = setup_task_loggers(args, task_id) 219 | 220 | # Set up train_logger if training is enabled (consistent with original) 221 | train_logger = None 222 | if config.ENABLE_TRAIN: 223 | pred_type = "dyn_k" if args.pred else "fix_k" 224 | log_dir = f"../data/travel_planner/{args.approx_type}_{args.target_type}/{args.model_type}/{pred_type}" 225 | train_logger = Logger(f'{log_dir}/tau_{args.tau}_offset_{args.offset}/train_datapoint{task_id}.log', on=True) 226 | 227 | planner.set_loggers(logger, target_logger, approximation_logger, train_logger) 228 | planner.log_task_description(query) 229 | 230 | config.reset_task_metrics() 231 | 232 | # Initialize planner for new task 233 | planner.initialize_for_new_task(query) 234 | 235 | executor.set_initial_task_prompt(query) 236 | 237 | steps = asyncio.run(planner.run(args, query)) 238 | 239 | step_num = len(steps) 240 | log_task_metrics(logger, step_num, config) 241 | 242 | if args.pred: 243 | logger.log(f'predictor acc: {round(config.PREDICT_CORRECT / config.PREDICT_TOTAL, 2)}') 244 | logger.log(f'step number: {step_num}') 245 | 246 | if not args.pred: 247 | logger.log(f'k = {args.k}') 248 | traj_file = traj_file.format(task_id) 249 | else: 250 | logger.log('dynamic k') 251 | 252 | executor.save_trajectory(traj_file, config.ENABLE_TRAIN) 253 | 254 | return steps 255 | 256 | def main(): 257 | """Main function to run the travel planner speculative planning system.""" 258 | os.environ['DEEPSEEK_API_KEY'] = "" 259 | os.environ['OPENAI_API_KEY'] = "" 260 | os.environ['GOOGLE_API_KEY'] = "" 261 | 262 | args = parse_args() 263 | 264 | config.initialize_from_args(args) 265 | 266 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 267 | 268 | executor, traj_file = setup_executor(args, device) 269 | 270 | app_agent, tar_agent = setup_agents(args) 271 | 272 | encoding = tiktoken.get_encoding("cl100k_base") 273 | 274 | planner = TravelPlannerSpeculativePlanner( 275 | collector=executor.collector, 276 | encoding=encoding, 277 | executor=executor, 278 | app_agent=app_agent, 279 | tar_agent=tar_agent, 280 | config=config 281 | ) 282 | 283 | if args.set_type == "validation": 284 | query_data_list = load_dataset("osunlp/TravelPlanner", "validation")["validation"] 285 | elif args.set_type == "test": 286 | query_data_list = load_dataset("osunlp/TravelPlanner", "test")["test"] 287 | 288 | 289 | task_ids = list(range(args.s_task, 180)) 290 | 291 | warmup_task = 0 292 | for task_id in task_ids: 293 | if warmup_task >= config.WARMUP: 294 | config.ENABLE_PRED = args.pred 295 | run_one_task(args, task_id, executor, encoding, planner, traj_file, query_data_list) 296 | warmup_task += 1 297 | 298 | if __name__ == "__main__": 299 | main() -------------------------------------------------------------------------------- /OpenAGI/agents.py: -------------------------------------------------------------------------------- 1 | """Agent implementations for the speculative planning system.""" 2 | from typing import List, Union 3 | from autogen import AssistantAgent 4 | from .openagi_utils import parse_response 5 | import os 6 | 7 | class BaseAgent: 8 | """Base agent class with direct generation approach.""" 9 | def __init__(self, assistant: AssistantAgent, logger, encoding=None, config=None): 10 | self.assistant = assistant 11 | self.logger = logger 12 | self.encoding = encoding 13 | self.config = config 14 | 15 | async def generate(self, prompt: str, step_number: int) -> str: 16 | """Generate a response using direct approach.""" 17 | prompt += f"\n\nDirectly tell me what **the ONE NEXT action step** based on the current action trajectory should be. (Remember to use xml tag and for formatting.)\nWhat should be the action in Step {step_number}?\n\nStep {step_number}:" 18 | 19 | # Record prompt tokens 20 | prompt_token = len(self.encoding.encode(prompt)) 21 | self.config.TOTAL_TOKEN_PROMPT += prompt_token 22 | self.config.APPROX_SP_PROMPT += prompt_token 23 | self.config.APPROX_NORMAL_PROMPT[step_number] = prompt_token 24 | self.logger.log(f'Step {step_number} prompt token: {prompt_token}') 25 | 26 | # Generate response 27 | response = await self.assistant.a_generate_reply( 28 | messages=[{'content': prompt, 'role': 'user'}] 29 | ) 30 | 31 | # Record generation tokens 32 | response_token = len(self.encoding.encode(response)) 33 | self.config.TOTAL_TOKEN_GENERATION += response_token 34 | self.config.APPROX_SP_GENERATION += response_token 35 | self.config.APPROX_NORMAL_GENERATION[step_number] = response_token 36 | self.logger.log(f'Step {step_number} generation token: {response_token}') 37 | 38 | return parse_response(response) 39 | 40 | class ReActAgent(BaseAgent): 41 | """Agent that uses ReAct approach for generation.""" 42 | 43 | async def generate(self, prompt: str, step_number: int) -> str: 44 | """Generate response using ReAct approach.""" 45 | # First generate thought 46 | thought_prompt = prompt + "\n\nCarefully think about **the ONE NEXT action step** based on the current action trajectory." 47 | thought_prompt += f"\nGenerate thought only.\nThought {step_number}:" 48 | 49 | self.logger.log('react launch api for thought.') 50 | 51 | # Record thought prompt tokens 52 | thought_prompt_token = len(self.encoding.encode(thought_prompt)) 53 | self.config.TARGET_NORMAL_PROMPT[step_number] = thought_prompt_token 54 | self.config.TOTAL_TOKEN_PROMPT += thought_prompt_token 55 | self.logger.log(f"Target step {step_number} thought prompt token: {thought_prompt_token}") 56 | 57 | # Generate thought 58 | thought = await self.assistant.a_generate_reply( 59 | messages=[{'content': thought_prompt, 'role': 'user'}] 60 | ) 61 | 62 | # Record thought generation tokens 63 | thought_token = len(self.encoding.encode(thought)) 64 | self.config.TOTAL_TOKEN_GENERATION += thought_token 65 | self.config.TARGET_NORMAL_GENERATION[step_number] = thought_token 66 | self.logger.log(f"Target step {step_number} thought generation token: {thought_token}") 67 | 68 | # Generate action based on thought 69 | action_prompt = thought_prompt + " " + thought 70 | action_prompt += f"\nGenerate Action only based on thoughts. Remember to use xml tag and for formatting. \nAction {step_number}:" 71 | 72 | self.logger.log('react launch api for response.') 73 | 74 | # Record action prompt tokens 75 | action_prompt_token = len(self.encoding.encode(action_prompt)) 76 | self.config.TARGET_NORMAL_PROMPT[step_number] += action_prompt_token 77 | self.config.TOTAL_TOKEN_PROMPT += action_prompt_token 78 | self.logger.log(f"Target step {step_number} response prompt token: {action_prompt_token}") 79 | 80 | # Generate action 81 | response = await self.assistant.a_generate_reply( 82 | messages=[{'content': action_prompt, 'role': 'user'}] 83 | ) 84 | 85 | # Record action generation tokens 86 | response_token = len(self.encoding.encode(response)) 87 | self.config.TOTAL_TOKEN_GENERATION += response_token 88 | self.config.TARGET_NORMAL_GENERATION[step_number] += response_token 89 | self.logger.log(f"Target step {step_number} response generation token: {response_token}") 90 | 91 | return parse_response(response) 92 | 93 | class CoTAgent(BaseAgent): 94 | """Agent that uses Chain of Thought approach for generation.""" 95 | 96 | async def generate(self, prompt: str, step_number: int) -> str: 97 | """Generate response using Chain of Thought approach.""" 98 | prompt += "\n\nCarefully think about **the ONE NEXT action step** based on the current action trajectory, by first providing a clear reasoning chain. And then decide which tool to use for the current step." 99 | prompt += f"\nWhat should be the action in Step {step_number}? Remember to use xml tag and for formatting.\nStep {step_number}:" 100 | 101 | # Record prompt tokens 102 | prompt_token = len(self.encoding.encode(prompt)) 103 | self.config.TOTAL_TOKEN_PROMPT += prompt_token 104 | self.config.APPROX_SP_PROMPT += prompt_token 105 | self.config.APPROX_NORMAL_PROMPT[step_number] = prompt_token 106 | self.logger.log(f'Step {step_number} prompt token: {prompt_token}') 107 | 108 | # Generate response 109 | response = await self.assistant.a_generate_reply( 110 | messages=[{'content': prompt, 'role': 'user'}] 111 | ) 112 | 113 | # Record generation tokens 114 | response_token = len(self.encoding.encode(response)) 115 | self.config.TOTAL_TOKEN_GENERATION += response_token 116 | self.config.APPROX_SP_GENERATION += response_token 117 | self.config.APPROX_NORMAL_GENERATION[step_number] = response_token 118 | self.logger.log(f'Step {step_number} generation token: {response_token}') 119 | 120 | return parse_response(response) 121 | 122 | class MultiAgent(BaseAgent): 123 | """Agent that uses multi-agent discussion for generation.""" 124 | 125 | def __init__(self, assistants: List[AssistantAgent], logger, encoding=None, config=None): 126 | if not isinstance(assistants, list) or len(assistants) != 2: 127 | raise ValueError("MultiAgent requires exactly 2 assistants") 128 | super().__init__(assistants[0], logger, encoding, config) 129 | self.assistant_b = assistants[1] 130 | 131 | async def generate(self, prompt: str, step_number: int) -> str: 132 | """Generate response using multi-agent discussion.""" 133 | # First round of discussion 134 | prompt4A = prompt + f"\n\nYou will discuss with another agent about **the ONE NEXT action step** based on the current action trajectory. Please provide your thought and answer first.\nAction {step_number}:" 135 | 136 | # Record A1 prompt tokens 137 | prompt4A_token = len(self.encoding.encode(prompt4A)) 138 | self.config.TARGET_NORMAL_PROMPT[step_number] = prompt4A_token 139 | self.config.TOTAL_TOKEN_PROMPT += prompt4A_token 140 | self.logger.log(f"Target step {step_number} thought A prompt token: {prompt4A_token}") 141 | 142 | # Generate A1 response 143 | thoughtA = await self.assistant.a_generate_reply( 144 | messages=[{'content': prompt4A, 'role': 'user'}] 145 | ) 146 | if isinstance(thoughtA, dict): 147 | thoughtA = thoughtA['content'] 148 | 149 | # Record A1 generation tokens 150 | thoughtA_token = len(self.encoding.encode(thoughtA)) 151 | self.config.TOTAL_TOKEN_GENERATION += thoughtA_token 152 | self.config.TARGET_NORMAL_GENERATION[step_number] = thoughtA_token 153 | self.logger.log(f"Target step {step_number} thought A generation token: {thoughtA_token}") 154 | 155 | # B1's turn 156 | prompt4B = prompt + f"\n\nYou are discussing with another agent about **the ONE NEXT action step** based on the current action trajectory.\nThe other agent's idea about this step is {thoughtA}.\nPlease think about whether the other agent's thought and idea is useful, and then provide your thought and answer now.\nAction {step_number}:" 157 | 158 | # Record B1 prompt tokens 159 | prompt4B_token = len(self.encoding.encode(prompt4B)) 160 | self.config.TARGET_NORMAL_PROMPT[step_number] += prompt4B_token 161 | self.config.TOTAL_TOKEN_PROMPT += prompt4B_token 162 | self.logger.log(f"Target step {step_number} thought B prompt token: {prompt4B_token}") 163 | 164 | # Generate B1 response 165 | thoughtB = await self.assistant_b.a_generate_reply( 166 | messages=[{'content': prompt4B, 'role': 'user'}] 167 | ) 168 | if isinstance(thoughtB, dict): 169 | thoughtB = thoughtB['content'] 170 | 171 | # Record B1 generation tokens 172 | thoughtB_token = len(self.encoding.encode(thoughtB)) 173 | self.config.TOTAL_TOKEN_GENERATION += thoughtB_token 174 | self.config.TARGET_NORMAL_GENERATION[step_number] += thoughtB_token 175 | self.logger.log(f"Target step {step_number} thought B generation token: {thoughtB_token}") 176 | 177 | # Second round - A2's turn 178 | prompt4A = prompt + f"\n\nYou are discussing with another agent about **the ONE NEXT action step** based on the current action trajectory.\nYour original thought and idea about this step is {thoughtA}.\nThe other agent's thought and idea about this step is {thoughtB}.\nPlease summarize and reflect, and then update your thought on what this step should be after updating.\nAction {step_number}:" 179 | 180 | # Record A2 prompt tokens 181 | prompt4A_token = len(self.encoding.encode(prompt4A)) 182 | self.config.TARGET_NORMAL_PROMPT[step_number] += prompt4A_token 183 | self.config.TOTAL_TOKEN_PROMPT += prompt4A_token 184 | self.logger.log(f"Target step {step_number} round 2 thought A prompt token: {prompt4A_token}") 185 | 186 | # Generate A2 response 187 | thoughtA = await self.assistant.a_generate_reply( 188 | messages=[{'content': prompt4A, 'role': 'user'}] 189 | ) 190 | if isinstance(thoughtA, dict): 191 | thoughtA = thoughtA['content'] 192 | 193 | # Record A2 generation tokens 194 | thoughtA_token = len(self.encoding.encode(thoughtA)) 195 | self.config.TOTAL_TOKEN_GENERATION += thoughtA_token 196 | self.config.TARGET_NORMAL_GENERATION[step_number] += thoughtA_token 197 | self.logger.log(f"Target step {step_number} round 2 thought A generation token: {thoughtA_token}") 198 | 199 | # B2's turn 200 | prompt4B = prompt + f"\n\nYou are discussing with another agent about **the ONE NEXT action step** based on the current action trajectory.\nYour original thought and idea about this step is {thoughtB}.\nThe other agent's thought and idea about this step is {thoughtA}.\nPlease summarize and reflect, and then update your thought on what this step should be after updating.\nAction {step_number}:" 201 | 202 | # Record B2 prompt tokens 203 | prompt4B_token = len(self.encoding.encode(prompt4B)) 204 | self.config.TARGET_NORMAL_PROMPT[step_number] += prompt4B_token 205 | self.config.TOTAL_TOKEN_PROMPT += prompt4B_token 206 | self.logger.log(f"Target step {step_number} round 2 thought B prompt token: {prompt4B_token}") 207 | 208 | # Generate B2 response 209 | thoughtB = await self.assistant_b.a_generate_reply( 210 | messages=[{'content': prompt4B, 'role': 'user'}] 211 | ) 212 | if isinstance(thoughtB, dict): 213 | thoughtB = thoughtB['content'] 214 | 215 | # Record B2 generation tokens 216 | thoughtB_token = len(self.encoding.encode(thoughtB)) 217 | self.config.TOTAL_TOKEN_GENERATION += thoughtB_token 218 | self.config.TARGET_NORMAL_GENERATION[step_number] += thoughtB_token 219 | self.logger.log(f"Target step {step_number} round 2 thought B generation token: {thoughtB_token}") 220 | 221 | # Final round - Generate action 222 | final_prompt = prompt + f"\n\nYou are discussing with another agent about **the ONE NEXT action step** based on the current action trajectory.\nYour original thought and idea about this step is {thoughtA}.\nThe other agent's thought and idea about this step is {thoughtB}.\nPlease summarize and reflect, and then provide your thought and answer for what this step should be.\nAction {step_number}:" 223 | 224 | # Record final prompt tokens 225 | final_prompt_token = len(self.encoding.encode(final_prompt)) 226 | self.config.TARGET_NORMAL_PROMPT[step_number] += final_prompt_token 227 | self.config.TOTAL_TOKEN_PROMPT += final_prompt_token 228 | self.logger.log(f"Target step {step_number} final prompt token: {final_prompt_token}") 229 | 230 | # Generate final response 231 | response = await self.assistant.a_generate_reply( 232 | messages=[{'content': final_prompt, 'role': 'user'}] 233 | ) 234 | if isinstance(response, dict): 235 | response = response['content'] 236 | 237 | # Record final generation tokens 238 | response_token = len(self.encoding.encode(response)) 239 | self.config.TOTAL_TOKEN_GENERATION += response_token 240 | self.config.TARGET_NORMAL_GENERATION[step_number] += response_token 241 | self.logger.log(f"Target step {step_number} final generation token: {response_token}") 242 | 243 | return parse_response(response) 244 | 245 | def create_agent(agent_type: str, assistant: Union[AssistantAgent, List[AssistantAgent]], logger, encoding=None, config=None) -> BaseAgent: 246 | """Factory function to create appropriate agent type.""" 247 | if agent_type == 'react': 248 | return ReActAgent(assistant, logger, encoding, config) 249 | elif agent_type == 'multi_agent': 250 | if not isinstance(assistant, list): 251 | raise ValueError("MultiAgent requires a list of assistants") 252 | return MultiAgent(assistant, logger, encoding, config) 253 | elif agent_type == 'cot': 254 | return CoTAgent(assistant, logger, encoding, config) 255 | else: # direct 256 | return BaseAgent(assistant, logger, encoding, config) 257 | 258 | def create_agent_config(model_name, api_key, api_type="openai", base_url=None): 259 | """Create agent configuration.""" 260 | config = { 261 | "model": model_name, 262 | "api_key": api_key, 263 | "api_type": api_type, 264 | "cache_seed": None, 265 | "temperature": 0, 266 | "top_p": 1.0, 267 | "seed": 0 268 | } 269 | 270 | if base_url: 271 | config["base_url"] = base_url 272 | 273 | return [config] 274 | 275 | def setup_assistants(model_type): 276 | """Set up base approximation and target agents.""" 277 | if model_type == "deepseek": 278 | app_config = create_agent_config( 279 | "deepseek-chat", 280 | os.environ['DEEPSEEK_API_KEY'], 281 | api_type="deepseek", 282 | base_url="https://api.deepseek.com/" 283 | ) 284 | tar_config = create_agent_config( 285 | "deepseek-reasoner", 286 | os.environ['DEEPSEEK_API_KEY'], 287 | api_type="deepseek", 288 | base_url="https://api.deepseek.com/" 289 | ) 290 | else: 291 | # For OpenAI models 292 | app_config = create_agent_config(model_type, os.environ['OPENAI_API_KEY']) 293 | tar_config = create_agent_config(model_type, os.environ['OPENAI_API_KEY']) 294 | 295 | app_assistant = AssistantAgent( 296 | "assistant", 297 | llm_config={"config_list": app_config}, 298 | human_input_mode='NEVER' 299 | ) 300 | 301 | tar_assistant = AssistantAgent( 302 | "assistant", 303 | llm_config={"config_list": tar_config}, 304 | human_input_mode='NEVER' 305 | ) 306 | 307 | return app_assistant, tar_assistant 308 | 309 | def setup_multi_agent(model_type, tar_assistant): 310 | """Set up multi-agent configuration for target agent.""" 311 | if model_type == "deepseek": 312 | tar_config = create_agent_config( 313 | "deepseek-reasoner", 314 | os.environ['DEEPSEEK_API_KEY'], 315 | api_type="deepseek", 316 | base_url="https://api.deepseek.com" 317 | ) 318 | else: 319 | # For OpenAI models 320 | tar_config = create_agent_config(model_type, os.environ['OPENAI_API_KEY']) 321 | 322 | tar_assistantB = AssistantAgent( 323 | "assistant", 324 | llm_config={"config_list": tar_config}, 325 | human_input_mode='NEVER' 326 | ) 327 | 328 | return [tar_assistant, tar_assistantB] -------------------------------------------------------------------------------- /OpenAGI/async_online_utils.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from torch.utils.data import Dataset, DataLoader 3 | from collections import deque 4 | import torch 5 | import random 6 | from typing import List, Tuple 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from concurrent.futures import ThreadPoolExecutor 10 | import asyncio 11 | import copy 12 | import threading 13 | from threading import Event 14 | import time 15 | import json 16 | import hashlib 17 | import os 18 | 19 | class ExpectileLoss(nn.Module): 20 | def __init__(self, tau=0.5): 21 | super().__init__() 22 | self.tau = tau 23 | 24 | def forward(self, G_t, pred_k): 25 | diff = G_t - pred_k 26 | weight = torch.where(diff > 0, self.tau, (1 - self.tau)) 27 | return torch.mean(weight * diff ** 2) 28 | 29 | class OnlineLearningExecutor: 30 | def __init__(self, 31 | device, 32 | model_save_path, 33 | model: nn.Module, 34 | tokenizer, 35 | buffer_size=2500, 36 | batch_size_steps=32, 37 | max_length=512, 38 | lambda_=0.95, 39 | gamma=1, 40 | lr=1e-5, 41 | epoch_per_train=3, 42 | load=False, 43 | traj_file=None, 44 | tau = 0.5 45 | ): 46 | self.device = device 47 | self.tokenizer = tokenizer 48 | self.buffer = FiniteReplay(tokenizer=tokenizer, model=model, max_length=max_length, replay_size=buffer_size) 49 | self.collector = OnlineTrajectoryCollector(self.buffer) 50 | self.batch_size_steps = batch_size_steps 51 | if tau == 0.5: 52 | self.criterion = nn.MSELoss() 53 | else: self.criterion = ExpectileLoss(tau) 54 | 55 | self.max_length = max_length 56 | self.lambda_ = lambda_ 57 | self.gamma = gamma 58 | self.epoch_per_train = epoch_per_train 59 | 60 | self.train_model = model.to(device) 61 | self.predict_model = copy.deepcopy(model).to(device) 62 | self.train_executor = ThreadPoolExecutor(max_workers=1) 63 | self.optimizer = optim.AdamW(self.train_model.parameters(), lr=lr) 64 | self.train_lock = asyncio.Lock() 65 | self.model_lock = threading.Lock() 66 | self.model_save_path = model_save_path 67 | if load: 68 | self.load_checkpoint(model_save_path) 69 | self.load_from_file(traj_file) 70 | 71 | def load_checkpoint(self, file_path): 72 | checkpoint = torch.load(file_path, map_location=self.device) 73 | self.train_model.load_state_dict(checkpoint["model_state_dict"]) 74 | self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 75 | self.predict_model.load_state_dict(self.train_model.state_dict()) 76 | print("load weights from last breakingpoint") 77 | 78 | def _compute_lambda_return(self, input_ids, attention_masks, rewards): 79 | self.train_model.eval() 80 | T = len(input_ids) 81 | G_lambda = torch.zeros(T) 82 | 83 | G_t = 0 84 | for t in reversed(range(T)): 85 | input_id = input_ids[t].unsqueeze(0) 86 | attention_mask = attention_masks[t].unsqueeze(0) 87 | v_pred = self.train_model(input_id, attention_mask).item() 88 | mask = rewards[t] 89 | G_t = rewards[t] + self.gamma * (1 - self.lambda_) * v_pred * mask + self.gamma * self.lambda_ * G_t * mask 90 | G_lambda[t] = G_t 91 | self.train_model.train() 92 | return G_lambda 93 | 94 | def _flat_batch_data(self, input_ids_batch, attention_mask_batch, rewards_batch, gt_k_batch): 95 | all_G_lambda = [] 96 | all_input_ids = [] 97 | all_attention_mask = [] 98 | all_gt_k = [] 99 | # compute current G_lambda for batch data and flatten steps 100 | for i in range(len(input_ids_batch)): 101 | input_ids = input_ids_batch[i] 102 | attention_mask = attention_mask_batch[i] 103 | rewards = rewards_batch[i] 104 | 105 | with torch.no_grad(): 106 | # G_lambda for each trajectory 107 | G_lambda = self._compute_lambda_return(input_ids.to(self.device), attention_mask.to(self.device), rewards.to(self.device)) 108 | # flatten inputs and targets 109 | all_G_lambda.extend(G_lambda) 110 | all_input_ids.extend(input_ids) 111 | all_attention_mask.extend(attention_mask) 112 | all_gt_k.extend(gt_k_batch[i]) 113 | 114 | 115 | flat_input_ids = torch.stack(all_input_ids, dim=0).to(self.device) 116 | flat_attention_mask = torch.stack(all_attention_mask, dim=0).to(self.device) 117 | flat_G_lambda = torch.tensor(all_G_lambda).to(self.device) 118 | flat_gt_k = torch.stack(all_gt_k, dim=0).to(self.device) 119 | return flat_input_ids, flat_attention_mask, flat_G_lambda, flat_gt_k 120 | 121 | async def async_train(self): 122 | async with self.train_lock: 123 | loop = asyncio.get_running_loop() 124 | current_train_task = loop.run_in_executor( 125 | self.train_executor, 126 | self._train 127 | ) 128 | try: 129 | await current_train_task 130 | except Exception as e: 131 | pass 132 | 133 | def _train(self): 134 | batch = self.buffer.sample(self.batch_size_steps) 135 | for epoch in range(self.epoch_per_train): 136 | self.train_model.train() 137 | 138 | if batch is None: 139 | return 140 | input_ids_batch, attention_mask_batch, rewards_batch, gt_k_batch = batch 141 | 142 | flat_input_ids, flat_attention_mask, flat_G_lambda, flat_gt_k = self._flat_batch_data( 143 | input_ids_batch, attention_mask_batch, rewards_batch, gt_k_batch) 144 | 145 | self.optimizer.zero_grad() 146 | k_pred = self.train_model(flat_input_ids, flat_attention_mask).squeeze() 147 | loss = self.criterion(flat_G_lambda, k_pred) 148 | 149 | loss.backward() 150 | self.optimizer.step() 151 | 152 | with self.model_lock: 153 | self.predict_model.load_state_dict(self.train_model.state_dict()) 154 | checkpoint = { 155 | "model_state_dict": self.train_model.state_dict(), 156 | "optimizer_state_dict": self.optimizer.state_dict(), 157 | } 158 | torch.save(checkpoint, self.model_save_path) 159 | 160 | 161 | async def async_predict(self, logger, k_offset): 162 | loop = asyncio.get_running_loop() 163 | return await loop.run_in_executor( 164 | None, 165 | self._predict, 166 | logger, 167 | k_offset 168 | ) 169 | 170 | def _predict(self, logger, k_offset): 171 | self.predict_model.eval() 172 | 173 | state = self.collector.get_current_trajectory() 174 | with self.model_lock: 175 | 176 | inputs = self.tokenizer( 177 | state, 178 | return_tensors="pt", 179 | truncation=True, 180 | padding='max_length', 181 | max_length=self.max_length 182 | ).to(self.device) 183 | k_pred = self.predict_model(**inputs).squeeze().cpu() 184 | k = int(torch.round(k_pred).item()) + k_offset 185 | logger.log(f'Predict K: {k}') 186 | return k 187 | 188 | def set_initial_task_prompt(self, initial_task): 189 | self.collector.set_initial_task_prompt(initial_task) 190 | 191 | def save_trajectory(self, file_path, append): 192 | self.collector.save_trajectory(file_path, append) 193 | 194 | def load_from_file(self, file_path): 195 | self.buffer.load_from_file(file_path) 196 | 197 | 198 | class OnlineTrajectoryCollector: 199 | def __init__(self, replay_buffer: "FiniteReplay"): 200 | self.traj_list = [] 201 | self.reset() 202 | self.replay_buffer = replay_buffer 203 | 204 | def reset(self): 205 | self.approx_logs = [] # approximation_steps in current breakingpoint 206 | self.target_logs = [] # target steps in current breakingpoint 207 | 208 | def set_initial_task_prompt(self, initial_task): 209 | self.initial_task = initial_task 210 | self.prefix = initial_task 211 | 212 | def _reset_trajectory(self): 213 | self.approx_logs = [] 214 | self.target_logs = [] 215 | 216 | def record_step(self, 217 | timestamp: datetime, 218 | source: str, # "approximation" or "target" 219 | step: int, 220 | description: str = None): 221 | 222 | if source.strip() == "Approximation": 223 | self.approx_logs.append((timestamp, source.strip(), step, description)) 224 | else: 225 | self.target_logs.append((timestamp, source.strip(), step, description)) 226 | 227 | def build_trajectory(self, logger, predict_ks): 228 | # when reach a breakingpoint, call build_trajectory 229 | if len(self.approx_logs) == 0: 230 | return 0 231 | # combined_logs = sorted(self.approx_logs + self.target_logs, key=lambda x: (x[0], x[2], 0 if x[1] == "Approximation" else 1)) 232 | target_logs_sorted_by_step = sorted(self.target_logs, key=lambda x: x[2]) 233 | i, j = 0, 0 234 | trajectory = [] 235 | gt_k = 0 236 | 237 | while j < len(target_logs_sorted_by_step) and i < len(self.approx_logs): 238 | target_timestamp, _, target_step, target_desc = target_logs_sorted_by_step[j] 239 | approx_timestamp, _, approx_step, approx_desc = self.approx_logs[i] 240 | if approx_desc == target_desc and j != len(target_logs_sorted_by_step) - 1: 241 | i += 1 242 | j += 1 243 | else: # mismatch, breakingpoint 244 | # iterate thru all approx tasks in this breakingpoint 245 | # generate traj from bp start to cur_approx_task(not included) 246 | bp_state = [ 247 | f"Step {log[2]}: {log[3]}" 248 | for log in self.approx_logs[:i] 249 | ] + [f"Step {target_step}: {target_desc}"] 250 | approx_index = 0 251 | while approx_index < len(self.approx_logs): 252 | cur_approx_timestamp, _, cur_approx_step, cur_approx_desc = self.approx_logs[approx_index] 253 | current_k = max(0, target_step - (cur_approx_step - 1)) 254 | gt_k = max(gt_k, current_k) 255 | state = [] 256 | 257 | for log in self.approx_logs[:approx_index]: 258 | state.append(f"Step {log[2]}: {log[3]}") 259 | 260 | reward = 1 if current_k > 0 else 0 261 | # construct prompt 262 | if state: 263 | if self.prefix == self.initial_task: 264 | self.prefix += "\nPrevious History:\n" 265 | prompt = self.prefix + "\n".join(state) + "\n" 266 | else: 267 | prompt = self.prefix 268 | trajectory.append({ 269 | "state": prompt, 270 | "reward": reward, 271 | "k": current_k 272 | }) 273 | approx_index += 1 274 | break 275 | # update prefix 276 | self.prefix += "\n".join(bp_state) + "\n" 277 | # add trajectory to replay buffer 278 | self.traj_list.append({"trajectory": trajectory}) 279 | self.replay_buffer.add_trajectory(trajectory) 280 | self._reset_trajectory() 281 | 282 | if len(predict_ks) == 1: 283 | return 1 if (predict_ks[0] == gt_k or predict_ks[0] == gt_k+1) else 0 284 | else: 285 | acc = 0 286 | for i in range(len(predict_ks)): 287 | if predict_ks[i] == gt_k or predict_ks[i] == gt_k+1: 288 | acc += 1 289 | gt_k -= predict_ks[i] 290 | return acc 291 | 292 | def get_current_trajectory(self): 293 | bp_state = [f"Step {log[2]}: {log[3]}" for log in self.approx_logs] 294 | if self.prefix == self.initial_task and bp_state: 295 | prefix = self.prefix + "\nPrevious History:\n" 296 | else: prefix = self.prefix 297 | prefix += "\n".join(bp_state) + "\n" 298 | 299 | return prefix 300 | 301 | def save_trajectory(self, file_path, append): 302 | if append: 303 | with open(file_path, 'a', encoding='utf-8') as f: 304 | for traj in self.traj_list: 305 | f.write(json.dumps(traj) + '\n') 306 | else: 307 | with open(file_path, 'w', encoding='utf-8') as f: 308 | json.dump(self.traj_list, f, indent=2) 309 | self.traj_list = [] 310 | 311 | class FiniteReplay: 312 | def __init__(self, tokenizer, model, max_length=512, replay_size: int = 100): 313 | self.tokenizer = tokenizer 314 | self.model = model 315 | self.max_length = max_length 316 | self.replay_size = replay_size 317 | 318 | self.replay_buffer: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = [None] * replay_size 319 | self.trajectory_lengths: List[int] = [0] * replay_size 320 | self.pos = 0 321 | self.full = False 322 | 323 | def __len__(self) -> int: 324 | return self.replay_size if self.full else self.pos 325 | 326 | def load_from_file(self, file_path: str): 327 | if not os.path.exists(file_path): 328 | raise FileNotFoundError(f"{file_path} not found.") 329 | 330 | if file_path.endswith(".json"): 331 | with open(file_path, "r", encoding="utf-8") as f: 332 | data = json.load(f) 333 | if isinstance(data, list): 334 | for traj_entry in data: 335 | self.add_trajectory(traj_entry["trajectory"]) 336 | elif file_path.endswith(".ndjson"): 337 | with open(file_path, 'r') as f: 338 | for line in f: 339 | traj = json.loads(line) 340 | self.add_trajectory(traj['trajectory']) 341 | else: 342 | raise ValueError("Unsupported file format. Use .json or .ndjson") 343 | 344 | def add_trajectory(self, trajectory: List[dict]): 345 | processed = self._preprocess_trajectory(trajectory) 346 | self.replay_buffer[self.pos] = processed 347 | self.trajectory_lengths[self.pos] = len(trajectory) 348 | self.pos = (self.pos + 1) % self.replay_size 349 | if self.pos == 0: 350 | self.full = True 351 | 352 | def _preprocess_trajectory(self, trajectory: List[dict]): 353 | states = [step["state"] for step in trajectory] 354 | rewards = [step["reward"] for step in trajectory] 355 | gt_ks = [step["k"] for step in trajectory] 356 | 357 | inputs = self.tokenizer( 358 | states, 359 | return_tensors="pt", 360 | truncation=True, 361 | padding='max_length', 362 | max_length=self.max_length 363 | ) 364 | input_ids = inputs['input_ids'] 365 | attention_masks = inputs['attention_mask'] 366 | rewards = torch.tensor(rewards, dtype=torch.float32) 367 | gt_ks = torch.tensor(gt_ks, dtype=torch.float32) 368 | 369 | return (input_ids, attention_masks, rewards, gt_ks) 370 | 371 | def sample(self, batch_size_steps: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 372 | indices = [] 373 | total_steps = 0 374 | available_indices = list(range(len(self))) 375 | random.shuffle(available_indices) 376 | 377 | input_ids_list = [] 378 | attention_mask_list = [] 379 | rewards_list = [] 380 | gt_k_list = [] 381 | total_steps = 0 382 | 383 | chosen_idx = [] 384 | for idx in available_indices: 385 | input_ids, attention_mask, rewards, gt_ks = self.replay_buffer[idx] 386 | input_ids_list.append(input_ids) 387 | attention_mask_list.append(attention_mask) 388 | rewards_list.append(rewards) 389 | gt_k_list.append(gt_ks) 390 | total_steps += self.trajectory_lengths[idx] 391 | chosen_idx.append(idx) 392 | # use all datapoint to train 393 | if total_steps >= batch_size_steps: 394 | break 395 | if total_steps < batch_size_steps: 396 | return None 397 | 398 | return input_ids_list, attention_mask_list, rewards_list, gt_k_list 399 | 400 | class SharedState: 401 | def __init__(self): 402 | mismatch_detected = None 403 | mismatch_step_id = None 404 | 405 | async def initialize(self): 406 | self.mismatch_detected = asyncio.Event() -------------------------------------------------------------------------------- /travelplanner_supplement/async_online_utils.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import torch 3 | import random 4 | from typing import List, Tuple 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from concurrent.futures import ThreadPoolExecutor 8 | import asyncio 9 | import copy 10 | import threading 11 | from threading import Event 12 | import time 13 | import json 14 | import os 15 | import nltk 16 | 17 | class ExpectileLoss(nn.Module): 18 | def __init__(self, tau=0.5): 19 | super().__init__() 20 | self.tau = tau 21 | 22 | def forward(self, G_t, pred_k): 23 | diff = G_t - pred_k 24 | weight = torch.where(diff > 0, self.tau, (1 - self.tau)) 25 | return torch.mean(weight * diff ** 2) 26 | 27 | class OnlineLearningExecutor: 28 | def __init__(self, 29 | device, 30 | model_save_path, 31 | model: nn.Module, 32 | tokenizer, 33 | buffer_size=200, 34 | batch_size_steps=32, 35 | max_length=512, 36 | lambda_=0.9, 37 | gamma=1, 38 | lr=5e-5, 39 | epoch_per_train=5, 40 | load=False, 41 | traj_file=None, 42 | tau = 0.5 43 | ): 44 | self.device = device 45 | self.tokenizer = tokenizer 46 | self.buffer = FiniteReplay(tokenizer=tokenizer, model=model, max_length=max_length, replay_size=buffer_size) 47 | self.collector = OnlineTrajectoryCollector(self.buffer) 48 | self.batch_size_steps = batch_size_steps 49 | if tau == 0.5: 50 | self.criterion = nn.MSELoss() 51 | else: self.criterion = ExpectileLoss(tau) 52 | self.max_length = max_length 53 | self.lambda_ = lambda_ 54 | self.gamma = gamma 55 | self.epoch_per_train = epoch_per_train 56 | 57 | self.train_model = model.to(device) 58 | self.predict_model = copy.deepcopy(model).to(device) 59 | self.train_executor = ThreadPoolExecutor(max_workers=1) 60 | self.optimizer = optim.AdamW(self.train_model.parameters(), lr=lr) 61 | self.train_lock = asyncio.Lock() 62 | self.model_lock = threading.Lock() 63 | self.model_save_path = model_save_path 64 | if load: 65 | self.load_checkpoint(model_save_path) 66 | self.load_from_file(traj_file) 67 | 68 | def load_checkpoint(self, file_path): 69 | checkpoint = torch.load(file_path, map_location=self.device) 70 | self.train_model.load_state_dict(checkpoint["model_state_dict"]) 71 | self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 72 | self.predict_model.load_state_dict(self.train_model.state_dict()) 73 | print("load weights from last breakingpoint") 74 | 75 | def _compute_lambda_return(self, input_ids, attention_masks, rewards): 76 | self.train_model.eval() 77 | T = len(input_ids) 78 | G_lambda = torch.zeros(T) 79 | 80 | G_t = 0 81 | for t in reversed(range(T)): 82 | input_id = input_ids[t].unsqueeze(0) 83 | attention_mask = attention_masks[t].unsqueeze(0) 84 | v_pred = self.train_model(input_id, attention_mask).item() 85 | mask = rewards[t] 86 | G_t = rewards[t] + self.gamma * (1 - self.lambda_) * v_pred * mask + self.gamma * self.lambda_ * G_t * mask 87 | G_lambda[t] = G_t 88 | self.train_model.train() 89 | return G_lambda 90 | 91 | def _flat_batch_data(self, input_ids_batch, attention_mask_batch, rewards_batch, gt_k_batch): 92 | all_G_lambda = [] 93 | all_input_ids = [] 94 | all_attention_mask = [] 95 | all_gt_k = [] 96 | # compute current G_lambda for batch data and flatten steps 97 | for i in range(len(input_ids_batch)): 98 | # [steps_num, hid_dim] 99 | input_ids = input_ids_batch[i] 100 | attention_mask = attention_mask_batch[i] 101 | rewards = rewards_batch[i] 102 | 103 | with torch.no_grad(): 104 | # G_lambda for each trajectory 105 | G_lambda = self._compute_lambda_return(input_ids.to(self.device), attention_mask.to(self.device), rewards.to(self.device)) 106 | # flatten inputs and targets 107 | all_G_lambda.extend(G_lambda) 108 | all_input_ids.extend(input_ids) 109 | all_attention_mask.extend(attention_mask) 110 | all_gt_k.extend(gt_k_batch[i]) 111 | 112 | 113 | # [total_step_num, hid_dim]: total_step_num in this batch 114 | flat_input_ids = torch.stack(all_input_ids, dim=0).to(self.device) 115 | flat_attention_mask = torch.stack(all_attention_mask, dim=0).to(self.device) 116 | # [total_step_num] 117 | flat_G_lambda = torch.tensor(all_G_lambda).to(self.device) 118 | flat_gt_k = torch.stack(all_gt_k, dim=0).to(self.device) 119 | return flat_input_ids, flat_attention_mask, flat_G_lambda, flat_gt_k 120 | 121 | async def async_train(self, logger): 122 | async with self.train_lock: 123 | loop = asyncio.get_running_loop() 124 | current_train_task = loop.run_in_executor( 125 | self.train_executor, 126 | self._train, 127 | logger 128 | ) 129 | try: 130 | await current_train_task 131 | except Exception as e: 132 | logger.log(f"Training failed with exception: {e}") 133 | 134 | def _train(self, logger): 135 | batch = self.buffer.sample(self.batch_size_steps, logger) 136 | for epoch in range(self.epoch_per_train): 137 | self.train_model.train() 138 | 139 | if batch is None: 140 | logger.log("Batch is None, skipping training.") 141 | return 142 | input_ids_batch, attention_mask_batch, rewards_batch, gt_k_batch = batch 143 | 144 | flat_input_ids, flat_attention_mask, flat_G_lambda, flat_gt_k = self._flat_batch_data( 145 | input_ids_batch, attention_mask_batch, rewards_batch, gt_k_batch) 146 | 147 | self.optimizer.zero_grad() 148 | k_pred = self.train_model(flat_input_ids, flat_attention_mask).squeeze() 149 | loss = self.criterion(flat_G_lambda, k_pred) 150 | diff = torch.round(k_pred) - flat_gt_k 151 | acc = ((diff == 0) | (diff == 1)).float().mean().item() 152 | logger.log(f'Epoch {epoch + 1} - Loss: {loss.item()} - Accuracy: {round(acc * 100, 2)}%') 153 | 154 | loss.backward() 155 | self.optimizer.step() 156 | 157 | with self.model_lock: 158 | self.predict_model.load_state_dict(self.train_model.state_dict()) 159 | checkpoint = { 160 | "model_state_dict": self.train_model.state_dict(), 161 | "optimizer_state_dict": self.optimizer.state_dict(), 162 | } 163 | torch.save(checkpoint, self.model_save_path) 164 | 165 | 166 | async def async_predict(self, logger, k_offset): 167 | loop = asyncio.get_running_loop() 168 | return await loop.run_in_executor( 169 | None, 170 | self._predict, 171 | logger, 172 | k_offset 173 | ) 174 | 175 | def _predict(self, logger, k_offset): 176 | self.predict_model.eval() 177 | 178 | state = self.collector.get_current_trajectory() 179 | logger.log(f"Predict state: {state}") 180 | with self.model_lock: 181 | inputs = self.tokenizer( 182 | state, 183 | return_tensors="pt", 184 | truncation=True, 185 | padding='max_length', 186 | max_length=self.max_length 187 | ).to(self.device) 188 | k_pred = self.predict_model(**inputs).squeeze().cpu() 189 | k = int(torch.round(k_pred).item()) + k_offset 190 | logger.log(f'Predict K: {k}') 191 | return k 192 | 193 | def set_initial_task_prompt(self, initial_task): 194 | self.collector.set_initial_task_prompt(initial_task) 195 | 196 | def save_trajectory(self, file_path, append): 197 | self.collector.save_trajectory(file_path, append) 198 | 199 | def load_from_file(self, file_path): 200 | self.buffer.load_from_file(file_path) 201 | 202 | 203 | class OnlineTrajectoryCollector: 204 | def __init__(self, replay_buffer: "FiniteReplay"): 205 | self.traj_list = [] 206 | self.reset() 207 | self.replay_buffer = replay_buffer 208 | 209 | def reset(self): 210 | self.approx_logs = [] # approximation_steps in current breakingpoint 211 | self.target_logs = [] # target steps in current breakingpoint 212 | 213 | def set_initial_task_prompt(self, initial_task): 214 | self.initial_task = initial_task 215 | self.prefix = initial_task+"\n" 216 | 217 | def _reset_trajectory(self): 218 | self.approx_logs = [] 219 | self.target_logs = [] 220 | 221 | def record_step(self, 222 | timestamp: datetime, 223 | source: str, # "approximation" or "target" 224 | step: int, 225 | description: str = None): 226 | 227 | if source.strip() == "Approximation": 228 | self.approx_logs.append((timestamp, source.strip(), step, description)) 229 | else: 230 | self.target_logs.append((timestamp, source.strip(), step, description)) 231 | 232 | 233 | def judge_to_be_true(self, s, t): 234 | try: 235 | approximation_function_name = s.split("[")[0].strip() 236 | target_function_name = t.split("[")[0].strip() 237 | 238 | approximation_function_arg = s[s.index("[") : s.index("]")].strip() 239 | target_function_arg = t[t.index("[") : t.index("]")].strip() 240 | 241 | def token_edit_levenstein_similarity_normalized( 242 | text1: str, text2: str 243 | ) -> float: 244 | """ 245 | Compute the normalized levenstein distance between two texts. 246 | """ 247 | return 1 - nltk.edit_distance(text1, text2) / max(len(text1), len(text2)) 248 | 249 | if approximation_function_name == target_function_name: 250 | if ( 251 | token_edit_levenstein_similarity_normalized( 252 | approximation_function_arg, target_function_arg 253 | ) 254 | > 0.5 255 | ): 256 | return True 257 | 258 | return False 259 | except: 260 | if s == t: 261 | return True 262 | else: 263 | return False 264 | 265 | def build_trajectory(self, logger, predict_ks): 266 | # start = time.time() 267 | # when reach a breakingpoint, call build_trajectory 268 | # print("Build Trajectory") 269 | if len(self.approx_logs) == 0: 270 | return 0 271 | # combined_logs = sorted(self.approx_logs + self.target_logs, key=lambda x: (x[0], x[2], 0 if x[1] == "Approximation" else 1)) 272 | target_logs_sorted_by_step = sorted(self.target_logs, key=lambda x: x[2]) 273 | i, j = 0, 0 274 | trajectory = [] 275 | gt_k = 0 276 | # logger.log(f"\napp_logs: {self.approx_logs}\n") 277 | # logger.log(f"\ntar_logs: {self.target_logs}\n") 278 | # logger.log(f"\nprefix: {self.prefix}\n") 279 | # breakpoint() 280 | 281 | while j < len(target_logs_sorted_by_step) and i < len(self.approx_logs): 282 | target_timestamp, _, target_step, target_desc = target_logs_sorted_by_step[j] 283 | approx_timestamp, _, approx_step, approx_desc = self.approx_logs[i] 284 | if self.judge_to_be_true(approx_desc, target_desc) and j != len(target_logs_sorted_by_step) - 1: 285 | i += 1 286 | j += 1 287 | else: # mismatch, breakingpoint 288 | # iterate thru all approx tasks in this breakingpoint 289 | # generate traj from bp start to cur_approx_task(not included) 290 | bp_state = [ 291 | f"Step {log[2]}: {log[3]}" 292 | for log in self.approx_logs[:i] 293 | ] + [f"Step {target_step}: {target_desc}"] 294 | approx_index = 0 295 | while approx_index < len(self.approx_logs): 296 | cur_approx_timestamp, _, cur_approx_step, cur_approx_desc = self.approx_logs[approx_index] 297 | current_k = max(0, target_step - (cur_approx_step - 1)) 298 | gt_k = max(gt_k, current_k) 299 | state = [] 300 | 301 | for log in self.approx_logs[:approx_index]: 302 | state.append(f"Step {log[2]}: {log[3]}") 303 | 304 | reward = 1 if current_k > 0 else 0 305 | # construct prompt 306 | if state: 307 | # if self.prefix == self.initial_task: 308 | # self.prefix += "\nPrevious History:\n" 309 | prompt = self.prefix + "\n".join(state) + "\n" 310 | else: 311 | prompt = self.prefix 312 | trajectory.append({ 313 | "state": prompt, 314 | "reward": reward, 315 | "k": current_k 316 | }) 317 | approx_index += 1 318 | break 319 | # update prefix 320 | self.prefix += "\n".join(bp_state) + "\n" 321 | logger.log(f"trajectory: {trajectory}") 322 | # add trajectory to replay buffer 323 | self.traj_list.append({"trajectory": trajectory}) 324 | self.replay_buffer.add_trajectory(trajectory) 325 | self._reset_trajectory() 326 | 327 | if len(predict_ks) == 1: 328 | return 1 if (predict_ks[0] == gt_k or predict_ks[0] == gt_k+1) else 0 329 | else: 330 | acc = 0 331 | for i in range(len(predict_ks)): 332 | if predict_ks[i] == gt_k or predict_ks[i] == gt_k+1: 333 | acc += 1 334 | gt_k -= predict_ks[i] 335 | return acc 336 | 337 | def get_current_trajectory(self): 338 | bp_state = [f"Step {log[2]}: {log[3]}" for log in self.approx_logs] 339 | # if self.prefix == self.initial_task and bp_state: 340 | # prefix = self.prefix + "\nPrevious History:\n" 341 | prefix = self.prefix 342 | prefix += "\n".join(bp_state) + "\n" 343 | 344 | return prefix 345 | 346 | def save_trajectory(self, file_path, append): 347 | if append: 348 | with open(file_path, 'a', encoding='utf-8') as f: 349 | for traj in self.traj_list: 350 | f.write(json.dumps(traj) + '\n') 351 | else: 352 | with open(file_path, 'w', encoding='utf-8') as f: 353 | json.dump(self.traj_list, f, indent=2) 354 | self.traj_list = [] 355 | 356 | class FiniteReplay: 357 | def __init__(self, tokenizer, model, max_length=512, replay_size: int = 100): 358 | self.tokenizer = tokenizer 359 | self.model = model 360 | self.max_length = max_length 361 | self.replay_size = replay_size 362 | 363 | self.replay_buffer: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = [None] * replay_size 364 | self.trajectory_lengths: List[int] = [0] * replay_size 365 | self.pos = 0 366 | self.full = False 367 | 368 | def __len__(self) -> int: 369 | return self.replay_size if self.full else self.pos 370 | 371 | def load_from_file(self, file_path: str): 372 | if not os.path.exists(file_path): 373 | raise FileNotFoundError(f"{file_path} not found.") 374 | 375 | if file_path.endswith(".json"): 376 | with open(file_path, "r", encoding="utf-8") as f: 377 | data = json.load(f) 378 | if isinstance(data, list): 379 | for traj_entry in data: 380 | self.add_trajectory(traj_entry["trajectory"]) 381 | elif file_path.endswith(".ndjson"): 382 | with open(file_path, 'r') as f: 383 | for line in f: 384 | traj = json.loads(line) 385 | self.add_trajectory(traj['trajectory']) 386 | else: 387 | raise ValueError("Unsupported file format. Use .json or .ndjson") 388 | 389 | def add_trajectory(self, trajectory: List[dict]): 390 | processed = self._preprocess_trajectory(trajectory) 391 | self.replay_buffer[self.pos] = processed 392 | self.trajectory_lengths[self.pos] = len(trajectory) 393 | self.pos = (self.pos + 1) % self.replay_size 394 | if self.pos == 0: 395 | self.full = True 396 | 397 | def _preprocess_trajectory(self, trajectory: List[dict]): 398 | states = [step["state"] for step in trajectory] 399 | rewards = [step["reward"] for step in trajectory] 400 | gt_ks = [step["k"] for step in trajectory] 401 | 402 | inputs = self.tokenizer( 403 | states, 404 | return_tensors="pt", 405 | truncation=True, 406 | padding='max_length', 407 | max_length=self.max_length 408 | ) 409 | input_ids = inputs['input_ids'] 410 | attention_masks = inputs['attention_mask'] 411 | rewards = torch.tensor(rewards, dtype=torch.float32) 412 | gt_ks = torch.tensor(gt_ks, dtype=torch.float32) 413 | 414 | return (input_ids, attention_masks, rewards, gt_ks) 415 | 416 | def sample(self, batch_size_steps: int, logger) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 417 | indices = [] 418 | total_steps = 0 419 | available_indices = list(range(len(self))) 420 | random.shuffle(available_indices) 421 | 422 | input_ids_list = [] 423 | attention_mask_list = [] 424 | rewards_list = [] 425 | gt_k_list = [] 426 | total_steps = 0 427 | 428 | chosen_idx = [] 429 | for idx in available_indices: 430 | input_ids, attention_mask, rewards, gt_ks = self.replay_buffer[idx] 431 | input_ids_list.append(input_ids) 432 | attention_mask_list.append(attention_mask) 433 | rewards_list.append(rewards) 434 | gt_k_list.append(gt_ks) 435 | total_steps += self.trajectory_lengths[idx] 436 | chosen_idx.append(idx) 437 | # use all datapoint to train 438 | if total_steps >= batch_size_steps: 439 | break 440 | if total_steps < batch_size_steps: 441 | return None 442 | 443 | return input_ids_list, attention_mask_list, rewards_list, gt_k_list 444 | 445 | class SharedState: 446 | def __init__(self): 447 | mismatch_detected = None 448 | mismatch_step_id = None 449 | 450 | async def initialize(self): 451 | self.mismatch_detected = asyncio.Event() -------------------------------------------------------------------------------- /OpenAGI/planner.py: -------------------------------------------------------------------------------- 1 | """Dynamic Speculative Planning implementation.""" 2 | 3 | import asyncio 4 | import time 5 | from datetime import datetime 6 | from typing import List 7 | 8 | from .openagi_utils import judge_to_be_true, concurrent_calls 9 | from util import cancel, register_async_handler 10 | from .config import Config 11 | from .async_online_utils import SharedState 12 | 13 | 14 | class SpeculativePlanner: 15 | def __init__(self, collector, encoding, executor, app_agent, tar_agent, config: Config): 16 | self.collector = collector 17 | self.encoding = encoding 18 | self.executor = executor 19 | self.app_agent = app_agent 20 | self.tar_agent = tar_agent 21 | self.config = config 22 | self.mismatch_state = SharedState() 23 | self.steps = [] 24 | 25 | # Initialize loggers as None - will be set per task 26 | self.logger = None 27 | self.target_logger = None 28 | self.approximation_logger = None 29 | 30 | def set_loggers(self, logger, target_logger, approximation_logger): 31 | self.logger = logger 32 | self.target_logger = target_logger 33 | self.approximation_logger = approximation_logger 34 | 35 | # Update agent loggers 36 | self.app_agent.logger = approximation_logger 37 | self.tar_agent.logger = target_logger 38 | 39 | def log_task_description(self, task_description): 40 | self.logger.log('task description: ' + task_description) 41 | self.target_logger.log('task description: ' + task_description) 42 | self.approximation_logger.log('task description: ' + task_description) 43 | 44 | def log_results(self, steps): 45 | self.logger.log('final result for the speculative planning ' + str(steps)) 46 | self.logger.log('max concurrent calls: ' + str(self.config.MAX_CONCURRENT_CALLS-1)) 47 | 48 | async def initialize(self, args): 49 | await self.mismatch_state.initialize() 50 | 51 | async def process_target(self, sas, tas, to_print_id, prev_steps, target_tasks): 52 | """Process target agent responses and update collectors.""" 53 | s = sas[to_print_id] 54 | t = tas[to_print_id] 55 | self.target_logger.log(f'Target: Step {t[0][0] + len(prev_steps)+1} - {t[0][1]}') 56 | 57 | # Add to online trajectory collector 58 | cur_time = time.time() 59 | timestamp = datetime.fromtimestamp(cur_time) 60 | source = "Target" 61 | step = t[0][0] + len(prev_steps) + 1 62 | desc = t[0][1] 63 | self.collector.record_step(timestamp, source, step, desc) 64 | 65 | self.config.TOTAL_APPROXIMATION_CALLS += 1 66 | if judge_to_be_true(s, t[0][1]): 67 | self.config.TOTAL_CORRECT_APPROXIMATION_CALLS += 1 68 | self.logger.log(f'The target agent thinks step {len(prev_steps) + to_print_id+1} should be {t[0][1]}, which agrees with the approximation agent.') 69 | try: 70 | self.logger.log(f'The approximation agent thinks step {len(prev_steps) + to_print_id+2} should be {sas[to_print_id+1]}') 71 | self.config.HIL_INTERACTION = to_print_id+1 72 | register_async_handler(target_tasks=target_tasks) 73 | except Exception: 74 | pass 75 | else: 76 | self.logger.log(f'The target agent thinks step {len(prev_steps) + to_print_id+1} should be {t[0][1]}, correcting what the approximation agent thinks which is {s}.') 77 | 78 | async def process_on_time_interactions(self, sas, tas, flatten_tas, printed_ids, prev_steps, target_tasks): 79 | """Handle on-time interaction between approximation and target agents.""" 80 | tas_ids = [t[0] for t in flatten_tas] 81 | flatten_printed_ids = [i for ids in printed_ids if ids for i in ids] 82 | 83 | tas_length = len(tas_ids) 84 | 85 | for l in range(0, tas_length+1): 86 | if not(tas_ids[:l] == list(range(len(flatten_tas)))[:l] and len(sas) >= len(flatten_tas[:l])): 87 | tas_length = l-1 88 | break 89 | 90 | contain_wrong_result = any( 91 | not judge_to_be_true(sas[pid], tas[pid][0][1]) for pid in flatten_printed_ids 92 | ) 93 | 94 | if tas_length > 0 and not contain_wrong_result: 95 | tas_ids = tas_ids[:tas_length] 96 | to_print_ids = list(set(tas_ids) - set(flatten_printed_ids)) 97 | for order_id, to_print_id in enumerate(to_print_ids): 98 | if order_id > 0 and not judge_to_be_true( 99 | sas[to_print_ids[order_id-1]], 100 | tas[to_print_ids[order_id-1]][0][1] 101 | ): 102 | break 103 | printed_ids[to_print_id].append(to_print_id) 104 | await self.process_target(sas, tas, to_print_id, prev_steps, target_tasks) 105 | 106 | return printed_ids, tas 107 | 108 | async def process_remaining_interactions(self, sas, tas, flatten_tas, printed_ids, prev_steps, target_tasks): 109 | """Handle remaining interactions and process corrections.""" 110 | flatten_printed_ids = [i for ids in printed_ids if ids for i in ids] 111 | 112 | 113 | print_leftover = ( 114 | len(sas) > len(flatten_printed_ids) 115 | and len(flatten_tas) > len(flatten_printed_ids) 116 | and all(judge_to_be_true(sas[pid], tas[pid][0][1]) for pid in flatten_printed_ids) 117 | ) 118 | 119 | contain_wrong_result = any( 120 | not judge_to_be_true(sas[pid], tas[pid][0][1]) for pid in flatten_printed_ids 121 | ) 122 | 123 | if print_leftover and not contain_wrong_result: 124 | for print_id in range(len(flatten_printed_ids), min(len(sas),len(tas))): 125 | if not tas[print_id]: 126 | break 127 | 128 | await self.process_target(sas, tas, print_id, prev_steps, target_tasks) 129 | if not judge_to_be_true(sas[print_id], tas[print_id][0][1]): 130 | break 131 | 132 | return printed_ids, tas 133 | 134 | async def approx_gen(self, args, collector, app_prompt, total_step_number, start): 135 | """Generate approximation agent responses.""" 136 | try: 137 | # Get the response from the agent (token tracking handled by agent) 138 | response = await self.app_agent.generate(app_prompt, total_step_number + 1) 139 | 140 | # Add approximation task to online trajectory collector 141 | end = time.time() 142 | timestamp = datetime.fromtimestamp(end) 143 | source = "Approximation" 144 | step = total_step_number+1 145 | desc = response 146 | collector.record_step(timestamp, source, step, desc) 147 | 148 | # Record timing 149 | a_time = round(end-start, 2) 150 | app_tokens = self.config.APPROX_NORMAL_GENERATION[total_step_number+1] 151 | self.approximation_logger.log( 152 | f'Approximation: Step {total_step_number+1} - {response} -time {str(a_time)} -token {app_tokens}' 153 | ) 154 | self.config.APPROX_NORMAL_TIME[total_step_number+1] = a_time 155 | 156 | return response 157 | except Exception: 158 | return '' 159 | 160 | async def target_gen(self, args, prediction_task, prompt, total_step_number, tas, sas, target_tasks, printed_ids, current_step, prev_steps, start): 161 | """Generate target agent responses with comprehensive error handling.""" 162 | result = None 163 | 164 | try: 165 | n = 0 166 | while True: 167 | try: 168 | n += 1 169 | if n >= 10: 170 | result = '' 171 | break 172 | 173 | # Get response from agent (token tracking handled by agent) 174 | response = await self.tar_agent.generate(prompt, total_step_number + 1) 175 | result = response 176 | break 177 | except Exception: 178 | await asyncio.sleep(0.1) 179 | continue 180 | 181 | tas, printed_ids = await self.verify_process( 182 | args, prediction_task, result, 183 | total_step_number, tas, sas, target_tasks, 184 | printed_ids=printed_ids, current_step=current_step, 185 | prev_steps=prev_steps 186 | ) 187 | 188 | except Exception: 189 | pass 190 | 191 | # Record timing 192 | end = time.time() 193 | t_time = round(end-start, 2) 194 | self.config.TARGET_NORMAL_TIME[total_step_number+1] = t_time 195 | 196 | self.target_logger.log( 197 | f"Intermediate Target Step {total_step_number+1} - {result} " 198 | f"-gen {self.config.TARGET_NORMAL_GENERATION[total_step_number+1]} " 199 | f"-prompt {self.config.TARGET_NORMAL_PROMPT[total_step_number+1]} " 200 | f"-time {t_time}" 201 | ) 202 | 203 | return tas, printed_ids 204 | 205 | async def verify_process(self, args, prediction_task, result, total_step_number, tas, sas, target_tasks, printed_ids=[[]], current_step=0, prev_steps=[]): 206 | """Post-process target generation results and handle mismatches.""" 207 | in_step_number = total_step_number - current_step 208 | tas[in_step_number].append((in_step_number,result)) 209 | 210 | flatten_tas = sorted([item for t in tas if t for item in t], key=lambda x: x[0], reverse=False) 211 | printed_ids, tas = await self.process_on_time_interactions( 212 | sas, tas, flatten_tas, printed_ids, 213 | prev_steps, target_tasks 214 | ) 215 | 216 | flatten_ids = [t[0] for t in flatten_tas] 217 | if flatten_ids == list(range(len(flatten_ids))): 218 | for step_number, (s, t) in enumerate(zip(sas, flatten_tas)): 219 | self.target_logger.log(f"step number:{step_number}, t[0]: {t[0]}, t[1]: {t[1]}") 220 | if t[1].lower() == 'terminate': 221 | if not self.mismatch_state.mismatch_detected.is_set(): 222 | self.mismatch_state.mismatch_step_id = t[0] 223 | self.mismatch_state.mismatch_detected.set() 224 | raise Exception('terminate the whole process!') 225 | 226 | flatten_tas = sorted([item for t in tas if t for item in t], key=lambda x: x[0], reverse=False) 227 | for ta in flatten_tas: 228 | if len(sas) > ta[0]: 229 | if not judge_to_be_true(sas[ta[0]], ta[1]) or ta[1].lower() == "terminate": 230 | if not self.mismatch_state.mismatch_detected.is_set(): 231 | self.mismatch_state.mismatch_step_id = ta[0] 232 | self.mismatch_state.mismatch_detected.set() 233 | 234 | pending_approximation_tasks = [ 235 | t for t in asyncio.all_tasks() 236 | if not t.cancelled() and not t.done() 237 | and t not in target_tasks 238 | and t.get_name().startswith('approximation') 239 | ] 240 | 241 | for pending_approximation_task in pending_approximation_tasks: 242 | await cancel(pending_approximation_task) 243 | 244 | raise Exception(f'approximation error happen in step {total_step_number} for current step {current_step}, the target id is {ta[0]}') 245 | 246 | if self.config.ENABLE_PRED: 247 | k = await prediction_task 248 | else: 249 | k = args.k 250 | 251 | if not self.mismatch_state.mismatch_detected.is_set() and ta[0] >= k-1: 252 | self.mismatch_state.mismatch_step_id = ta[0] 253 | self.mismatch_state.mismatch_detected.set() 254 | 255 | return tas, printed_ids 256 | 257 | async def one_episode_sp(self, args, app_prompt, tar_prompt, current_step): 258 | """Run one breaking point of speculative planning.""" 259 | sas = [] # approximation 260 | tas = [] # target 261 | target_tasks = [] 262 | printed_ids = [] 263 | 264 | self.mismatch_state.mismatch_detected.clear() 265 | pred_k = 1 if self.config.ENABLE_PRED else args.k 266 | first = True 267 | 268 | i = 0 269 | prediction_task = None 270 | while i < pred_k: 271 | if self.mismatch_state.mismatch_detected.is_set(): 272 | break 273 | 274 | if self.config.ENABLE_PRED and first: 275 | prediction_task = asyncio.create_task( 276 | self.executor.async_predict(self.approximation_logger, args.offset) 277 | ) 278 | 279 | break_out_approximation = False 280 | tas.append([]) 281 | printed_ids.append([]) 282 | 283 | # Generate approximation 284 | a_start = time.time() 285 | approximation = asyncio.create_task( 286 | self.approx_gen( 287 | args, self.collector, app_prompt, current_step+i, a_start 288 | ), 289 | name=f"approximation_{current_step+i}" 290 | ) 291 | 292 | # Generate target 293 | t_start = time.time() 294 | target = asyncio.create_task( 295 | self.target_gen( 296 | args, prediction_task, tar_prompt, current_step+i, tas, sas, 297 | target_tasks, printed_ids, current_step, 298 | self.steps, t_start 299 | ), 300 | name=f"target_{i}" 301 | ) 302 | target_tasks.append(target) 303 | 304 | # Track concurrent calls 305 | concurrent_api_calls = concurrent_calls() 306 | if concurrent_api_calls >= self.config.MAX_CONCURRENT_CALLS: 307 | self.config.MAX_CONCURRENT_CALLS = concurrent_api_calls 308 | 309 | try: 310 | sa = await approximation 311 | sas.append(sa) 312 | if self.mismatch_state.mismatch_detected.is_set(): 313 | break 314 | 315 | # Check if we need to print approximation result 316 | flatten_tas = [] 317 | for t in tas[:len(sas)-1]: 318 | if t: 319 | flatten_tas += t 320 | flatten_tas = sorted(flatten_tas, key=lambda x: x[0], reverse=False) 321 | flatten_ids = [t[0] for t in flatten_tas] 322 | flatten_tas_action = [t[1] for t in flatten_tas] 323 | if (flatten_ids == list(range(len(flatten_ids))) and 324 | len(flatten_ids) == len(sas)-1 and 325 | all([judge_to_be_true(s, t) for s, t in zip(sas[:-1], flatten_tas_action)])): 326 | 327 | flattened_printed_ids = [printed_id[0] for printed_id in printed_ids if printed_id != []] 328 | 329 | if flattened_printed_ids: 330 | if len(sas) == len(flattened_printed_ids)+1: 331 | self.logger.log(f'The approximation agent thinks step {current_step+i+1} should be {sa}') 332 | self.config.HIL_INTERACTION = len(sas)-1 333 | register_async_handler(target_tasks=target_tasks) 334 | else: 335 | self.logger.log(f'The approximation agent thinks step {current_step+i+1} should be {sa}') 336 | self.config.HIL_INTERACTION = len(sas)-1 337 | register_async_handler(target_tasks=target_tasks) 338 | 339 | # Update prompts 340 | if '## Current Action Trajectory:' not in app_prompt: 341 | app_prompt += '\n\n## Current Action Trajectory:\n' 342 | app_prompt += f'\nAction {current_step+i+1}: {str(sa)}.' 343 | 344 | if '## Current Action Trajectory:' not in tar_prompt: 345 | tar_prompt += '\n\n## Current Action Trajectory:\n' 346 | tar_prompt += f'\nAction {current_step+i+1}: {str(sa)}.' 347 | 348 | except asyncio.CancelledError: 349 | self.target_logger.log("Exception happens.") 350 | pass 351 | 352 | # Check for termination conditions 353 | if sa.lower() == 'terminate': 354 | break_out_approximation = True 355 | 356 | flatten_tas = [] 357 | for t in tas: 358 | if t: 359 | flatten_tas += t 360 | flatten_tas = sorted(flatten_tas, key=lambda x: x[0], reverse=False) 361 | 362 | for t in flatten_tas: 363 | if len(sas) > t[0]: 364 | if not judge_to_be_true(sas[t[0]], t[1]) or t[1].lower() == "terminate": 365 | break_out_approximation = True 366 | self.mismatch_state.mismatch_step_id = t[0] 367 | self.mismatch_state.mismatch_detected.set() 368 | 369 | for process_id, one_task in enumerate(target_tasks): 370 | if not one_task.cancelled() and not one_task.done() and process_id > t[0]: 371 | self.target_logger.log(f'Cancel Task {len(self.steps)+process_id+1}') 372 | await cancel(one_task) 373 | break 374 | 375 | if break_out_approximation: 376 | break 377 | 378 | if self.config.ENABLE_PRED and first: 379 | pred_k = await prediction_task 380 | self.config.PREDICT_K.append(pred_k) 381 | self.config.PREDICT_TOTAL += 1 382 | pred_k = max(pred_k, 1) 383 | first = False 384 | i += 1 385 | 386 | # Process results after breaking point 387 | await self.mismatch_state.mismatch_detected.wait() 388 | self.target_logger.log(f"Mismatch at {self.mismatch_state.mismatch_step_id}. Cancel starts.") 389 | 390 | for process_id, one_task in enumerate(target_tasks): 391 | if not one_task.cancelled() and not one_task.done() and process_id > self.mismatch_state.mismatch_step_id: 392 | self.target_logger.log(f'Cancel task {len(self.steps)+process_id+1}') 393 | await cancel(one_task) 394 | 395 | # Wait for pending tasks prior to mismatch task 396 | pending_tasks = [t for t in target_tasks if not t.cancelled()] 397 | 398 | for process_id, t in enumerate(target_tasks): 399 | if t.done() or t.cancelled(): 400 | continue 401 | try: 402 | await t 403 | flatten_tas = [] 404 | for t in tas: 405 | if t: 406 | flatten_tas += t 407 | flatten_tas = sorted(flatten_tas, key=lambda x: x[0], reverse=False) 408 | for t in flatten_tas: 409 | if len(sas) > t[0]: 410 | if not judge_to_be_true(sas[t[0]], t[1]): 411 | for process_id, one_task in enumerate(pending_tasks): 412 | if not one_task.cancelled() and not one_task.done() and process_id > t[0]: 413 | self.target_logger.log(f'Cancel task: {len(self.steps)+process_id+1}') 414 | await cancel(one_task) 415 | break 416 | except Exception as e: 417 | print(f"An error occurred with task {len(self.steps)+t[0]+1}: {e}") 418 | 419 | # Process leftover interactions 420 | flatten_tas = [] 421 | for t in tas: 422 | if t: 423 | flatten_tas += t 424 | flatten_tas = sorted(flatten_tas, key=lambda x: x[0], reverse=False) 425 | printed_ids, tas = await self.process_remaining_interactions( 426 | sas, tas, flatten_tas, printed_ids, 427 | self.steps, target_tasks 428 | ) 429 | 430 | # Get final results 431 | flatten_tas = [] 432 | for t in tas: 433 | if t: 434 | flatten_tas += t 435 | flatten_tas = sorted(flatten_tas, key=lambda x: x[0], reverse=False) 436 | mismatch = False 437 | origin_sa = None 438 | 439 | self.target_logger.log(f"sas tasks: {sas}") 440 | self.target_logger.log(f"flatten_tas: {flatten_tas}") 441 | 442 | for step_number, (s, t) in enumerate(zip(sas, flatten_tas)): 443 | if t[0] == step_number and (not judge_to_be_true(s, t[1]) or t[1].lower() == "terminate"): 444 | self.target_logger.log(f"step {step_number} app: {s}, tar: {t[1]}") 445 | self.config.PREDICT_CORRECT += self.collector.build_trajectory( 446 | self.target_logger, self.config.PREDICT_K 447 | ) 448 | self.config.PREDICT_K = [] 449 | 450 | if self.config.ENABLE_TRAIN: 451 | if self.config.BUILD_TRAJ_TIMES == 0: 452 | asyncio.create_task(self.executor.async_train()) 453 | self.config.BUILD_TRAJ_TIMES = ( 454 | self.config.BUILD_TRAJ_TIMES + 1 455 | ) % self.config.TRAIN_INTERVAL 456 | 457 | sas = sas[:step_number]+[flatten_tas[step_number][1]] 458 | origin_sa = s 459 | mismatch = True 460 | break 461 | 462 | return sas, origin_sa, mismatch 463 | 464 | async def run(self, args, prompt: str) -> List[str]: 465 | """Run the speculative planning process.""" 466 | # Initialize agents and state 467 | await self.initialize(args) 468 | 469 | begin_time = datetime.now() 470 | steps = [] 471 | breaking_points = 0 472 | i = 0 473 | 474 | app_prompt = prompt 475 | tar_prompt = prompt 476 | 477 | while True: 478 | result, origin_sa, mismatch = await self.one_episode_sp( 479 | args, app_prompt, tar_prompt, len(steps) 480 | ) 481 | 482 | # Update prompts based on results 483 | if '## Current Action Trajectory:' not in app_prompt: 484 | app_prompt += '\n\n## Current Action Trajectory:\n' 485 | previous_action_trajectory = [ 486 | f'\nAction {len(steps) + j+1}: {result[j]}. Your prediction aligned with Target.' 487 | for j in range(len(result)) 488 | ] 489 | 490 | if mismatch: 491 | app_prompt += ''.join(previous_action_trajectory[:-1]) 492 | app_prompt += f"\nAction {len(steps) + len(result)}: {result[-1]}. Target predicted {result[-1]} and corrected your prediction {origin_sa}." 493 | else: 494 | app_prompt += ''.join(previous_action_trajectory) 495 | 496 | if '## Current Action Trajectory:' not in tar_prompt: 497 | tar_prompt += '\n\n## Current Action Trajectory:\n' 498 | previous_action_trajectory = [ 499 | f'\nAction {len(steps) + j+1}: {result[j]}.' 500 | for j in range(len(result)) 501 | ] 502 | tar_prompt += ''.join(previous_action_trajectory) 503 | 504 | steps += result 505 | breaking_points += 1 506 | i += len(result) 507 | 508 | if result[-1].lower() == 'terminate' or len(steps) >= self.config.MAX_STEP: 509 | break 510 | 511 | end_time = datetime.now() 512 | self.logger.log(f'{end_time} - {begin_time} = {end_time - begin_time}') 513 | self.config.TOTAL_SP_TIME = round((end_time - begin_time).total_seconds(), 2) 514 | 515 | # self.log_results(steps) 516 | 517 | return steps -------------------------------------------------------------------------------- /data/openagi_task_description.txt: -------------------------------------------------------------------------------- 1 | Given low-resolutioned noisy blurry grayscale image, how to return the regular image step by step? 2 | Given noisy blurry grayscale image, how to return the regular image step by step? 3 | Given low-resolutioned blurry grayscale image, how to return the regular image step by step? 4 | Given blurry grayscale image, how to return the regular image step by step? 5 | Given low-resolutioned noisy grayscale image, how to return the regular image step by step? 6 | Given noisy grayscale image, how to return the regular image step by step? 7 | Given low-resolutioned grayscale image, how to return the regular image step by step? 8 | Given grayscale image, how to return the regular image step by step? 9 | Given low-resolutioned noisy blurry image, how to return the regular image step by step? 10 | Given noisy blurry image, how to return the regular image step by step? 11 | Given low-resolutioned blurry image, how to return the regular image step by step? 12 | Given blurry image, how to return the regular image step by step? 13 | Given low-resolutioned noisy image, how to return the regular image step by step? 14 | Given noisy image, how to return the regular image step by step? 15 | Given low-resolutioned image, how to return the regular image step by step? 16 | Given low-resolutioned noisy blurry grayscale image, how to return the caption in German step by step? 17 | Given low-resolutioned noisy blurry grayscale image, how to return the class label in German step by step? 18 | Given low-resolutioned noisy blurry grayscale image, how to return the object names in German step by step? 19 | Given low-resolutioned noisy blurry grayscale image, how to return the caption in English step by step? 20 | Given low-resolutioned noisy blurry grayscale image, how to return the class label in English step by step? 21 | Given low-resolutioned noisy blurry grayscale image, how to return the object names in English step by step? 22 | Given noisy blurry grayscale image, how to return the caption in German step by step? 23 | Given noisy blurry grayscale image, how to return the class label in German step by step? 24 | Given noisy blurry grayscale image, how to return the object names in German step by step? 25 | Given noisy blurry grayscale image, how to return the caption in English step by step? 26 | Given noisy blurry grayscale image, how to return the class label in English step by step? 27 | Given noisy blurry grayscale image, how to return the object names in English step by step? 28 | Given low-resolutioned blurry grayscale image, how to return the caption in German step by step? 29 | Given low-resolutioned blurry grayscale image, how to return the class label in German step by step? 30 | Given low-resolutioned blurry grayscale image, how to return the object names in German step by step? 31 | Given low-resolutioned blurry grayscale image, how to return the caption in English step by step? 32 | Given low-resolutioned blurry grayscale image, how to return the class label in English step by step? 33 | Given low-resolutioned blurry grayscale image, how to return the object names in English step by step? 34 | Given blurry grayscale image, how to return the caption in German step by step? 35 | Given blurry grayscale image, how to return the class label in German step by step? 36 | Given blurry grayscale image, how to return the object names in German step by step? 37 | Given blurry grayscale image, how to return the caption in English step by step? 38 | Given blurry grayscale image, how to return the class label in English step by step? 39 | Given blurry grayscale image, how to return the object names in English step by step? 40 | Given low-resolutioned noisy grayscale image, how to return the caption in German step by step? 41 | Given low-resolutioned noisy grayscale image, how to return the class label in German step by step? 42 | Given low-resolutioned noisy grayscale image, how to return the object names in German step by step? 43 | Given low-resolutioned noisy grayscale image, how to return the caption in English step by step? 44 | Given low-resolutioned noisy grayscale image, how to return the class label in English step by step? 45 | Given low-resolutioned noisy grayscale image, how to return the object names in English step by step? 46 | Given noisy grayscale image, how to return the caption in German step by step? 47 | Given noisy grayscale image, how to return the class label in German step by step? 48 | Given noisy grayscale image, how to return the object names in German step by step? 49 | Given noisy grayscale image, how to return the caption in English step by step? 50 | Given noisy grayscale image, how to return the class label in English step by step? 51 | Given noisy grayscale image, how to return the object names in English step by step? 52 | Given low-resolutioned grayscale image, how to return the caption in German step by step? 53 | Given low-resolutioned grayscale image, how to return the class label in German step by step? 54 | Given low-resolutioned grayscale image, how to return the object names in German step by step? 55 | Given low-resolutioned grayscale image, how to return the caption in English step by step? 56 | Given low-resolutioned grayscale image, how to return the class label in English step by step? 57 | Given low-resolutioned grayscale image, how to return the object names in English step by step? 58 | Given grayscale image, how to return the caption in German step by step? 59 | Given grayscale image, how to return the class label in German step by step? 60 | Given grayscale image, how to return the object names in German step by step? 61 | Given grayscale image, how to return the caption in English step by step? 62 | Given grayscale image, how to return the class label in English step by step? 63 | Given grayscale image, how to return the object names in English step by step? 64 | Given low-resolutioned noisy blurry image, how to return the caption in German step by step? 65 | Given low-resolutioned noisy blurry image, how to return the class label in German step by step? 66 | Given low-resolutioned noisy blurry image, how to return the object names in German step by step? 67 | Given low-resolutioned noisy blurry image, how to return the caption in English step by step? 68 | Given low-resolutioned noisy blurry image, how to return the class label in English step by step? 69 | Given low-resolutioned noisy blurry image, how to return the object names in English step by step? 70 | Given noisy blurry image, how to return the caption in German step by step? 71 | Given noisy blurry image, how to return the class label in German step by step? 72 | Given noisy blurry image, how to return the object names in German step by step? 73 | Given noisy blurry image, how to return the caption in English step by step? 74 | Given noisy blurry image, how to return the class label in English step by step? 75 | Given noisy blurry image, how to return the object names in English step by step? 76 | Given low-resolutioned blurry image, how to return the caption in German step by step? 77 | Given low-resolutioned blurry image, how to return the class label in German step by step? 78 | Given low-resolutioned blurry image, how to return the object names in German step by step? 79 | Given low-resolutioned blurry image, how to return the caption in English step by step? 80 | Given low-resolutioned blurry image, how to return the class label in English step by step? 81 | Given low-resolutioned blurry image, how to return the object names in English step by step? 82 | Given blurry image, how to return the caption in German step by step? 83 | Given blurry image, how to return the class label in German step by step? 84 | Given blurry image, how to return the object names in German step by step? 85 | Given blurry image, how to return the caption in English step by step? 86 | Given blurry image, how to return the class label in English step by step? 87 | Given blurry image, how to return the object names in English step by step? 88 | Given low-resolutioned noisy image, how to return the caption in German step by step? 89 | Given low-resolutioned noisy image, how to return the class label in German step by step? 90 | Given low-resolutioned noisy image, how to return the object names in German step by step? 91 | Given low-resolutioned noisy image, how to return the caption in English step by step? 92 | Given low-resolutioned noisy image, how to return the class label in English step by step? 93 | Given low-resolutioned noisy image, how to return the object names in English step by step? 94 | Given noisy image, how to return the caption in German step by step? 95 | Given noisy image, how to return the class label in German step by step? 96 | Given noisy image, how to return the object names in German step by step? 97 | Given noisy image, how to return the caption in English step by step? 98 | Given noisy image, how to return the class label in English step by step? 99 | Given noisy image, how to return the object names in English step by step? 100 | Given low-resolutioned image, how to return the caption in German step by step? 101 | Given low-resolutioned image, how to return the class label in German step by step? 102 | Given low-resolutioned image, how to return the object names in German step by step? 103 | Given low-resolutioned image, how to return the caption in English step by step? 104 | Given low-resolutioned image, how to return the class label in English step by step? 105 | Given low-resolutioned image, how to return the object names in English step by step? 106 | Given clozed English text, how to generate a image step by step? 107 | Given English text, how to generate a image step by step? 108 | Given clozed English text, how to return the summarization in German step by step? 109 | Given clozed English text, how to translate the text in German step by step? 110 | Given clozed English text, how to return the sentiment in German step by step? 111 | Given clozed English text, how to return the summarization in English step by step? 112 | Given clozed English text, how to return the sentiment in English step by step? 113 | Given English text, how to return the summarization in German step by step? 114 | Given English text, how to translate the text in German step by step? 115 | Given English text, how to return the sentiment in German step by step? 116 | Given English text, how to return the summarization in English step by step? 117 | Given English text, how to return the sentiment in English step by step? 118 | Given low-resolutioned noisy blurry grayscale image and clozed English query, how to answer the question in English step by step? 119 | Given low-resolutioned noisy blurry grayscale image and clozed English query, how to answer the question in German step by step? 120 | Given low-resolutioned noisy blurry grayscale image and English query, how to answer the question in English step by step? 121 | Given low-resolutioned noisy blurry grayscale image and English query, how to answer the question in German step by step? 122 | Given noisy blurry grayscale image and clozed English query, how to answer the question in English step by step? 123 | Given noisy blurry grayscale image and clozed English query, how to answer the question in German step by step? 124 | Given noisy blurry grayscale image and English query, how to answer the question in English step by step? 125 | Given noisy blurry grayscale image and English query, how to answer the question in German step by step? 126 | Given low-resolutioned blurry grayscale image and clozed English query, how to answer the question in English step by step? 127 | Given low-resolutioned blurry grayscale image and clozed English query, how to answer the question in German step by step? 128 | Given low-resolutioned blurry grayscale image and English query, how to answer the question in English step by step? 129 | Given low-resolutioned blurry grayscale image and English query, how to answer the question in German step by step? 130 | Given blurry grayscale image and clozed English query, how to answer the question in English step by step? 131 | Given blurry grayscale image and clozed English query, how to answer the question in German step by step? 132 | Given blurry grayscale image and English query, how to answer the question in English step by step? 133 | Given blurry grayscale image and English query, how to answer the question in German step by step? 134 | Given low-resolutioned noisy grayscale image and clozed English query, how to answer the question in English step by step? 135 | Given low-resolutioned noisy grayscale image and clozed English query, how to answer the question in German step by step? 136 | Given low-resolutioned noisy grayscale image and English query, how to answer the question in English step by step? 137 | Given low-resolutioned noisy grayscale image and English query, how to answer the question in German step by step? 138 | Given noisy grayscale image and clozed English query, how to answer the question in English step by step? 139 | Given noisy grayscale image and clozed English query, how to answer the question in German step by step? 140 | Given noisy grayscale image and English query, how to answer the question in English step by step? 141 | Given noisy grayscale image and English query, how to answer the question in German step by step? 142 | Given low-resolutioned grayscale image and clozed English query, how to answer the question in English step by step? 143 | Given low-resolutioned grayscale image and clozed English query, how to answer the question in German step by step? 144 | Given low-resolutioned grayscale image and English query, how to answer the question in English step by step? 145 | Given low-resolutioned grayscale image and English query, how to answer the question in German step by step? 146 | Given grayscale image and clozed English query, how to answer the question in English step by step? 147 | Given grayscale image and clozed English query, how to answer the question in German step by step? 148 | Given grayscale image and English query, how to answer the question in English step by step? 149 | Given grayscale image and English query, how to answer the question in German step by step? 150 | Given low-resolutioned noisy blurry image and clozed English query, how to answer the question in English step by step? 151 | Given low-resolutioned noisy blurry image and clozed English query, how to answer the question in German step by step? 152 | Given low-resolutioned noisy blurry image and English query, how to answer the question in English step by step? 153 | Given low-resolutioned noisy blurry image and English query, how to answer the question in German step by step? 154 | Given noisy blurry image and clozed English query, how to answer the question in English step by step? 155 | Given noisy blurry image and clozed English query, how to answer the question in German step by step? 156 | Given noisy blurry image and English query, how to answer the question in English step by step? 157 | Given noisy blurry image and English query, how to answer the question in German step by step? 158 | Given low-resolutioned blurry image and clozed English query, how to answer the question in English step by step? 159 | Given low-resolutioned blurry image and clozed English query, how to answer the question in German step by step? 160 | Given low-resolutioned blurry image and English query, how to answer the question in English step by step? 161 | Given low-resolutioned blurry image and English query, how to answer the question in German step by step? 162 | Given blurry image and clozed English query, how to answer the question in English step by step? 163 | Given blurry image and clozed English query, how to answer the question in German step by step? 164 | Given blurry image and English query, how to answer the question in English step by step? 165 | Given blurry image and English query, how to answer the question in German step by step? 166 | Given low-resolutioned noisy image and clozed English query, how to answer the question in English step by step? 167 | Given low-resolutioned noisy image and clozed English query, how to answer the question in German step by step? 168 | Given low-resolutioned noisy image and English query, how to answer the question in English step by step? 169 | Given low-resolutioned noisy image and English query, how to answer the question in German step by step? 170 | Given noisy image and clozed English query, how to answer the question in English step by step? 171 | Given noisy image and clozed English query, how to answer the question in German step by step? 172 | Given noisy image and English query, how to answer the question in English step by step? 173 | Given noisy image and English query, how to answer the question in German step by step? 174 | Given low-resolutioned image and clozed English query, how to answer the question in English step by step? 175 | Given low-resolutioned image and clozed English query, how to answer the question in German step by step? 176 | Given low-resolutioned image and English query, how to answer the question in English step by step? 177 | Given low-resolutioned image and English query, how to answer the question in German step by step? 178 | Given clozed English document and clozed English query, how to answer the question in German step by step? 179 | Given clozed English document and clozed English query, how to answer the question in English step by step? 180 | Given clozed English document and English query, how to answer the question in German step by step? 181 | Given clozed English document and English query, how to answer the question in English step by step? 182 | Given English document and clozed English query, how to answer the question in German step by step? 183 | Given English document and clozed English query, how to answer the question in English step by step? 184 | Given English document and English query, how to answer the question in German step by step? 185 | Given English document and English query, how to answer the question in English step by step? 186 | Given a low-resolution image of a landscape, how to transform it into a vibrant high-quality version? 187 | Given a noisy nighttime photograph, how to produce a clear image with identifiable objects? 188 | Given a blurry photo of a street scene, how to create a detailed English description of what it contains? 189 | Given a grayscale historical photograph, how to modernize it and identify its historical period? 190 | Given a low-quality security camera image, how to improve it and identify all persons present? 191 | Given a blurry grayscale medical scan, how to transform it for diagnostic analysis? 192 | Given a degraded artwork image, how to restore it to its original appearance and explain its contents? 193 | Given a noisy satellite image, how to clean it up and count specific geographic features? 194 | Given an image of wildlife, how to describe it comprehensively in German? 195 | Given an image of a crowded scene, how to catalog all items present and explain their relationships? 196 | Given an image of a retail store, how to identify products and analyze customer experience impressions? 197 | Given a historical image, how to bring it to life and provide historical context? 198 | Given an image of a cityscape, how to identify architectural features and explain urban design principles? 199 | Given an image of a natural disaster, how to identify the event type and assess potential impacts? 200 | Given an image of a social gathering, how to identify attendees and determine the mood of the event? 201 | Given an image with multiple objects, how to complete a description with missing information about object relationships? 202 | Given a product description with missing information and a product image, how to complete the description? 203 | Given a news article summary, how to create a visual representation of the main story? 204 | Given customer feedback analysis, how to create a visual representation of customer satisfaction? 205 | Given a translated technical document, how to create visual aids for complex concepts? 206 | Given a question about an image, how to provide an answer and identify the key visual elements? 207 | Given a research paper summary, how to visualize the main findings? 208 | Given an incomplete description and its corresponding image, how to fill the information gaps? 209 | Given movie review sentiment data, how to create a visual representation of the film's reception? 210 | Given a severely degraded grayscale photograph, how to transform it into a modern, clear, colorful image and identify its contents? 211 | Given a low-quality product image, how to enhance it and create compelling marketing text? 212 | Given a poor-quality surveillance image, how to improve it and answer security-related questions? 213 | Given a historical document image, how to restore it and translate the content to German? 214 | Given a low-resolution medical image, how to improve it and answer diagnostic questions? 215 | Given a blurry wildlife photo, how to clarify it and provide information about the species and habitat? 216 | Given a noisy underwater image, how to improve visibility and describe the marine ecosystem? 217 | Given a degraded aerial image, how to restore it and analyze urban development patterns? 218 | Given an image and related text with missing information, how to complete the text and answer questions about the content? 219 | Given an image of a landmark and a question about its history, how to provide comprehensive historical information? 220 | Given an image of a product and customer reviews, how to answer questions about features and customer satisfaction? 221 | Given an image with multiple objects and a question about their relationship, how to describe the scene accurately? 222 | Given an image of a scientific experiment and a technical question, how to provide an expert-level explanation? 223 | Given an image of an artwork and questions about style, how to analyze the artistic techniques? 224 | Given an image of food and nutrition questions, how to identify ingredients and answer health queries? 225 | Given an image of a natural phenomenon and a question about causes, how to explain the scientific principles involved? 226 | Given a blurry image, how to restore it and create a detailed description? 227 | Given a noisy grayscale photograph, how to transform it into a vibrant image with a descriptive caption? 228 | Given a low-resolution product image, how to enhance it and create effective marketing content? 229 | Given a historical black and white portrait, how to modernize it and create a biographical sketch? 230 | Given a degraded landscape image, how to restore it and describe its geographical features? 231 | Given an image with poor lighting, how to optimize visibility and provide a scene analysis? 232 | Given a blurry architectural image, how to clarify it and analyze its architectural significance? 233 | Given a faded historical document, how to restore legibility and translate the content? 234 | Given a text description, how to visualize it and expand on the description based on the visualization? 235 | Given a basic image and an emotion, how to transform the image to evoke that emotional response? 236 | Given an image, how to create a story based on its elements and then visualize a key moment from that story? 237 | Given a vacation photo, how to enhance it and create a German-language travel description? 238 | Given a product image, how to create compelling marketing text with creative elements? 239 | Given an image of cultural significance, how to analyze its origin and create educational content? 240 | Given a nature image, how to identify species and create conservation-focused information? 241 | Given a historical event image, how to restore it and create content about its modern relevance? 242 | Given a scanned newspaper with faded text, how to make the article readable and provide its main points in German? 243 | Given a screenshot of social media comments, how to identify the overall public opinion and create a visual representation of it? 244 | Given an old family photo album page, how to restore the images and create captions for each photograph? 245 | Given a film still from a classic movie, how to modernize its appearance and analyze its cinematic significance? 246 | Given a poorly captured whiteboard from a meeting, how to make the content clear and summarize the key discussion points? 247 | Given an image of ancient inscriptions, how to clarify the text and translate its meaning to English? 248 | Given a graph with pixelated text and lines, how to create a clean version and explain the data trends? 249 | Given a collection of product labels, how to extract the information and identify potentially misleading claims? 250 | Given an image of a busy intersection, how to identify safety hazards and recommend improvements? 251 | Given a photograph of a room interior, how to identify all furniture items and suggest design improvements? 252 | Given an image of a sports play, how to identify player positions and explain the strategy? 253 | Given an aerial view of farmland, how to identify crop types and assess agricultural health? 254 | Given an image of a construction site, how to identify equipment and assess safety compliance? 255 | Given a photograph of a geological formation, how to identify its type and explain its formation process? 256 | Given an image of a night sky, how to identify celestial objects and create an educational guide? 257 | Given a photograph of an archaeological site, how to identify artifacts and explain their historical context? 258 | Given a technical diagram, how to create a simplified version with explanatory text for non-experts? 259 | Given a children's drawing, how to transform it into a professional-looking illustration with a story? 260 | Given a weather map, how to convert it into a travel advisory for tourists? 261 | Given a candid photograph, how to transform it into a professional portrait with appropriate background? 262 | Given a rough sketch of a product idea, how to create a polished concept image with feature descriptions? 263 | Given a concert photograph, how to enhance it and create an evocative review of the performance? 264 | Given a photograph of ingredients, how to create a recipe with preparation instructions? 265 | Given an image of a fashion outfit, how to identify the style elements and suggest complementary items? 266 | Given an image of a busy street in a foreign country, how to create a cultural guide for visitors? 267 | Given a photograph of an unusual animal, how to identify the species and create an informative fact sheet? 268 | Given a snapshot of a board game in progress, how to explain the current state and possible next moves? 269 | Given an image of a mechanical device, how to create an operation manual with labeled parts? 270 | Given a photograph of a restaurant dish, how to identify ingredients and estimate nutritional information? 271 | Given an image of a damaged vehicle, how to assess the extent of damage and estimate repair needs? 272 | Given a candid photo of a celebrity, how to analyze the context and create an appropriate news caption? 273 | Given an image of a protest or public gathering, how to analyze the event while respecting privacy concerns? 274 | Given an image of a biological specimen, how to create a labeled diagram with educational notes? 275 | Given a historical battlefield image, how to create a strategic analysis for history students? 276 | Given a photograph of a chemical reaction, how to explain the underlying scientific principles? 277 | Given an astronomical image, how to create an age-appropriate explanation for elementary school students? 278 | Given an image of a mathematical graph, how to explain the represented concept in simple terms? 279 | Given an engineering prototype photograph, how to create technical specifications and design rationale? 280 | Given a geographical landmark image, how to create a geological history and formation explanation? 281 | Given an image of a physics experiment, how to create step-by-step instructions with expected outcomes? 282 | Given visual art, how to create a descriptive text that enables blind individuals to appreciate it? 283 | Given a complex infographic, how to convert it into a plain text explanation while preserving all information? 284 | Given an image-heavy instructional guide, how to create an audio-friendly version with the same information? 285 | Given a visual warning sign, how to create universal symbols and multi-language alternatives? 286 | Given a colorful data visualization, how to recreate it in a colorblind-friendly format? 287 | Given a visually busy website screenshot, how to suggest accessibility improvements? 288 | Given a map image, how to create turn-by-turn directions in text format? 289 | Given a visual timeline, how to convert it into a narrative history? 290 | Given before and after images of a renovation project, how to identify all changes and assess quality improvements? 291 | Given images of similar products from different brands, how to create an objective comparison guide? 292 | Given images of the same location in different seasons, how to analyze environmental changes? 293 | Given reference and test images of the same subject, how to identify and explain all differences? 294 | Given an image of a structure before and after a natural disaster, how to assess damage and suggest reinforcements? 295 | Given images of a child at different ages, how to create a developmental progress report? 296 | Given original and counterfeit product images, how to identify distinguishing features? 297 | Given healthy and diseased plant images, how to create a diagnostic guide with treatment recommendations? 298 | Given an image of a person in a specific environment, how to determine their profession and activities? 299 | Given an image of an unusual tool, how to identify its purpose and instructions for use? 300 | Given a photograph with mixed languages on signs, how to translate all text to English and explain the location? 301 | Given an image of an interaction between people, how to determine their relationship and conversation topic? 302 | Given a photograph of an unusual natural phenomenon, how to explain what is happening scientifically? 303 | Given an image of a machine malfunction, how to diagnose the problem and suggest repairs? 304 | Given a photograph of an animal's tracks or signs, how to identify the species and its recent activities? 305 | Given an image of a person's expression, how to analyze their emotional state and potential causes? 306 | Given abstract art, how to create a poem that captures its essence? 307 | Given a crowded urban scene, how to transform it into a peaceful natural setting while preserving key elements? 308 | Given a photograph of a mundane object, how to create an advertisement that makes it appear luxurious? 309 | Given a daytime cityscape, how to transform it into a futuristic night scene with descriptive text? 310 | Given a portrait photograph, how to create different artistic interpretations in various historical art styles? 311 | Given a modern building, how to reimagine it in classical architectural style with appropriate descriptions? 312 | Given an image of ordinary food, how to transform it into gourmet presentation with a menu description? 313 | Given a simple landscape, how to create four seasonal variations with appropriate mood descriptions? -------------------------------------------------------------------------------- /travelplanner_supplement/planner.py: -------------------------------------------------------------------------------- 1 | """Dynamic Speculative Planning implementation for Travel Planner.""" 2 | 3 | import asyncio 4 | import time 5 | from datetime import datetime 6 | from typing import List 7 | import nltk 8 | 9 | from config import Config 10 | from async_online_utils import SharedState 11 | from util import cancel, register_async_handler 12 | 13 | 14 | def judge_to_be_true(s, t): 15 | """Judge if two actions are semantically equivalent.""" 16 | try: 17 | approximation_function_name = s.split("[")[0].strip() 18 | target_function_name = t.split("[")[0].strip() 19 | 20 | approximation_function_arg = s[s.index("[") : s.index("]")].strip() 21 | target_function_arg = t[t.index("[") : t.index("]")].strip() 22 | 23 | def token_edit_levenstein_similarity_normalized( 24 | text1: str, text2: str 25 | ) -> float: 26 | """Compute the normalized levenstein distance between two texts.""" 27 | return 1 - nltk.edit_distance(text1, text2) / max(len(text1), len(text2)) 28 | 29 | if approximation_function_name == target_function_name: 30 | if ( 31 | token_edit_levenstein_similarity_normalized( 32 | approximation_function_arg, target_function_arg 33 | ) 34 | > 0.5 35 | ): 36 | return True 37 | 38 | return False 39 | except Exception: 40 | if s == t: 41 | return True 42 | else: 43 | return False 44 | 45 | 46 | def concurrent_calls(): 47 | """Get the number of concurrent API calls.""" 48 | tasks = asyncio.all_tasks() 49 | pending_tasks = [t for t in tasks if not t.done() and not t.cancelled()] 50 | return len(pending_tasks) 51 | 52 | 53 | def interaction_function(s, t, logger, collector, previous_steps, config): 54 | """Handle interaction between approximation and target agents.""" 55 | cur_time = time.time() 56 | timestamp = datetime.fromtimestamp(cur_time) 57 | source = "Target" 58 | step = t[0][0] + len(previous_steps) + 1 59 | desc = t[0][1][0] 60 | 61 | collector.record_step(timestamp, source, step, desc) 62 | 63 | config.TOTAL_APPROXIMATION_CALLS += 1 64 | if judge_to_be_true(s[0], t[0][1][0]) and (not t[0][1][1] or t[0][1][0].lower() == 'terminate'): 65 | config.TOTAL_CORRECT_APPROXIMATION_CALLS += 1 66 | logger.log(f"Agree: Step {step}") 67 | logger.log(f"approximation: {s[0]}") 68 | logger.log(f"target: {t[0][1][0]}") 69 | else: 70 | logger.log(f"Correcting: Step {step}") 71 | logger.log(f"approximation: {s[0]}") 72 | logger.log(f"target: {t[0][1][0]}") 73 | 74 | return desc 75 | 76 | 77 | class TravelPlannerSpeculativePlanner: 78 | def __init__(self, collector, encoding, executor, app_agent, tar_agent, config: Config): 79 | self.collector = collector 80 | self.encoding = encoding 81 | self.executor = executor 82 | self.app_agent = app_agent 83 | self.tar_agent = tar_agent 84 | self.config = config 85 | self.mismatch_state = SharedState() 86 | self.steps = [] 87 | 88 | self.logger = None 89 | self.target_logger = None 90 | self.approximation_logger = None 91 | 92 | def set_loggers(self, logger, target_logger, approximation_logger, train_logger=None): 93 | """Set loggers for the planner.""" 94 | self.logger = logger 95 | self.target_logger = target_logger 96 | self.approximation_logger = approximation_logger 97 | self.train_logger = train_logger 98 | 99 | def initialize_for_new_task(self, query): 100 | """Initialize the planner for a new task.""" 101 | # Reset steps for new task 102 | self.steps = [] 103 | 104 | # Reset mismatch state 105 | self.mismatch_state.mismatch_detected.clear() 106 | self.mismatch_state.mismatch_step_id = None 107 | 108 | # Set query for both agents 109 | self.app_agent.query = query 110 | self.tar_agent.query = query 111 | 112 | # Reset task-specific configs using config's reset method 113 | self.config.reset_task_metrics() 114 | 115 | def log_task_description(self, task_description): 116 | """Log task description to all loggers.""" 117 | self.logger.log('task description: ' + task_description) 118 | self.target_logger.log('task description: ' + task_description) 119 | self.approximation_logger.log('task description: ' + task_description) 120 | 121 | def log_results(self, steps): 122 | """Log final results.""" 123 | self.logger.log('final result for the speculative planning ' + str([s[0] for s in steps])) 124 | self.logger.log('max concurrent calls: ' + str(self.config.MAX_CONCURRENT_CALLS-1)) 125 | 126 | async def initialize(self, args): 127 | """Initialize the planner.""" 128 | await self.mismatch_state.initialize() 129 | 130 | async def process_target(self, sas, tas, to_print_id, prev_steps, target_tasks): 131 | """Process target agent responses and update collectors.""" 132 | s = sas[to_print_id] 133 | t = tas[to_print_id] 134 | self.target_logger.log(f'Target: Step {t[0][0] + len(prev_steps)+1} - {t[0][1][0]}') 135 | 136 | # Add to online trajectory collector 137 | cur_time = time.time() 138 | timestamp = datetime.fromtimestamp(cur_time) 139 | source = "Target" 140 | step = t[0][0] + len(prev_steps) + 1 141 | desc = t[0][1][0] 142 | self.collector.record_step(timestamp, source, step, desc) 143 | 144 | self.config.TOTAL_APPROXIMATION_CALLS += 1 145 | if judge_to_be_true(s[0], t[0][1][0]): 146 | self.config.TOTAL_CORRECT_APPROXIMATION_CALLS += 1 147 | self.logger.log(f'The target agent thinks step {len(prev_steps) + to_print_id+1} should be {t[0][1][0]}, which agrees with the approximation agent.') 148 | try: 149 | self.logger.log(f'The approximation agent thinks step {len(prev_steps) + to_print_id+2} should be {sas[to_print_id+1][0]}') 150 | register_async_handler(target_tasks=target_tasks) 151 | except Exception: 152 | pass 153 | else: 154 | self.logger.log(f'The target agent thinks step {len(prev_steps) + to_print_id+1} should be {t[0][1][0]}, correcting what the approximation agent thinks which is {s[0]}.') 155 | 156 | def process_on_time_interactions(self, sas, tas, flatten_tas, printed_ids, prev_steps, target_tasks): 157 | """Handle on-time interaction between approximation and target agents.""" 158 | tas_ids = [t[0] for t in flatten_tas] 159 | flatten_printed_ids = [i for ids in printed_ids if ids for i in ids] 160 | tas_length = len(tas_ids) 161 | 162 | for l in range(tas_length+1): 163 | if not(tas_ids[:l] == list(range(len(flatten_tas)))[:l] and len(sas) >= len(flatten_tas[:l])): 164 | tas_length = l-1 165 | break 166 | 167 | contain_wrong_result = any( 168 | not judge_to_be_true(sas[pid][0], tas[pid][0][1][0]) for pid in flatten_printed_ids 169 | ) 170 | 171 | if tas_length > 0 and not contain_wrong_result: 172 | tas_ids = tas_ids[:tas_length] 173 | to_print_ids = list(set(tas_ids) - set(flatten_printed_ids)) 174 | for order_id, to_print_id in enumerate(to_print_ids): 175 | if order_id > 0 and not judge_to_be_true( 176 | sas[to_print_ids[order_id-1]][0], 177 | tas[to_print_ids[order_id-1]][0][1][0] 178 | ): 179 | break 180 | printed_ids[to_print_id].append(to_print_id) 181 | t_res = interaction_function(sas[to_print_id], tas[to_print_id], self.logger, self.collector, prev_steps, self.config) 182 | if str(t_res) == str(tas[to_print_id][0][1][0]): 183 | if t_res.lower() == "terminate": 184 | return printed_ids, tas 185 | continue 186 | 187 | changed_position = list(tas[to_print_id][0]) 188 | self.tar_agent.execute(t_res) 189 | changed_position[1] = [ 190 | t_res, 191 | self.tar_agent.current_observation, 192 | ] 193 | tas[to_print_id][0] = changed_position 194 | return printed_ids, tas 195 | 196 | return printed_ids, tas 197 | 198 | def process_remaining_interactions(self, sas, tas, flatten_tas, printed_ids, prev_steps, target_tasks): 199 | """Handle remaining interactions and process corrections.""" 200 | flatten_printed_ids = [i for ids in printed_ids if ids for i in ids] 201 | 202 | print_leftover = ( 203 | len(sas) > len(flatten_printed_ids) 204 | and len(flatten_tas) > len(flatten_printed_ids) 205 | and all(judge_to_be_true(sas[pid][0], tas[pid][0][1][0]) for pid in flatten_printed_ids) 206 | ) 207 | 208 | contain_wrong_result = any( 209 | not judge_to_be_true(sas[pid][0], tas[pid][0][1][0]) for pid in flatten_printed_ids 210 | ) 211 | 212 | if print_leftover and not contain_wrong_result: 213 | for print_id in range(len(flatten_printed_ids), min(len(sas),len(tas))): 214 | if not tas[print_id]: 215 | break 216 | 217 | t_res = interaction_function(sas[print_id], tas[print_id], self.logger, self.collector, prev_steps, self.config) 218 | printed_ids[print_id].append(print_id) 219 | if str(t_res) == str(tas[print_id][0][1][0]): 220 | continue 221 | 222 | changed_position = list(tas[print_id][0]) 223 | self.tar_agent.execute(t_res) 224 | changed_position[1] = [ 225 | t_res, 226 | self.tar_agent.current_observation, 227 | ] 228 | tas[print_id][0] = changed_position 229 | 230 | if not judge_to_be_true(sas[print_id][0], tas[print_id][0][1][0]): 231 | break 232 | 233 | return printed_ids, tas 234 | 235 | 236 | async def approx_gen(self, total_step_number, start): 237 | """Generate approximation agent responses.""" 238 | # Get the response from the agent 239 | action, finished = await self.app_agent.direct_act(total_step_number, self.config, self.approximation_logger) 240 | 241 | # Add approximation task to online trajectory collector 242 | end = time.time() 243 | timestamp = datetime.fromtimestamp(end) 244 | source = "Approximation" 245 | step = total_step_number+1 246 | desc = action 247 | self.collector.record_step(timestamp, source, step, desc) 248 | 249 | # Record timing 250 | a_time = round(end-start, 2) 251 | 252 | app_tokens = self.config.APPROX_NORMAL_GENERATION[total_step_number] 253 | self.approximation_logger.log( 254 | f'Approximation: Step {total_step_number+1} -Action {action} -Finished {finished} -time {str(a_time)} -token {app_tokens}' 255 | ) 256 | self.config.APPROX_NROMAL_TIME[total_step_number+1] = a_time 257 | 258 | # find action 259 | if not finished: 260 | self.app_agent.execute(action) 261 | observation = self.app_agent.current_observation 262 | else: 263 | observation = "terminate" 264 | 265 | return action, observation 266 | 267 | async def target_gen(self, args, prediction_task, prompt, total_step_number, tas, sas, target_tasks, printed_ids, current_step, prev_steps, start): 268 | """Generate target agent responses with comprehensive error handling.""" 269 | result = None 270 | finished = False 271 | 272 | n = 0 273 | while True: 274 | try: 275 | n += 1 276 | if n >= 10: 277 | result = '' 278 | break 279 | 280 | # Set the query for the agent 281 | self.tar_agent.query = prompt 282 | 283 | # Create scratchpad 284 | scratchpad = "" 285 | scratchpad = self.tar_agent.create_scratchpad(scratchpad, prev_steps + sas) 286 | 287 | # Get response from agent 288 | action, finished = await self.tar_agent.think_and_act( 289 | scratchpad, total_step_number, self.config, self.target_logger 290 | ) 291 | result = action 292 | break 293 | except Exception: 294 | await asyncio.sleep(0.1) 295 | continue 296 | 297 | tas, printed_ids = await self.verify_process( 298 | args, prediction_task, result, finished, 299 | total_step_number, tas, sas, target_tasks, 300 | printed_ids=printed_ids, current_step=current_step, 301 | prev_steps=prev_steps, 302 | start=start 303 | ) 304 | 305 | # Record timing 306 | end = time.time() 307 | t_time = round(end-start, 2) 308 | self.config.TARGET_NORMAL_TIME[total_step_number+1] = t_time 309 | 310 | self.target_logger.log( 311 | f"Intermediate Target Step {total_step_number+1} -Action {result} -Finished {finished} -gen {self.config.TARGET_NORMAL_GENERATION[total_step_number+1]} -prompt {self.config.TARGET_NORMAL_PROMPT[total_step_number+1]} -time {t_time}" 312 | ) 313 | 314 | return tas, printed_ids 315 | 316 | async def verify_process(self, args, prediction_task, result, finished, total_step_number, tas, sas, target_tasks, printed_ids=[[]], current_step=0, prev_steps=[], start=0): 317 | """Post-process target generation results and handle mismatches.""" 318 | in_step_number = total_step_number - current_step 319 | tas[in_step_number] = [[in_step_number, (result, finished)]] 320 | 321 | flatten_tas = [] 322 | for t in tas: 323 | if t: 324 | flatten_tas += t 325 | flatten_tas = sorted(flatten_tas, key=lambda x: x[0],reverse=False) 326 | printed_ids, tas = self.process_on_time_interactions( 327 | sas, tas, flatten_tas, printed_ids, 328 | prev_steps, target_tasks 329 | ) 330 | 331 | flatten_ids = [t[0] for t in flatten_tas] 332 | if flatten_ids == list(range(len(flatten_ids))): 333 | for step_number, (s, t) in enumerate(zip(sas, flatten_tas)): 334 | self.target_logger.log(f"step number:{step_number}, t[0]: {t[0]}, t[1]: {t[1]}") 335 | if t[0] == step_number and t[1][1]: # terminate 336 | if not self.mismatch_state.mismatch_detected.is_set(): 337 | self.mismatch_state.mismatch_step_id = t[0] 338 | self.mismatch_state.mismatch_detected.set() 339 | end = time.time() 340 | t_time = round(end-start, 2) 341 | self.config.TARGET_NORMAL_TIME[total_step_number+1] = t_time 342 | 343 | self.target_logger.log( 344 | f"Intermediate Target Step {total_step_number+1} -Action {result} -Finished {finished} -gen {self.config.TARGET_NORMAL_GENERATION[total_step_number+1]} -prompt {self.config.TARGET_NORMAL_PROMPT[total_step_number+1]} -time {t_time}" 345 | ) 346 | 347 | raise Exception('terminate the whole process!') 348 | 349 | flatten_tas = [] 350 | for t in tas: 351 | if t: 352 | flatten_tas += t 353 | flatten_tas = sorted(flatten_tas, key=lambda x: x[0],reverse=False) 354 | for ta in flatten_tas: 355 | if len(sas) > ta[0]: 356 | if not judge_to_be_true(sas[ta[0]][0], ta[1][0]) or ta[1][1] or ta[1][0].lower() == "terminate": 357 | if not self.mismatch_state.mismatch_detected.is_set(): 358 | self.mismatch_state.mismatch_step_id = ta[0] 359 | self.mismatch_state.mismatch_detected.set() 360 | 361 | pending_approximation_tasks = [ 362 | t for t in asyncio.all_tasks() 363 | if not t.cancelled() and not t.done() 364 | and t not in target_tasks 365 | and t.get_name().startswith('approximation') 366 | ] 367 | 368 | for pending_approximation_task in pending_approximation_tasks: 369 | await cancel(pending_approximation_task) 370 | 371 | end = time.time() 372 | t_time = round(end-start, 2) 373 | self.config.TARGET_NORMAL_TIME[total_step_number+1] = t_time 374 | self.target_logger.log( 375 | f"Intermediate Target Step {total_step_number+1} -Action {result} -Finished {finished} -gen {self.config.TARGET_NORMAL_GENERATION[total_step_number+1]} -prompt {self.config.TARGET_NORMAL_PROMPT[total_step_number+1]} -time {t_time}" 376 | ) 377 | raise Exception(f'approximation error happen in step {total_step_number} for current step {current_step}, the target id is {ta[0]}') 378 | 379 | if self.config.ENABLE_PRED: 380 | k = await prediction_task 381 | else: 382 | k = args.k 383 | 384 | if not self.mismatch_state.mismatch_detected.is_set() and ta[0] >= k-1: 385 | self.mismatch_state.mismatch_step_id = ta[0] 386 | self.mismatch_state.mismatch_detected.set() 387 | 388 | return tas, printed_ids 389 | 390 | async def one_episode_sp(self, args, app_prompt, tar_prompt, current_step): 391 | """Run one breaking point of speculative planning.""" 392 | sas = [] # approximation 393 | tas = [] # target 394 | target_tasks = [] 395 | printed_ids = [] 396 | 397 | self.mismatch_state.mismatch_detected.clear() 398 | pred_k = 1 if self.config.ENABLE_PRED else args.k 399 | first = True 400 | 401 | i = 0 402 | prediction_task = None 403 | while i < pred_k: 404 | if self.mismatch_state.mismatch_detected.is_set(): 405 | break 406 | 407 | if self.config.ENABLE_PRED and first: 408 | prediction_task = asyncio.create_task( 409 | self.executor.async_predict(self.approximation_logger, args.offset) 410 | ) 411 | self.config.PENDING_BACKGROUND_TASKS.append(prediction_task) 412 | 413 | break_out_approximation = False 414 | tas.append([]) 415 | printed_ids.append([]) 416 | 417 | # Generate approximation 418 | a_start = time.time() 419 | approximation = asyncio.create_task( 420 | self.approx_gen( 421 | current_step+i, a_start 422 | ), 423 | name=f"approximation_{current_step+i}" 424 | ) 425 | 426 | # Generate target 427 | t_start = time.time() 428 | target = asyncio.create_task( 429 | self.target_gen( 430 | args, prediction_task, tar_prompt, current_step+i, tas, sas, 431 | target_tasks, printed_ids, current_step, 432 | self.steps, t_start 433 | ), 434 | name=f"target_{current_step+i}" 435 | ) 436 | target_tasks.append(target) 437 | 438 | # Track concurrent calls 439 | concurrent_api_calls = concurrent_calls() 440 | if concurrent_api_calls >= self.config.MAX_CONCURRENT_CALLS: 441 | self.config.MAX_CONCURRENT_CALLS = concurrent_api_calls 442 | 443 | try: 444 | action, observation = await approximation 445 | sa = [action, observation] 446 | sas.append(sa) 447 | if self.mismatch_state.mismatch_detected.is_set(): 448 | break 449 | except asyncio.CancelledError: 450 | pass 451 | 452 | # Check for termination conditions 453 | if observation is True or action.lower == 'terminate': # terminate 454 | break_out_approximation = True 455 | 456 | flatten_tas = [] 457 | for t in tas: 458 | if t: 459 | flatten_tas += t 460 | flatten_tas = sorted(flatten_tas, key=lambda x: x[0], reverse=False) 461 | 462 | for t in flatten_tas: 463 | if len(sas) > t[0]: 464 | if not judge_to_be_true(sas[t[0]][0], t[1][0]) or t[1][1] or t[1][0].lower() == "terminate": 465 | break_out_approximation = True 466 | self.mismatch_state.mismatch_step_id = t[0] 467 | self.mismatch_state.mismatch_detected.set() 468 | 469 | for process_id, one_task in enumerate(target_tasks): 470 | if not one_task.cancelled() and not one_task.done() and process_id > t[0]: 471 | self.target_logger.log(f'Cancel Task {len(self.steps)+process_id+1}') 472 | await cancel(one_task) 473 | break 474 | 475 | if break_out_approximation: 476 | break 477 | 478 | if self.config.ENABLE_PRED and first: 479 | pred_k = await prediction_task 480 | self.config.PREDICT_K.append(pred_k) 481 | self.config.PREDICT_TOTAL += 1 482 | pred_k = max(pred_k, 0) 483 | first = False 484 | i += 1 485 | 486 | # Process results after breaking point 487 | await self.mismatch_state.mismatch_detected.wait() 488 | self.target_logger.log(f"Breaking point stops at step {len(self.steps)+1+self.mismatch_state.mismatch_step_id}. Start Cancellation.") 489 | 490 | for process_id, one_task in enumerate(target_tasks): 491 | if not one_task.cancelled() and not one_task.done() and process_id > self.mismatch_state.mismatch_step_id: 492 | self.target_logger.log(f'Cancel task {len(self.steps)+process_id+1}') 493 | await cancel(one_task) 494 | 495 | pending_tasks = [t for t in target_tasks if not t.cancelled()] 496 | 497 | while pending_tasks: 498 | break_while_loop = False 499 | try: 500 | if [pending_task.done() for pending_task in pending_tasks] == [True] * len( 501 | pending_tasks 502 | ): 503 | break_while_loop = True 504 | try: 505 | await asyncio.gather(*pending_tasks, return_exceptions=False) 506 | break 507 | except Exception: 508 | break 509 | # should not await cancelled tasks 510 | # return_exceptions=False is also the default value 511 | await asyncio.gather(*pending_tasks, return_exceptions=False) 512 | break_while_loop = True 513 | break 514 | except Exception as e: 515 | if str(e) == "terminate the whole process!": 516 | # cancel all pending tasks because we have already got TERMINATE 517 | for process_id, one_task in enumerate(pending_tasks): 518 | if not one_task.cancelled() and not one_task.done() and process_id > self.mismatch_state.mismatch_step_id: 519 | self.target_logger.log(f"cancel task at line 552 {len(self.steps)+process_id+1}") 520 | await cancel(one_task) 521 | # organize the results and return the final results 522 | flatten_tas = [] 523 | for t in tas: 524 | if t: 525 | flatten_tas += t 526 | flatten_tas = sorted(flatten_tas, key=lambda x: x[0], reverse=False) 527 | printed_ids, tas = self.process_remaining_interactions( 528 | sas, tas, flatten_tas, printed_ids, 529 | self.steps, target_tasks 530 | ) 531 | 532 | # get the final tas result 533 | flatten_tas = [] 534 | for t in tas: 535 | if t: 536 | flatten_tas += t 537 | flatten_tas = sorted(flatten_tas, key=lambda x: x[0], reverse=False) 538 | for step_number, (s, t) in enumerate(zip(sas, flatten_tas)): 539 | if t[0] == step_number and not judge_to_be_true(s[0], t[1][0]): 540 | self.tar_agent.execute(t[1][0]) 541 | to_replace_action = [ 542 | flatten_tas[step_number][1][0], 543 | self.tar_agent.current_observation, 544 | ] 545 | sas = sas[:step_number] + [to_replace_action] 546 | self.app_agent.update_scratchpad( 547 | sas[-1][0], sas[-1][1], len(self.steps) + step_number 548 | ) 549 | self.config.PREDICT_CORRECT += self.collector.build_trajectory(self.target_logger, self.config.PREDICT_K) 550 | self.config.PREDICT_K = [] 551 | if self.config.ENABLE_TRAIN: 552 | if self.config.BUILD_TRAJ_TIMES == 0: 553 | train_task = asyncio.create_task(self.executor.async_train(self.train_logger)) 554 | self.config.PENDING_BACKGROUND_TASKS.append(train_task) 555 | self.config.BUILD_TRAJ_TIMES = (self.config.BUILD_TRAJ_TIMES + 1) % self.config.TRAIN_INTERVAL 556 | break 557 | return sas 558 | else: 559 | # cancel t_j for j > i if t_i != s_i 560 | if [pending_task.done() for pending_task in pending_tasks] == [ 561 | True 562 | ] * len(pending_tasks): 563 | break_while_loop = True 564 | try: 565 | await asyncio.gather(*pending_tasks, return_exceptions=False) 566 | break 567 | except Exception: 568 | break 569 | if break_while_loop: 570 | break 571 | mistaken_process_id = int(str(e)[-1]) 572 | 573 | for process_id, one_task in enumerate(pending_tasks): 574 | if ( 575 | not one_task.cancelled() 576 | and not one_task.done() 577 | and process_id > mistaken_process_id 578 | ): 579 | self.target_logger.log(f"cancel task at line 608 {len(self.steps)+process_id+1}") 580 | await cancel(one_task) 581 | 582 | pending_tasks = [ 583 | t for process_id, t in enumerate(target_tasks) 584 | if not t.cancelled() and process_id != mistaken_process_id 585 | ] 586 | 587 | if break_while_loop: 588 | break 589 | 590 | # Process leftover interactions 591 | flatten_tas = sorted([item for t in tas if t for item in t], key=lambda x: x[0], reverse=False) 592 | printed_ids, tas = self.process_remaining_interactions( 593 | sas, tas, flatten_tas, printed_ids, 594 | self.steps, target_tasks 595 | ) 596 | 597 | flatten_tas = sorted([item for t in tas if t for item in t], key=lambda x: x[0], reverse=False) 598 | 599 | # self.target_logger.log(f"sas tasks: {[(sa[0], False if sa[1] is not True else True) for sa in sas]}") 600 | # self.target_logger.log(f"flatten_tas: {flatten_tas}") 601 | 602 | all_match = True 603 | for step_number, (s, t) in enumerate(zip(sas, flatten_tas)): 604 | if (len(self.steps)+step_number+1 >= self.config.MAX_STEP) or (t[0] == step_number and (not judge_to_be_true(s[0], t[1][0]) or (t[1][0].lower() == "terminate" or t[1][1]==True))): # t[1] != s: 605 | all_match = False 606 | self.config.PREDICT_CORRECT += self.collector.build_trajectory(self.target_logger, self.config.PREDICT_K) 607 | self.config.PREDICT_K = [] 608 | 609 | if self.config.ENABLE_TRAIN: 610 | if self.config.BUILD_TRAJ_TIMES == 0: 611 | train_task = asyncio.create_task(self.executor.async_train(self.train_logger)) 612 | self.config.PENDING_BACKGROUND_TASKS.append(train_task) 613 | self.config.BUILD_TRAJ_TIMES = (self.config.BUILD_TRAJ_TIMES + 1) % self.config.TRAIN_INTERVAL 614 | 615 | t1 = time.time() 616 | self.tar_agent.execute(t[1][0]) 617 | t2 = time.time() 618 | self.config.TARGET_NORMAL_TIME[len(self.steps)+step_number+1] += round(t2-t1, 2) 619 | to_replace_action = [ 620 | flatten_tas[step_number][1][0], 621 | self.tar_agent.current_observation, 622 | ] 623 | sas = sas[:step_number] + [to_replace_action] 624 | self.app_agent.update_scratchpad( 625 | sas[-1][0], sas[-1][1], len(self.steps) + step_number 626 | ) 627 | 628 | break 629 | 630 | if all_match: 631 | sas = sas[:len(flatten_tas)] 632 | t2 = time.time() 633 | return sas 634 | 635 | async def run(self, args, prompt: str) -> List[str]: 636 | """Run the speculative planning process.""" 637 | # Initialize agents and state 638 | await self.initialize(args) 639 | 640 | begin_time = datetime.now() 641 | steps = [] 642 | breaking_points = 0 643 | i = 0 644 | 645 | app_prompt = prompt 646 | tar_prompt = prompt 647 | 648 | while True: 649 | result = await self.one_episode_sp( 650 | args, app_prompt, tar_prompt, len(steps) 651 | ) 652 | 653 | steps += result 654 | self.steps = steps # Update self.steps to match original behavior 655 | breaking_points += 1 656 | i += len(result) 657 | 658 | # if the last action is terminate, we break the generation process 659 | if result[-1][0].lower() == "terminate" or result[-1][1] is True or len(steps) >= self.config.MAX_STEP: 660 | break 661 | 662 | end_time = datetime.now() 663 | self.logger.log(f"{end_time} - {begin_time} = {end_time - begin_time}") 664 | self.config.TOTAL_SP_TIME = round((end_time - begin_time).total_seconds(), 2) 665 | 666 | return steps --------------------------------------------------------------------------------